Unverified Commit fb9296f0 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Higher priority for user input of max_prefill_tokens & format (#540)

parent 1374334d
...@@ -13,15 +13,15 @@ import zmq ...@@ -13,15 +13,15 @@ import zmq
import zmq.asyncio import zmq.asyncio
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.managers.controller.dp_worker import (
DataParallelWorkerThread,
start_data_parallel_worker,
)
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
FlushCacheReq, FlushCacheReq,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
) )
from sglang.srt.managers.controller.dp_worker import (
DataParallelWorkerThread,
start_data_parallel_worker,
)
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
...@@ -136,7 +136,7 @@ class Controller: ...@@ -136,7 +136,7 @@ class Controller:
self.recv_reqs = [] self.recv_reqs = []
if next_step_input: if next_step_input:
await self.dispatching(next_step_input) await self.dispatching(next_step_input)
#else: # else:
# logger.error("There is no live worker.") # logger.error("There is no live worker.")
await asyncio.sleep(global_config.wait_for_new_request_delay) await asyncio.sleep(global_config.wait_for_new_request_delay)
......
"""A controller that manages a group of tensor parallel workers.""" """A controller that manages a group of tensor parallel workers."""
import asyncio import asyncio
import logging import logging
import time import time
...@@ -49,7 +50,9 @@ class ControllerSingle: ...@@ -49,7 +50,9 @@ class ControllerSingle:
# async sleep for receiving the subsequent request and avoiding cache miss # async sleep for receiving the subsequent request and avoiding cache miss
slept = False slept = False
if len(out_pyobjs) != 0: if len(out_pyobjs) != 0:
has_finished = any([obj.finished_reason is not None for obj in out_pyobjs]) has_finished = any(
[obj.finished_reason is not None for obj in out_pyobjs]
)
if has_finished: if has_finished:
if self.request_dependency_delay > 0: if self.request_dependency_delay > 0:
slept = True slept = True
...@@ -94,4 +97,4 @@ def start_controller_process( ...@@ -94,4 +97,4 @@ def start_controller_process(
except Exception: except Exception:
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback()) logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
finally: finally:
kill_parent_process() kill_parent_process()
\ No newline at end of file
"""ModelRunner runs the forward passes of the models.""" """ModelRunner runs the forward passes of the models."""
import importlib import importlib
import importlib.resources import importlib.resources
import logging import logging
...@@ -12,15 +13,18 @@ import torch ...@@ -12,15 +13,18 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig from vllm.config import ModelConfig as VllmModelConfig
from vllm.distributed import initialize_model_parallel, init_distributed_environment from vllm.distributed import init_distributed_environment, initialize_model_parallel
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model, monkey_patch_vllm_p2p_access_check from sglang.srt.utils import (
get_available_gpu_memory,
is_multimodal_model,
monkey_patch_vllm_p2p_access_check,
)
logger = logging.getLogger("srt.model_runner") logger = logging.getLogger("srt.model_runner")
...@@ -441,7 +445,9 @@ def import_model_classes(): ...@@ -441,7 +445,9 @@ def import_model_classes():
module = importlib.import_module(name) module = importlib.import_module(name)
if hasattr(module, "EntryClass"): if hasattr(module, "EntryClass"):
entry = module.EntryClass entry = module.EntryClass
if isinstance(entry, list): # To support multiple model classes in one module if isinstance(
entry, list
): # To support multiple model classes in one module
for tmp in entry: for tmp in entry:
model_arch_name_to_cls[tmp.__name__] = tmp model_arch_name_to_cls[tmp.__name__] = tmp
else: else:
...@@ -449,7 +455,9 @@ def import_model_classes(): ...@@ -449,7 +455,9 @@ def import_model_classes():
# compat: some models such as chatglm has incorrect class set in config.json # compat: some models such as chatglm has incorrect class set in config.json
# usage: [ tuple("From_Entry_Class_Name": EntryClass), ] # usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
if hasattr(module, "EntryClassRemapping") and isinstance(module.EntryClassRemapping, list): if hasattr(module, "EntryClassRemapping") and isinstance(
module.EntryClassRemapping, list
):
for remap in module.EntryClassRemapping: for remap in module.EntryClassRemapping:
if isinstance(remap, tuple) and len(remap) == 2: if isinstance(remap, tuple) and len(remap) == 2:
model_arch_name_to_cls[remap[0]] = remap[1] model_arch_name_to_cls[remap[0]] = remap[1]
......
""" """
The radix tree data structure for managing the KV cache. The radix tree data structure for managing the KV cache.
""" """
import heapq import heapq
import time import time
from collections import defaultdict from collections import defaultdict
......
"""Request scheduler heuristic.""" """Request scheduler heuristic."""
import random import random
from collections import defaultdict from collections import defaultdict
......
...@@ -15,22 +15,22 @@ from sglang.global_config import global_config ...@@ -15,22 +15,22 @@ from sglang.global_config import global_config
from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import (
AbortReq,
BatchTokenIDOut,
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.controller.infer_batch import ( from sglang.srt.managers.controller.infer_batch import (
FINISH_ABORT,
BaseFinishReason, BaseFinishReason,
Batch, Batch,
FINISH_ABORT,
ForwardMode, ForwardMode,
Req, Req,
) )
from sglang.srt.managers.controller.model_runner import ModelRunner from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.managers.controller.radix_cache import RadixCache from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
from sglang.srt.managers.io_struct import (
AbortReq,
BatchTokenIDOut,
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.model_config import ModelConfig from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import ModelPortArgs, ServerArgs from sglang.srt.server_args import ModelPortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
...@@ -96,13 +96,13 @@ class ModelTpServer: ...@@ -96,13 +96,13 @@ class ModelTpServer:
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
) )
self.max_total_num_tokens = self.model_runner.max_total_num_tokens self.max_total_num_tokens = self.model_runner.max_total_num_tokens
self.max_prefill_tokens = max( self.max_prefill_tokens = (
self.model_config.context_len, max(
( self.model_config.context_len,
min(self.max_total_num_tokens // 6, 65536) min(self.max_total_num_tokens // 6, 65536),
if server_args.max_prefill_tokens is None )
else server_args.max_prefill_tokens if server_args.max_prefill_tokens is None
), else server_args.max_prefill_tokens
) )
self.max_running_requests = ( self.max_running_requests = (
self.max_total_num_tokens // 2 self.max_total_num_tokens // 2
......
"""DetokenizerManager is a process that detokenizes the token ids.""" """DetokenizerManager is a process that detokenizes the token ids."""
import asyncio import asyncio
import inspect import inspect
...@@ -7,10 +8,10 @@ import zmq ...@@ -7,10 +8,10 @@ import zmq
import zmq.asyncio import zmq.asyncio
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import get_exception_traceback, graceful_registry from sglang.utils import get_exception_traceback, graceful_registry
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
......
...@@ -7,8 +7,8 @@ import uuid ...@@ -7,8 +7,8 @@ import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.managers.controller.infer_batch import BaseFinishReason from sglang.srt.managers.controller.infer_batch import BaseFinishReason
from sglang.srt.sampling_params import SamplingParams
@dataclass @dataclass
......
"""TokenizerManager is a process that tokenizes the text.""" """TokenizerManager is a process that tokenizes the text."""
import asyncio import asyncio
import concurrent.futures import concurrent.futures
import dataclasses import dataclasses
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import os import os
from typing import List, Dict from typing import Dict, List
import numpy as np import numpy as np
import transformers import transformers
...@@ -23,11 +24,11 @@ from sglang.srt.hf_transformers_utils import ( ...@@ -23,11 +24,11 @@ from sglang.srt.hf_transformers_utils import (
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BatchStrOut, BatchStrOut,
BatchTokenIDOut,
FlushCacheReq, FlushCacheReq,
GenerateReqInput, GenerateReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
) )
from sglang.srt.managers.io_struct import BatchTokenIDOut
from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.sampling_params import SamplingParams from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
...@@ -91,7 +92,7 @@ class TokenizerManager: ...@@ -91,7 +92,7 @@ class TokenizerManager:
) )
self.to_create_loop = True self.to_create_loop = True
self.rid_to_state: Dict[str, ReqState] = {} self.rid_to_state: Dict[str, ReqState] = {}
async def get_pixel_values(self, image_data): async def get_pixel_values(self, image_data):
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
...@@ -322,7 +323,6 @@ class TokenizerManager: ...@@ -322,7 +323,6 @@ class TokenizerManager:
state.finished = recv_obj.finished_reason[i] is not None state.finished = recv_obj.finished_reason[i] is not None
state.event.set() state.event.set()
def convert_logprob_style( def convert_logprob_style(
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
): ):
......
from typing import Optional from typing import Optional
from sglang.srt.hf_transformers_utils import get_config, get_context_length
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt.hf_transformers_utils import get_config, get_context_length
class ModelConfig: class ModelConfig:
def __init__( def __init__(
...@@ -17,8 +18,12 @@ class ModelConfig: ...@@ -17,8 +18,12 @@ class ModelConfig:
self.trust_remote_code = trust_remote_code self.trust_remote_code = trust_remote_code
self.revision = revision self.revision = revision
self.model_overide_args = model_overide_args self.model_overide_args = model_overide_args
self.hf_config = get_config(self.path, trust_remote_code, revision, self.hf_config = get_config(
model_overide_args=model_overide_args) self.path,
trust_remote_code,
revision,
model_overide_args=model_overide_args,
)
self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_text_config = get_hf_text_config(self.hf_config)
if context_length is not None: if context_length is not None:
self.context_len = context_length self.context_len = context_length
...@@ -55,18 +60,23 @@ class ModelConfig: ...@@ -55,18 +60,23 @@ class ModelConfig:
# KV heads. # KV heads.
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
new_decoder_arch_falcon = ( new_decoder_arch_falcon = (
self.hf_config.model_type in falcon_model_types self.hf_config.model_type in falcon_model_types
and getattr(self.hf_config, "new_decoder_architecture", False)) and getattr(self.hf_config, "new_decoder_architecture", False)
if not new_decoder_arch_falcon and getattr(self.hf_text_config, )
"multi_query", False): if not new_decoder_arch_falcon and getattr(
self.hf_text_config, "multi_query", False
):
# Multi-query attention, only one KV head. # Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case. # Currently, tensor parallelism is not supported in this case.
return 1 return 1
# For DBRX and MPT # For DBRX and MPT
if self.hf_config.model_type in ["dbrx", "mpt"]: if self.hf_config.model_type in ["dbrx", "mpt"]:
return getattr(self.hf_config.attn_config, "kv_n_heads", return getattr(
self.hf_config.num_attention_heads) self.hf_config.attn_config,
"kv_n_heads",
self.hf_config.num_attention_heads,
)
attributes = [ attributes = [
# For Falcon: # For Falcon:
...@@ -94,13 +104,12 @@ class ModelConfig: ...@@ -94,13 +104,12 @@ class ModelConfig:
# the tensor parallel size. We will replicate the KV heads in the # the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor # case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head. # parallel size so each GPU has at least one KV head.
return max(1, return max(1, total_num_kv_heads // tensor_parallel_size)
total_num_kv_heads // tensor_parallel_size)
def get_hf_text_config(config: PretrainedConfig): def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models. """Get the "sub" config relevant to llm for multi modal models.
No op for pure text models. No op for pure text models.
""" """
if hasattr(config, "text_config"): if hasattr(config, "text_config"):
# The code operates under the assumption that text_config should have # The code operates under the assumption that text_config should have
......
...@@ -5,30 +5,32 @@ ...@@ -5,30 +5,32 @@
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple
import torch import torch
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.layers.logits_processor import LogitsProcessor
from torch import nn from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (
QKVParallelLinear, MergedColumnParallelLinear,
RowParallelLinear) QKVParallelLinear,
from vllm.model_executor.layers.quantization.base_config import ( RowParallelLinear,
QuantizationConfig) )
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata
LoraConfig = None LoraConfig = None
...@@ -49,9 +51,11 @@ class GLMAttention(nn.Module): ...@@ -49,9 +51,11 @@ class GLMAttention(nn.Module):
assert self.total_num_heads % tp_size == 0 assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size self.num_heads = self.total_num_heads // tp_size
self.multi_query_attention = config.multi_query_attention self.multi_query_attention = config.multi_query_attention
self.total_num_kv_heads = (config.multi_query_group_num self.total_num_kv_heads = (
if config.multi_query_attention else config.multi_query_group_num
config.num_attention_heads) if config.multi_query_attention
else config.num_attention_heads
)
if self.total_num_kv_heads >= tp_size: if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition # Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs. # the KV heads across multiple tensor parallel GPUs.
...@@ -91,11 +95,13 @@ class GLMAttention(nn.Module): ...@@ -91,11 +95,13 @@ class GLMAttention(nn.Module):
base=10000 * rope_ratio, base=10000 * rope_ratio,
is_neox_style=False, is_neox_style=False,
) )
self.attn = RadixAttention(self.num_heads, self.attn = RadixAttention(
self.head_dim, self.num_heads,
self.scaling, self.head_dim,
num_kv_heads=self.num_kv_heads, self.scaling,
layer_id=layer_id) num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
)
def forward( def forward(
self, self,
...@@ -176,14 +182,16 @@ class GLMBlock(nn.Module): ...@@ -176,14 +182,16 @@ class GLMBlock(nn.Module):
): ):
super().__init__() super().__init__()
self.apply_residual_connection_post_layernorm = ( self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm) config.apply_residual_connection_post_layernorm
)
self.fp32_residual_connection = config.fp32_residual_connection self.fp32_residual_connection = config.fp32_residual_connection
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
# Layernorm on the input data. # Layernorm on the input data.
self.input_layernorm = layer_norm_func(config.hidden_size, self.input_layernorm = layer_norm_func(
eps=config.layernorm_epsilon) config.hidden_size, eps=config.layernorm_epsilon
)
# Self attention. # Self attention.
self.self_attention = GLMAttention(config, layer_id, cache_config, quant_config) self.self_attention = GLMAttention(config, layer_id, cache_config, quant_config)
...@@ -191,7 +199,8 @@ class GLMBlock(nn.Module): ...@@ -191,7 +199,8 @@ class GLMBlock(nn.Module):
# Layernorm on the attention output # Layernorm on the attention output
self.post_attention_layernorm = layer_norm_func( self.post_attention_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon) config.hidden_size, eps=config.layernorm_epsilon
)
# MLP # MLP
self.mlp = GLMMLP(config, quant_config) self.mlp = GLMMLP(config, quant_config)
...@@ -250,16 +259,19 @@ class GLMTransformer(nn.Module): ...@@ -250,16 +259,19 @@ class GLMTransformer(nn.Module):
self.num_layers = config.num_layers self.num_layers = config.num_layers
# Transformer layers. # Transformer layers.
self.layers = nn.ModuleList([ self.layers = nn.ModuleList(
GLMBlock(config, i, cache_config, quant_config) [
for i in range(self.num_layers) GLMBlock(config, i, cache_config, quant_config)
]) for i in range(self.num_layers)
]
)
if self.post_layer_norm: if self.post_layer_norm:
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = layer_norm_func( self.final_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon) config.hidden_size, eps=config.layernorm_epsilon
)
def forward( def forward(
self, self,
...@@ -291,16 +303,16 @@ class ChatGLMModel(nn.Module): ...@@ -291,16 +303,16 @@ class ChatGLMModel(nn.Module):
): ):
super().__init__() super().__init__()
self.embedding = VocabParallelEmbedding(config.padded_vocab_size, self.embedding = VocabParallelEmbedding(
config.hidden_size) config.padded_vocab_size, config.hidden_size
)
self.num_layers = config.num_layers self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config, cache_config, quant_config) self.encoder = GLMTransformer(config, cache_config, quant_config)
self.output_layer = ParallelLMHead(config.padded_vocab_size, self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size)
config.hidden_size)
def forward( def forward(
self, self,
...@@ -322,7 +334,7 @@ class ChatGLMModel(nn.Module): ...@@ -322,7 +334,7 @@ class ChatGLMModel(nn.Module):
class ChatGLMForCausalLM(nn.Module): class ChatGLMForCausalLM(nn.Module):
packed_modules_mapping = { packed_modules_mapping = {
"query_key_value": ["query_key_value"], "query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"] "dense_h_to_4h": ["dense_h_to_4h"],
} }
# LoRA specific attributes # LoRA specific attributes
supported_lora_modules = [ supported_lora_modules = [
...@@ -344,8 +356,7 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -344,8 +356,7 @@ class ChatGLMForCausalLM(nn.Module):
super().__init__() super().__init__()
self.config: ChatGLMConfig = config self.config: ChatGLMConfig = config
self.quant_config = quant_config self.quant_config = quant_config
self.max_position_embeddings = getattr(config, "max_sequence_length", self.max_position_embeddings = getattr(config, "max_sequence_length", 8192)
8192)
self.transformer = ChatGLMModel(config, cache_config, quant_config) self.transformer = ChatGLMModel(config, cache_config, quant_config)
self.lm_head = self.transformer.output_layer self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
...@@ -357,8 +368,7 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -357,8 +368,7 @@ class ChatGLMForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, hidden_states = self.transformer(input_ids, positions, input_metadata)
input_metadata)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
...@@ -382,10 +392,10 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -382,10 +392,10 @@ class ChatGLMForCausalLM(nn.Module):
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader", default_weight_loader)
default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
EntryClass = ChatGLMForCausalLM EntryClass = ChatGLMForCausalLM
# compat: glm model.config class == ChatGLMModel # compat: glm model.config class == ChatGLMModel
EntryClassRemapping = [("ChatGLMModel", ChatGLMForCausalLM)] EntryClassRemapping = [("ChatGLMModel", ChatGLMForCausalLM)]
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
# This file is based on the LLama model definition file in transformers # This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model.""" """PyTorch Cohere model."""
from typing import Optional, Tuple, Iterable from typing import Iterable, Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -44,8 +44,8 @@ from vllm.model_executor.layers.linear import ( ...@@ -44,8 +44,8 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
......
...@@ -24,8 +24,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -24,8 +24,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.utils import set_weight_attrs
from vllm.transformers_utils.configs.dbrx import DbrxConfig from vllm.transformers_utils.configs.dbrx import DbrxConfig
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
......
...@@ -6,7 +6,7 @@ from typing import Iterable, Optional, Tuple ...@@ -6,7 +6,7 @@ from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import LoRAConfig, CacheConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
......
This diff is collapsed.
# Adapted from # Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1 # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
"""Inference-only LLaMA model compatible with HuggingFace weights.""" """Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Optional, Tuple, Iterable from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
import tqdm import tqdm
...@@ -10,7 +10,7 @@ from transformers import LlamaConfig ...@@ -10,7 +10,7 @@ from transformers import LlamaConfig
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import ( from vllm.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size get_tensor_model_parallel_world_size,
) )
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -158,9 +158,11 @@ class LlamaDecoderLayer(nn.Module): ...@@ -158,9 +158,11 @@ class LlamaDecoderLayer(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr( if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None): config, "original_max_position_embeddings", None
):
rope_scaling["original_max_position_embeddings"] = ( rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings) config.original_max_position_embeddings
)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = LlamaAttention( self.self_attn = LlamaAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
......
"""Inference-only LLaVa model compatible with HuggingFace weights.""" """Inference-only LLaVa model compatible with HuggingFace weights."""
from typing import List, Iterable, Optional, Tuple from typing import Iterable, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from transformers import CLIPVisionModel, CLIPVisionConfig, LlavaConfig, Qwen2Config, MistralConfig from transformers import (
CLIPVisionConfig,
CLIPVisionModel,
LlavaConfig,
MistralConfig,
Qwen2Config,
)
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
...@@ -19,8 +25,8 @@ from sglang.srt.mm_utils import ( ...@@ -19,8 +25,8 @@ from sglang.srt.mm_utils import (
unpad_image_shape, unpad_image_shape,
) )
from sglang.srt.models.llama2 import LlamaForCausalLM from sglang.srt.models.llama2 import LlamaForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
class LlavaLlamaForCausalLM(nn.Module): class LlavaLlamaForCausalLM(nn.Module):
...@@ -359,6 +365,7 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM): ...@@ -359,6 +365,7 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
first_call = True first_call = True
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]
...@@ -388,8 +395,4 @@ def monkey_path_clip_vision_embed_forward(): ...@@ -388,8 +395,4 @@ def monkey_path_clip_vision_embed_forward():
) )
EntryClass = [ EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
LlavaLlamaForCausalLM,
LlavaQwenForCausalLM,
LlavaMistralForCausalLM
]
"""Inference-only LLaVa video model compatible with HuggingFace weights.""" """Inference-only LLaVa video model compatible with HuggingFace weights."""
from typing import List, Iterable, Optional, Tuple from typing import Iterable, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
......
...@@ -33,13 +33,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -33,13 +33,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.managers.controller.model_runner import InputMetadata
class MixtralMoE(nn.Module): class MixtralMoE(nn.Module):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert """A tensor-parallel MoE implementation for Mixtral that shards each expert
across all ranks. across all ranks.
...@@ -76,32 +74,46 @@ class MixtralMoE(nn.Module): ...@@ -76,32 +74,46 @@ class MixtralMoE(nn.Module):
self.params_dtype = params_dtype self.params_dtype = params_dtype
# Gate always runs at half / full precision for now. # Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(self.hidden_size, self.gate = ReplicatedLinear(
self.num_total_experts, self.hidden_size,
bias=False, self.num_total_experts,
params_dtype=self.params_dtype, bias=False,
quant_config=None) params_dtype=self.params_dtype,
quant_config=None,
)
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized: if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn params_dtype = torch.float8_e4m3fn
self.w13_weight = nn.Parameter( self.w13_weight = nn.Parameter(
torch.empty(self.num_total_experts, torch.empty(
2 * self.intermediate_size, self.num_total_experts,
self.hidden_size, 2 * self.intermediate_size,
dtype=params_dtype)) self.hidden_size,
dtype=params_dtype,
)
)
self.w2_weight = nn.Parameter( self.w2_weight = nn.Parameter(
torch.empty(self.num_total_experts, torch.empty(
self.hidden_size, self.num_total_experts,
self.intermediate_size, self.hidden_size,
dtype=params_dtype)) self.intermediate_size,
dtype=params_dtype,
set_weight_attrs(self.w13_weight, { )
"weight_loader": self.weight_loader, )
})
set_weight_attrs(self.w2_weight, { set_weight_attrs(
"weight_loader": self.weight_loader, self.w13_weight,
}) {
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
self.w2_weight,
{
"weight_loader": self.weight_loader,
},
)
# Used for fp8. # Used for fp8.
self.w13_scale = None self.w13_scale = None
...@@ -111,46 +123,68 @@ class MixtralMoE(nn.Module): ...@@ -111,46 +123,68 @@ class MixtralMoE(nn.Module):
if self.use_fp8: if self.use_fp8:
# WEIGHT_SCALE (for fp8) # WEIGHT_SCALE (for fp8)
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts, self.w13_scale = nn.Parameter(
dtype=torch.float32), torch.ones(self.num_total_experts, dtype=torch.float32),
requires_grad=False) requires_grad=False,
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts, )
dtype=torch.float32), self.w2_scale = nn.Parameter(
requires_grad=False) torch.ones(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
# If loading fp8 checkpoint, pass the weight loaders. # If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in # If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading() # process_weights_after_loading()
if quant_config.is_checkpoint_fp8_serialized: if quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(self.w13_scale, { set_weight_attrs(
"weight_loader": self.weight_loader, self.w13_scale,
}) {
set_weight_attrs(self.w2_scale, { "weight_loader": self.weight_loader,
"weight_loader": self.weight_loader, },
}) )
set_weight_attrs(
self.w2_scale,
{
"weight_loader": self.weight_loader,
},
)
# ACT_SCALE (for fp8) # ACT_SCALE (for fp8)
if quant_config.activation_scheme == "static": if quant_config.activation_scheme == "static":
if not quant_config.is_checkpoint_fp8_serialized: if not quant_config.is_checkpoint_fp8_serialized:
raise ValueError( raise ValueError(
"Found static activation scheme for checkpoint that " "Found static activation scheme for checkpoint that "
"was not serialized fp8.") "was not serialized fp8."
self.a13_scale = nn.Parameter(torch.zeros( )
self.num_total_experts, dtype=torch.float32), self.a13_scale = nn.Parameter(
requires_grad=False) torch.zeros(self.num_total_experts, dtype=torch.float32),
self.a2_scale = nn.Parameter(torch.zeros( requires_grad=False,
self.num_total_experts, dtype=torch.float32), )
requires_grad=False) self.a2_scale = nn.Parameter(
torch.zeros(self.num_total_experts, dtype=torch.float32),
set_weight_attrs(self.a13_scale, { requires_grad=False,
"weight_loader": self.weight_loader, )
})
set_weight_attrs(self.a2_scale, { set_weight_attrs(
"weight_loader": self.weight_loader, self.a13_scale,
}) {
"weight_loader": self.weight_loader,
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, },
weight_name: str, expert_id: int): )
set_weight_attrs(
self.a2_scale,
{
"weight_loader": self.weight_loader,
},
)
def weight_loader(
self,
param: nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
expert_id: int,
):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
param_data = param.data param_data = param.data
shard_size = self.intermediate_size shard_size = self.intermediate_size
...@@ -158,8 +192,9 @@ class MixtralMoE(nn.Module): ...@@ -158,8 +192,9 @@ class MixtralMoE(nn.Module):
if weight_name.endswith("w1.weight"): if weight_name.endswith("w1.weight"):
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w3.weight"): if weight_name.endswith("w3.weight"):
param_data[expert_id, param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
shard_size:2 * shard_size, :] = loaded_weight[shard, :] shard, :
]
if weight_name.endswith("w2.weight"): if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard] param_data[expert_id, :, :] = loaded_weight[:, shard]
if "act_scale" in weight_name or "weight_scale" in weight_name: if "act_scale" in weight_name or "weight_scale" in weight_name:
...@@ -172,17 +207,17 @@ class MixtralMoE(nn.Module): ...@@ -172,17 +207,17 @@ class MixtralMoE(nn.Module):
# If checkpoint is fp16, quantize here. # If checkpoint is fp16, quantize here.
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
w13_weight = torch.empty_like(self.w13_weight.data, w13_weight = torch.empty_like(
dtype=torch.float8_e4m3fn) self.w13_weight.data, dtype=torch.float8_e4m3fn
w2_weight = torch.empty_like(self.w2_weight.data, )
dtype=torch.float8_e4m3fn) w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
for expert in range(self.num_total_experts): for expert in range(self.num_total_experts):
w13_weight[expert, :, :], self.w13_scale[ w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
expert] = ops.scaled_fp8_quant( self.w13_weight.data[expert, :, :]
self.w13_weight.data[expert, :, :]) )
w2_weight[expert, :, :], self.w2_scale[ w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
expert] = ops.scaled_fp8_quant( self.w2_weight.data[expert, :, :]
self.w2_weight.data[expert, :, :]) )
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False) self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False) self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
...@@ -193,40 +228,40 @@ class MixtralMoE(nn.Module): ...@@ -193,40 +228,40 @@ class MixtralMoE(nn.Module):
if self.a13_scale is None or self.a2_scale is None: if self.a13_scale is None or self.a2_scale is None:
raise ValueError( raise ValueError(
"QuantConfig has static quantization, but found " "QuantConfig has static quantization, but found "
"activation scales are None.") "activation scales are None."
)
if (not all_close_1d(self.a13_scale) if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
or not all_close_1d(self.a2_scale)):
print_warning_once( print_warning_once(
"Found act_scales that are not equal for fp8 MoE layer. " "Found act_scales that are not equal for fp8 MoE layer. "
"Using the maximum across experts for each layer. ") "Using the maximum across experts for each layer. "
)
self.a13_scale = nn.Parameter(self.a13_scale.max(), self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
requires_grad=False) self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
self.a2_scale = nn.Parameter(self.a2_scale.max(),
requires_grad=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size) hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states, final_hidden_states = fused_moe(
self.w13_weight, hidden_states,
self.w2_weight, self.w13_weight,
router_logits, self.w2_weight,
self.top_k, router_logits,
renormalize=True, self.top_k,
inplace=True, renormalize=True,
use_fp8=self.use_fp8, inplace=True,
w1_scale=self.w13_scale, use_fp8=self.use_fp8,
w2_scale=self.w2_scale, w1_scale=self.w13_scale,
a1_scale=self.a13_scale, w2_scale=self.w2_scale,
a2_scale=self.a2_scale) a1_scale=self.a13_scale,
a2_scale=self.a2_scale,
)
if self.tp_size > 1: if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size) return final_hidden_states.view(num_tokens, hidden_size)
...@@ -335,7 +370,8 @@ class MixtralDecoderLayer(nn.Module): ...@@ -335,7 +370,8 @@ class MixtralDecoderLayer(nn.Module):
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
quant_config=quant_config) quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps config.hidden_size, eps=config.rms_norm_eps
...@@ -444,35 +480,48 @@ class MixtralForCausalLM(nn.Module): ...@@ -444,35 +480,48 @@ class MixtralForCausalLM(nn.Module):
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] ]
expert_params_mapping = [ expert_params_mapping = (
# These are the weight scales for the experts [
# (param_name, weight_name, expert_id) # These are the weight scales for the experts
("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", # (param_name, weight_name, expert_id)
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id) (
for expert_id in range(self.config.num_local_experts) "w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
for weight_name in ["w1", "w2", "w3"] f"experts.{expert_id}.{weight_name}.weight_scale",
] + [ expert_id,
# These are the weights for the experts )
# (param_name, weight_name, expert_id) for expert_id in range(self.config.num_local_experts)
("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", for weight_name in ["w1", "w2", "w3"]
f"experts.{expert_id}.{weight_name}.weight", expert_id) ]
for expert_id in range(self.config.num_local_experts) + [
for weight_name in ["w1", "w2", "w3"] # These are the weights for the experts
] + [ # (param_name, weight_name, expert_id)
# These are the activation scales for the experts (
# (param_name, weight_name, expert_id) "w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", f"experts.{expert_id}.{weight_name}.weight",
f"experts.{expert_id}.{weight_name}.act_scale", expert_id) expert_id,
for expert_id in range(self.config.num_local_experts) )
for weight_name in ["w1", "w2", "w3"] for expert_id in range(self.config.num_local_experts)
] for weight_name in ["w1", "w2", "w3"]
]
+ [
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
(
"a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
f"experts.{expert_id}.{weight_name}.act_scale",
expert_id,
)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
...@@ -490,18 +539,18 @@ class MixtralForCausalLM(nn.Module): ...@@ -490,18 +539,18 @@ class MixtralForCausalLM(nn.Module):
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, weight_loader(
loaded_weight, param, loaded_weight, weight_name, expert_id=expert_id
weight_name, )
expert_id=expert_id)
break break
else: else:
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(
default_weight_loader) param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
......
...@@ -28,7 +28,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -28,7 +28,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.managers.controller.model_runner import InputMetadata
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment