Unverified Commit fb9296f0 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Higher priority for user input of max_prefill_tokens & format (#540)

parent 1374334d
......@@ -13,15 +13,15 @@ import zmq
import zmq.asyncio
from sglang.global_config import global_config
from sglang.srt.managers.controller.dp_worker import (
DataParallelWorkerThread,
start_data_parallel_worker,
)
from sglang.srt.managers.io_struct import (
AbortReq,
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.controller.dp_worker import (
DataParallelWorkerThread,
start_data_parallel_worker,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import get_exception_traceback
......@@ -136,7 +136,7 @@ class Controller:
self.recv_reqs = []
if next_step_input:
await self.dispatching(next_step_input)
#else:
# else:
# logger.error("There is no live worker.")
await asyncio.sleep(global_config.wait_for_new_request_delay)
......
"""A controller that manages a group of tensor parallel workers."""
import asyncio
import logging
import time
......@@ -49,7 +50,9 @@ class ControllerSingle:
# async sleep for receiving the subsequent request and avoiding cache miss
slept = False
if len(out_pyobjs) != 0:
has_finished = any([obj.finished_reason is not None for obj in out_pyobjs])
has_finished = any(
[obj.finished_reason is not None for obj in out_pyobjs]
)
if has_finished:
if self.request_dependency_delay > 0:
slept = True
......@@ -94,4 +97,4 @@ def start_controller_process(
except Exception:
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
finally:
kill_parent_process()
\ No newline at end of file
kill_parent_process()
"""ModelRunner runs the forward passes of the models."""
import importlib
import importlib.resources
import logging
......@@ -12,15 +13,18 @@ import torch
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
from vllm.distributed import initialize_model_parallel, init_distributed_environment
from vllm.distributed import init_distributed_environment, initialize_model_parallel
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model, monkey_patch_vllm_p2p_access_check
from sglang.srt.utils import (
get_available_gpu_memory,
is_multimodal_model,
monkey_patch_vllm_p2p_access_check,
)
logger = logging.getLogger("srt.model_runner")
......@@ -441,7 +445,9 @@ def import_model_classes():
module = importlib.import_module(name)
if hasattr(module, "EntryClass"):
entry = module.EntryClass
if isinstance(entry, list): # To support multiple model classes in one module
if isinstance(
entry, list
): # To support multiple model classes in one module
for tmp in entry:
model_arch_name_to_cls[tmp.__name__] = tmp
else:
......@@ -449,7 +455,9 @@ def import_model_classes():
# compat: some models such as chatglm has incorrect class set in config.json
# usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
if hasattr(module, "EntryClassRemapping") and isinstance(module.EntryClassRemapping, list):
if hasattr(module, "EntryClassRemapping") and isinstance(
module.EntryClassRemapping, list
):
for remap in module.EntryClassRemapping:
if isinstance(remap, tuple) and len(remap) == 2:
model_arch_name_to_cls[remap[0]] = remap[1]
......
"""
The radix tree data structure for managing the KV cache.
"""
import heapq
import time
from collections import defaultdict
......
"""Request scheduler heuristic."""
import random
from collections import defaultdict
......
......@@ -15,22 +15,22 @@ from sglang.global_config import global_config
from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import (
AbortReq,
BatchTokenIDOut,
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.controller.infer_batch import (
FINISH_ABORT,
BaseFinishReason,
Batch,
FINISH_ABORT,
ForwardMode,
Req,
)
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
from sglang.srt.managers.io_struct import (
AbortReq,
BatchTokenIDOut,
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import ModelPortArgs, ServerArgs
from sglang.srt.utils import (
......@@ -96,13 +96,13 @@ class ModelTpServer:
trust_remote_code=server_args.trust_remote_code,
)
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
self.max_prefill_tokens = max(
self.model_config.context_len,
(
min(self.max_total_num_tokens // 6, 65536)
if server_args.max_prefill_tokens is None
else server_args.max_prefill_tokens
),
self.max_prefill_tokens = (
max(
self.model_config.context_len,
min(self.max_total_num_tokens // 6, 65536),
)
if server_args.max_prefill_tokens is None
else server_args.max_prefill_tokens
)
self.max_running_requests = (
self.max_total_num_tokens // 2
......
"""DetokenizerManager is a process that detokenizes the token ids."""
import asyncio
import inspect
......@@ -7,10 +8,10 @@ import zmq
import zmq.asyncio
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import get_exception_traceback, graceful_registry
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
......
......@@ -7,8 +7,8 @@ import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.managers.controller.infer_batch import BaseFinishReason
from sglang.srt.sampling_params import SamplingParams
@dataclass
......
"""TokenizerManager is a process that tokenizes the text."""
import asyncio
import concurrent.futures
import dataclasses
import logging
import multiprocessing as mp
import os
from typing import List, Dict
from typing import Dict, List
import numpy as np
import transformers
......@@ -23,11 +24,11 @@ from sglang.srt.hf_transformers_utils import (
from sglang.srt.managers.io_struct import (
AbortReq,
BatchStrOut,
BatchTokenIDOut,
FlushCacheReq,
GenerateReqInput,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.io_struct import BatchTokenIDOut
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
......@@ -91,7 +92,7 @@ class TokenizerManager:
)
self.to_create_loop = True
self.rid_to_state: Dict[str, ReqState] = {}
self.rid_to_state: Dict[str, ReqState] = {}
async def get_pixel_values(self, image_data):
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
......@@ -322,7 +323,6 @@ class TokenizerManager:
state.finished = recv_obj.finished_reason[i] is not None
state.event.set()
def convert_logprob_style(
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
):
......
from typing import Optional
from sglang.srt.hf_transformers_utils import get_config, get_context_length
from transformers import PretrainedConfig
from sglang.srt.hf_transformers_utils import get_config, get_context_length
class ModelConfig:
def __init__(
......@@ -17,8 +18,12 @@ class ModelConfig:
self.trust_remote_code = trust_remote_code
self.revision = revision
self.model_overide_args = model_overide_args
self.hf_config = get_config(self.path, trust_remote_code, revision,
model_overide_args=model_overide_args)
self.hf_config = get_config(
self.path,
trust_remote_code,
revision,
model_overide_args=model_overide_args,
)
self.hf_text_config = get_hf_text_config(self.hf_config)
if context_length is not None:
self.context_len = context_length
......@@ -55,18 +60,23 @@ class ModelConfig:
# KV heads.
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
new_decoder_arch_falcon = (
self.hf_config.model_type in falcon_model_types
and getattr(self.hf_config, "new_decoder_architecture", False))
if not new_decoder_arch_falcon and getattr(self.hf_text_config,
"multi_query", False):
self.hf_config.model_type in falcon_model_types
and getattr(self.hf_config, "new_decoder_architecture", False)
)
if not new_decoder_arch_falcon and getattr(
self.hf_text_config, "multi_query", False
):
# Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case.
return 1
# For DBRX and MPT
if self.hf_config.model_type in ["dbrx", "mpt"]:
return getattr(self.hf_config.attn_config, "kv_n_heads",
self.hf_config.num_attention_heads)
return getattr(
self.hf_config.attn_config,
"kv_n_heads",
self.hf_config.num_attention_heads,
)
attributes = [
# For Falcon:
......@@ -94,13 +104,12 @@ class ModelConfig:
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
return max(1,
total_num_kv_heads // tensor_parallel_size)
return max(1, total_num_kv_heads // tensor_parallel_size)
def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models.
No op for pure text models.
No op for pure text models.
"""
if hasattr(config, "text_config"):
# The code operates under the assumption that text_config should have
......
......@@ -5,30 +5,32 @@
from typing import Iterable, List, Optional, Tuple
import torch
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.layers.logits_processor import LogitsProcessor
from torch import nn
from torch.nn import LayerNorm
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata
LoraConfig = None
......@@ -49,9 +51,11 @@ class GLMAttention(nn.Module):
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.multi_query_attention = config.multi_query_attention
self.total_num_kv_heads = (config.multi_query_group_num
if config.multi_query_attention else
config.num_attention_heads)
self.total_num_kv_heads = (
config.multi_query_group_num
if config.multi_query_attention
else config.num_attention_heads
)
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
......@@ -91,11 +95,13 @@ class GLMAttention(nn.Module):
base=10000 * rope_ratio,
is_neox_style=False,
)
self.attn = RadixAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
)
def forward(
self,
......@@ -176,14 +182,16 @@ class GLMBlock(nn.Module):
):
super().__init__()
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm)
config.apply_residual_connection_post_layernorm
)
self.fp32_residual_connection = config.fp32_residual_connection
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
# Layernorm on the input data.
self.input_layernorm = layer_norm_func(config.hidden_size,
eps=config.layernorm_epsilon)
self.input_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon
)
# Self attention.
self.self_attention = GLMAttention(config, layer_id, cache_config, quant_config)
......@@ -191,7 +199,8 @@ class GLMBlock(nn.Module):
# Layernorm on the attention output
self.post_attention_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon)
config.hidden_size, eps=config.layernorm_epsilon
)
# MLP
self.mlp = GLMMLP(config, quant_config)
......@@ -250,16 +259,19 @@ class GLMTransformer(nn.Module):
self.num_layers = config.num_layers
# Transformer layers.
self.layers = nn.ModuleList([
GLMBlock(config, i, cache_config, quant_config)
for i in range(self.num_layers)
])
self.layers = nn.ModuleList(
[
GLMBlock(config, i, cache_config, quant_config)
for i in range(self.num_layers)
]
)
if self.post_layer_norm:
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
# Final layer norm before output.
self.final_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon)
config.hidden_size, eps=config.layernorm_epsilon
)
def forward(
self,
......@@ -291,16 +303,16 @@ class ChatGLMModel(nn.Module):
):
super().__init__()
self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
config.hidden_size)
self.embedding = VocabParallelEmbedding(
config.padded_vocab_size, config.hidden_size
)
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config, cache_config, quant_config)
self.output_layer = ParallelLMHead(config.padded_vocab_size,
config.hidden_size)
self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size)
def forward(
self,
......@@ -322,7 +334,7 @@ class ChatGLMModel(nn.Module):
class ChatGLMForCausalLM(nn.Module):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"]
"dense_h_to_4h": ["dense_h_to_4h"],
}
# LoRA specific attributes
supported_lora_modules = [
......@@ -344,8 +356,7 @@ class ChatGLMForCausalLM(nn.Module):
super().__init__()
self.config: ChatGLMConfig = config
self.quant_config = quant_config
self.max_position_embeddings = getattr(config, "max_sequence_length",
8192)
self.max_position_embeddings = getattr(config, "max_sequence_length", 8192)
self.transformer = ChatGLMModel(config, cache_config, quant_config)
self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config)
......@@ -357,8 +368,7 @@ class ChatGLMForCausalLM(nn.Module):
positions: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions,
input_metadata)
hidden_states = self.transformer(input_ids, positions, input_metadata)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
......@@ -382,10 +392,10 @@ class ChatGLMForCausalLM(nn.Module):
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = ChatGLMForCausalLM
# compat: glm model.config class == ChatGLMModel
EntryClassRemapping = [("ChatGLMModel", ChatGLMForCausalLM)]
......@@ -23,7 +23,7 @@
# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
from typing import Optional, Tuple, Iterable
from typing import Iterable, Optional, Tuple
import torch
import torch.utils.checkpoint
......@@ -44,8 +44,8 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
......
......@@ -24,8 +24,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.utils import set_weight_attrs
from vllm.transformers_utils.configs.dbrx import DbrxConfig
from sglang.srt.layers.logits_processor import LogitsProcessor
......
......@@ -6,7 +6,7 @@ from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig, CacheConfig
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
......
This diff is collapsed.
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Optional, Tuple, Iterable
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
import tqdm
......@@ -10,7 +10,7 @@ from transformers import LlamaConfig
from vllm.config import CacheConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
......@@ -158,9 +158,11 @@ class LlamaDecoderLayer(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
config, "original_max_position_embeddings", None
):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
config.original_max_position_embeddings
)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = LlamaAttention(
hidden_size=self.hidden_size,
......
"""Inference-only LLaVa model compatible with HuggingFace weights."""
from typing import List, Iterable, Optional, Tuple
from typing import Iterable, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
from transformers import CLIPVisionModel, CLIPVisionConfig, LlavaConfig, Qwen2Config, MistralConfig
from transformers import (
CLIPVisionConfig,
CLIPVisionModel,
LlavaConfig,
MistralConfig,
Qwen2Config,
)
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
......@@ -19,8 +25,8 @@ from sglang.srt.mm_utils import (
unpad_image_shape,
)
from sglang.srt.models.llama2 import LlamaForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
class LlavaLlamaForCausalLM(nn.Module):
......@@ -359,6 +365,7 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
first_call = True
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
......@@ -388,8 +395,4 @@ def monkey_path_clip_vision_embed_forward():
)
EntryClass = [
LlavaLlamaForCausalLM,
LlavaQwenForCausalLM,
LlavaMistralForCausalLM
]
EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
"""Inference-only LLaVa video model compatible with HuggingFace weights."""
from typing import List, Iterable, Optional, Tuple
from typing import Iterable, List, Optional, Tuple
import numpy as np
import torch
......
......@@ -33,13 +33,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import print_warning_once
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata
class MixtralMoE(nn.Module):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
across all ranks.
......@@ -76,32 +74,46 @@ class MixtralMoE(nn.Module):
self.params_dtype = params_dtype
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(self.hidden_size,
self.num_total_experts,
bias=False,
params_dtype=self.params_dtype,
quant_config=None)
self.gate = ReplicatedLinear(
self.hidden_size,
self.num_total_experts,
bias=False,
params_dtype=self.params_dtype,
quant_config=None,
)
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
self.w13_weight = nn.Parameter(
torch.empty(self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
dtype=params_dtype))
torch.empty(
self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
dtype=params_dtype,
)
)
self.w2_weight = nn.Parameter(
torch.empty(self.num_total_experts,
self.hidden_size,
self.intermediate_size,
dtype=params_dtype))
set_weight_attrs(self.w13_weight, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2_weight, {
"weight_loader": self.weight_loader,
})
torch.empty(
self.num_total_experts,
self.hidden_size,
self.intermediate_size,
dtype=params_dtype,
)
)
set_weight_attrs(
self.w13_weight,
{
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
self.w2_weight,
{
"weight_loader": self.weight_loader,
},
)
# Used for fp8.
self.w13_scale = None
......@@ -111,46 +123,68 @@ class MixtralMoE(nn.Module):
if self.use_fp8:
# WEIGHT_SCALE (for fp8)
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
dtype=torch.float32),
requires_grad=False)
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
dtype=torch.float32),
requires_grad=False)
self.w13_scale = nn.Parameter(
torch.ones(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
self.w2_scale = nn.Parameter(
torch.ones(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(self.w13_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(
self.w13_scale,
{
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
self.w2_scale,
{
"weight_loader": self.weight_loader,
},
)
# ACT_SCALE (for fp8)
if quant_config.activation_scheme == "static":
if not quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8.")
self.a13_scale = nn.Parameter(torch.zeros(
self.num_total_experts, dtype=torch.float32),
requires_grad=False)
self.a2_scale = nn.Parameter(torch.zeros(
self.num_total_experts, dtype=torch.float32),
requires_grad=False)
set_weight_attrs(self.a13_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.a2_scale, {
"weight_loader": self.weight_loader,
})
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str, expert_id: int):
"was not serialized fp8."
)
self.a13_scale = nn.Parameter(
torch.zeros(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
self.a2_scale = nn.Parameter(
torch.zeros(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
set_weight_attrs(
self.a13_scale,
{
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
self.a2_scale,
{
"weight_loader": self.weight_loader,
},
)
def weight_loader(
self,
param: nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
expert_id: int,
):
tp_rank = get_tensor_model_parallel_rank()
param_data = param.data
shard_size = self.intermediate_size
......@@ -158,8 +192,9 @@ class MixtralMoE(nn.Module):
if weight_name.endswith("w1.weight"):
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w3.weight"):
param_data[expert_id,
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
shard, :
]
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
if "act_scale" in weight_name or "weight_scale" in weight_name:
......@@ -172,17 +207,17 @@ class MixtralMoE(nn.Module):
# If checkpoint is fp16, quantize here.
if not self.quant_config.is_checkpoint_fp8_serialized:
w13_weight = torch.empty_like(self.w13_weight.data,
dtype=torch.float8_e4m3fn)
w2_weight = torch.empty_like(self.w2_weight.data,
dtype=torch.float8_e4m3fn)
w13_weight = torch.empty_like(
self.w13_weight.data, dtype=torch.float8_e4m3fn
)
w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
for expert in range(self.num_total_experts):
w13_weight[expert, :, :], self.w13_scale[
expert] = ops.scaled_fp8_quant(
self.w13_weight.data[expert, :, :])
w2_weight[expert, :, :], self.w2_scale[
expert] = ops.scaled_fp8_quant(
self.w2_weight.data[expert, :, :])
w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
self.w13_weight.data[expert, :, :]
)
w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
self.w2_weight.data[expert, :, :]
)
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
......@@ -193,40 +228,40 @@ class MixtralMoE(nn.Module):
if self.a13_scale is None or self.a2_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None.")
"activation scales are None."
)
if (not all_close_1d(self.a13_scale)
or not all_close_1d(self.a2_scale)):
if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
print_warning_once(
"Found act_scales that are not equal for fp8 MoE layer. "
"Using the maximum across experts for each layer. ")
"Using the maximum across experts for each layer. "
)
self.a13_scale = nn.Parameter(self.a13_scale.max(),
requires_grad=False)
self.a2_scale = nn.Parameter(self.a2_scale.max(),
requires_grad=False)
self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.w13_weight,
self.w2_weight,
router_logits,
self.top_k,
renormalize=True,
inplace=True,
use_fp8=self.use_fp8,
w1_scale=self.w13_scale,
w2_scale=self.w2_scale,
a1_scale=self.a13_scale,
a2_scale=self.a2_scale)
final_hidden_states = fused_moe(
hidden_states,
self.w13_weight,
self.w2_weight,
router_logits,
self.top_k,
renormalize=True,
inplace=True,
use_fp8=self.use_fp8,
w1_scale=self.w13_scale,
w2_scale=self.w2_scale,
a1_scale=self.a13_scale,
a2_scale=self.a2_scale,
)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size)
......@@ -335,7 +370,8 @@ class MixtralDecoderLayer(nn.Module):
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config)
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
......@@ -444,35 +480,48 @@ class MixtralForCausalLM(nn.Module):
("qkv_proj", "v_proj", "v"),
]
expert_params_mapping = [
# These are the weight scales for the experts
# (param_name, weight_name, expert_id)
("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
] + [
# These are the weights for the experts
# (param_name, weight_name, expert_id)
("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
] + [
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
expert_params_mapping = (
[
# These are the weight scales for the experts
# (param_name, weight_name, expert_id)
(
"w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
f"experts.{expert_id}.{weight_name}.weight_scale",
expert_id,
)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
+ [
# These are the weights for the experts
# (param_name, weight_name, expert_id)
(
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
f"experts.{expert_id}.{weight_name}.weight",
expert_id,
)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
+ [
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
(
"a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
f"experts.{expert_id}.{weight_name}.act_scale",
expert_id,
)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
)
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
......@@ -490,18 +539,18 @@ class MixtralForCausalLM(nn.Module):
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
weight_name,
expert_id=expert_id)
weight_loader(
param, loaded_weight, weight_name, expert_id=expert_id
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
......
......@@ -28,7 +28,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment