Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fb9296f0
"include/ck/utility/get_id.hpp" did not exist on "19f17df47a2d814cab40b75027cbcac0c544932f"
Unverified
Commit
fb9296f0
authored
Jun 12, 2024
by
Ying Sheng
Committed by
GitHub
Jun 12, 2024
Browse files
Higher priority for user input of max_prefill_tokens & format (#540)
parent
1374334d
Changes
50
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
505 additions
and
351 deletions
+505
-351
python/sglang/srt/managers/controller/manager_multi.py
python/sglang/srt/managers/controller/manager_multi.py
+5
-5
python/sglang/srt/managers/controller/manager_single.py
python/sglang/srt/managers/controller/manager_single.py
+5
-2
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+13
-5
python/sglang/srt/managers/controller/radix_cache.py
python/sglang/srt/managers/controller/radix_cache.py
+1
-0
python/sglang/srt/managers/controller/schedule_heuristic.py
python/sglang/srt/managers/controller/schedule_heuristic.py
+1
-0
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+14
-14
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+2
-1
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+1
-1
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+4
-4
python/sglang/srt/model_config.py
python/sglang/srt/model_config.py
+21
-12
python/sglang/srt/models/chatglm.py
python/sglang/srt/models/chatglm.py
+48
-38
python/sglang/srt/models/commandr.py
python/sglang/srt/models/commandr.py
+2
-2
python/sglang/srt/models/dbrx.py
python/sglang/srt/models/dbrx.py
+1
-1
python/sglang/srt/models/gemma.py
python/sglang/srt/models/gemma.py
+1
-1
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+204
-136
python/sglang/srt/models/llama2.py
python/sglang/srt/models/llama2.py
+6
-4
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+11
-8
python/sglang/srt/models/llavavid.py
python/sglang/srt/models/llavavid.py
+1
-1
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+164
-115
python/sglang/srt/models/mixtral_quant.py
python/sglang/srt/models/mixtral_quant.py
+0
-1
No files found.
python/sglang/srt/managers/controller/manager_multi.py
View file @
fb9296f0
...
...
@@ -13,15 +13,15 @@ import zmq
import
zmq.asyncio
from
sglang.global_config
import
global_config
from
sglang.srt.managers.controller.dp_worker
import
(
DataParallelWorkerThread
,
start_data_parallel_worker
,
)
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
FlushCacheReq
,
TokenizedGenerateReqInput
,
)
from
sglang.srt.managers.controller.dp_worker
import
(
DataParallelWorkerThread
,
start_data_parallel_worker
,
)
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.utils
import
get_exception_traceback
...
...
@@ -136,7 +136,7 @@ class Controller:
self
.
recv_reqs
=
[]
if
next_step_input
:
await
self
.
dispatching
(
next_step_input
)
#else:
#
else:
# logger.error("There is no live worker.")
await
asyncio
.
sleep
(
global_config
.
wait_for_new_request_delay
)
...
...
python/sglang/srt/managers/controller/manager_single.py
View file @
fb9296f0
"""A controller that manages a group of tensor parallel workers."""
import
asyncio
import
logging
import
time
...
...
@@ -49,7 +50,9 @@ class ControllerSingle:
# async sleep for receiving the subsequent request and avoiding cache miss
slept
=
False
if
len
(
out_pyobjs
)
!=
0
:
has_finished
=
any
([
obj
.
finished_reason
is
not
None
for
obj
in
out_pyobjs
])
has_finished
=
any
(
[
obj
.
finished_reason
is
not
None
for
obj
in
out_pyobjs
]
)
if
has_finished
:
if
self
.
request_dependency_delay
>
0
:
slept
=
True
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
fb9296f0
"""ModelRunner runs the forward passes of the models."""
import
importlib
import
importlib.resources
import
logging
...
...
@@ -12,15 +13,18 @@ import torch
import
torch.nn
as
nn
from
vllm.config
import
DeviceConfig
,
LoadConfig
from
vllm.config
import
ModelConfig
as
VllmModelConfig
from
vllm.distributed
import
init
ialize_model_parallel
,
init_distributed_environment
from
vllm.distributed
import
init
_distributed_environment
,
initialize_model_parallel
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.models
import
ModelRegistry
from
sglang.srt.managers.controller.infer_batch
import
Batch
,
ForwardMode
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_available_gpu_memory
,
is_multimodal_model
,
monkey_patch_vllm_p2p_access_check
from
sglang.srt.utils
import
(
get_available_gpu_memory
,
is_multimodal_model
,
monkey_patch_vllm_p2p_access_check
,
)
logger
=
logging
.
getLogger
(
"srt.model_runner"
)
...
...
@@ -441,7 +445,9 @@ def import_model_classes():
module
=
importlib
.
import_module
(
name
)
if
hasattr
(
module
,
"EntryClass"
):
entry
=
module
.
EntryClass
if
isinstance
(
entry
,
list
):
# To support multiple model classes in one module
if
isinstance
(
entry
,
list
):
# To support multiple model classes in one module
for
tmp
in
entry
:
model_arch_name_to_cls
[
tmp
.
__name__
]
=
tmp
else
:
...
...
@@ -449,7 +455,9 @@ def import_model_classes():
# compat: some models such as chatglm has incorrect class set in config.json
# usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
if
hasattr
(
module
,
"EntryClassRemapping"
)
and
isinstance
(
module
.
EntryClassRemapping
,
list
):
if
hasattr
(
module
,
"EntryClassRemapping"
)
and
isinstance
(
module
.
EntryClassRemapping
,
list
):
for
remap
in
module
.
EntryClassRemapping
:
if
isinstance
(
remap
,
tuple
)
and
len
(
remap
)
==
2
:
model_arch_name_to_cls
[
remap
[
0
]]
=
remap
[
1
]
...
...
python/sglang/srt/managers/controller/radix_cache.py
View file @
fb9296f0
"""
The radix tree data structure for managing the KV cache.
"""
import
heapq
import
time
from
collections
import
defaultdict
...
...
python/sglang/srt/managers/controller/schedule_heuristic.py
View file @
fb9296f0
"""Request scheduler heuristic."""
import
random
from
collections
import
defaultdict
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
fb9296f0
...
...
@@ -15,22 +15,22 @@ from sglang.global_config import global_config
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchTokenIDOut
,
FlushCacheReq
,
TokenizedGenerateReqInput
,
)
from
sglang.srt.managers.controller.infer_batch
import
(
FINISH_ABORT
,
BaseFinishReason
,
Batch
,
FINISH_ABORT
,
ForwardMode
,
Req
,
)
from
sglang.srt.managers.controller.model_runner
import
ModelRunner
from
sglang.srt.managers.controller.radix_cache
import
RadixCache
from
sglang.srt.managers.controller.schedule_heuristic
import
ScheduleHeuristic
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchTokenIDOut
,
FlushCacheReq
,
TokenizedGenerateReqInput
,
)
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.server_args
import
ModelPortArgs
,
ServerArgs
from
sglang.srt.utils
import
(
...
...
@@ -96,13 +96,13 @@ class ModelTpServer:
trust_remote_code
=
server_args
.
trust_remote_code
,
)
self
.
max_total_num_tokens
=
self
.
model_runner
.
max_total_num_tokens
self
.
max_prefill_tokens
=
max
(
self
.
max_prefill_tokens
=
(
max
(
self
.
model_config
.
context_len
,
(
min
(
self
.
max_total_num_tokens
//
6
,
65536
)
min
(
self
.
max_total_num_tokens
//
6
,
65536
),
)
if
server_args
.
max_prefill_tokens
is
None
else
server_args
.
max_prefill_tokens
),
)
self
.
max_running_requests
=
(
self
.
max_total_num_tokens
//
2
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
fb9296f0
"""DetokenizerManager is a process that detokenizes the token ids."""
import
asyncio
import
inspect
...
...
@@ -7,10 +8,10 @@ import zmq
import
zmq.asyncio
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.controller.infer_batch
import
FINISH_MATCHED_STR
from
sglang.srt.managers.io_struct
import
BatchStrOut
,
BatchTokenIDOut
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.utils
import
get_exception_traceback
,
graceful_registry
from
sglang.srt.managers.controller.infer_batch
import
FINISH_MATCHED_STR
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
...
...
python/sglang/srt/managers/io_struct.py
View file @
fb9296f0
...
...
@@ -7,8 +7,8 @@ import uuid
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Union
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.managers.controller.infer_batch
import
BaseFinishReason
from
sglang.srt.sampling_params
import
SamplingParams
@
dataclass
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
fb9296f0
"""TokenizerManager is a process that tokenizes the text."""
import
asyncio
import
concurrent.futures
import
dataclasses
import
logging
import
multiprocessing
as
mp
import
os
from
typing
import
List
,
Dic
t
from
typing
import
Dict
,
Lis
t
import
numpy
as
np
import
transformers
...
...
@@ -23,11 +24,11 @@ from sglang.srt.hf_transformers_utils import (
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchStrOut
,
BatchTokenIDOut
,
FlushCacheReq
,
GenerateReqInput
,
TokenizedGenerateReqInput
,
)
from
sglang.srt.managers.io_struct
import
BatchTokenIDOut
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
...
...
@@ -322,7 +323,6 @@ class TokenizerManager:
state
.
finished
=
recv_obj
.
finished_reason
[
i
]
is
not
None
state
.
event
.
set
()
def
convert_logprob_style
(
self
,
ret
,
return_logprob
,
top_logprobs_num
,
return_text_in_logprobs
):
...
...
python/sglang/srt/model_config.py
View file @
fb9296f0
from
typing
import
Optional
from
sglang.srt.hf_transformers_utils
import
get_config
,
get_context_length
from
transformers
import
PretrainedConfig
from
sglang.srt.hf_transformers_utils
import
get_config
,
get_context_length
class
ModelConfig
:
def
__init__
(
...
...
@@ -17,8 +18,12 @@ class ModelConfig:
self
.
trust_remote_code
=
trust_remote_code
self
.
revision
=
revision
self
.
model_overide_args
=
model_overide_args
self
.
hf_config
=
get_config
(
self
.
path
,
trust_remote_code
,
revision
,
model_overide_args
=
model_overide_args
)
self
.
hf_config
=
get_config
(
self
.
path
,
trust_remote_code
,
revision
,
model_overide_args
=
model_overide_args
,
)
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
if
context_length
is
not
None
:
self
.
context_len
=
context_length
...
...
@@ -56,17 +61,22 @@ class ModelConfig:
falcon_model_types
=
[
"falcon"
,
"RefinedWeb"
,
"RefinedWebModel"
]
new_decoder_arch_falcon
=
(
self
.
hf_config
.
model_type
in
falcon_model_types
and
getattr
(
self
.
hf_config
,
"new_decoder_architecture"
,
False
))
if
not
new_decoder_arch_falcon
and
getattr
(
self
.
hf_text_config
,
"multi_query"
,
False
):
and
getattr
(
self
.
hf_config
,
"new_decoder_architecture"
,
False
)
)
if
not
new_decoder_arch_falcon
and
getattr
(
self
.
hf_text_config
,
"multi_query"
,
False
):
# Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case.
return
1
# For DBRX and MPT
if
self
.
hf_config
.
model_type
in
[
"dbrx"
,
"mpt"
]:
return
getattr
(
self
.
hf_config
.
attn_config
,
"kv_n_heads"
,
self
.
hf_config
.
num_attention_heads
)
return
getattr
(
self
.
hf_config
.
attn_config
,
"kv_n_heads"
,
self
.
hf_config
.
num_attention_heads
,
)
attributes
=
[
# For Falcon:
...
...
@@ -94,8 +104,7 @@ class ModelConfig:
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
return
max
(
1
,
total_num_kv_heads
//
tensor_parallel_size
)
return
max
(
1
,
total_num_kv_heads
//
tensor_parallel_size
)
def
get_hf_text_config
(
config
:
PretrainedConfig
):
...
...
python/sglang/srt/models/chatglm.py
View file @
fb9296f0
...
...
@@ -5,30 +5,32 @@
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.controller.model_runner
import
InputMetadata
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
torch
import
nn
from
torch.nn
import
LayerNorm
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs
import
ChatGLMConfig
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.controller.model_runner
import
InputMetadata
LoraConfig
=
None
...
...
@@ -49,9 +51,11 @@ class GLMAttention(nn.Module):
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
multi_query_attention
=
config
.
multi_query_attention
self
.
total_num_kv_heads
=
(
config
.
multi_query_group_num
if
config
.
multi_query_attention
else
config
.
num_attention_heads
)
self
.
total_num_kv_heads
=
(
config
.
multi_query_group_num
if
config
.
multi_query_attention
else
config
.
num_attention_heads
)
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
...
...
@@ -91,11 +95,13 @@ class GLMAttention(nn.Module):
base
=
10000
*
rope_ratio
,
is_neox_style
=
False
,
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
)
layer_id
=
layer_id
,
)
def
forward
(
self
,
...
...
@@ -176,14 +182,16 @@ class GLMBlock(nn.Module):
):
super
().
__init__
()
self
.
apply_residual_connection_post_layernorm
=
(
config
.
apply_residual_connection_post_layernorm
)
config
.
apply_residual_connection_post_layernorm
)
self
.
fp32_residual_connection
=
config
.
fp32_residual_connection
layer_norm_func
=
RMSNorm
if
config
.
rmsnorm
else
LayerNorm
# Layernorm on the input data.
self
.
input_layernorm
=
layer_norm_func
(
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
)
self
.
input_layernorm
=
layer_norm_func
(
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
)
# Self attention.
self
.
self_attention
=
GLMAttention
(
config
,
layer_id
,
cache_config
,
quant_config
)
...
...
@@ -191,7 +199,8 @@ class GLMBlock(nn.Module):
# Layernorm on the attention output
self
.
post_attention_layernorm
=
layer_norm_func
(
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
)
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
)
# MLP
self
.
mlp
=
GLMMLP
(
config
,
quant_config
)
...
...
@@ -250,16 +259,19 @@ class GLMTransformer(nn.Module):
self
.
num_layers
=
config
.
num_layers
# Transformer layers.
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
(
[
GLMBlock
(
config
,
i
,
cache_config
,
quant_config
)
for
i
in
range
(
self
.
num_layers
)
])
]
)
if
self
.
post_layer_norm
:
layer_norm_func
=
RMSNorm
if
config
.
rmsnorm
else
LayerNorm
# Final layer norm before output.
self
.
final_layernorm
=
layer_norm_func
(
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
)
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
)
def
forward
(
self
,
...
...
@@ -291,16 +303,16 @@ class ChatGLMModel(nn.Module):
):
super
().
__init__
()
self
.
embedding
=
VocabParallelEmbedding
(
config
.
padded_vocab_size
,
config
.
hidden_size
)
self
.
embedding
=
VocabParallelEmbedding
(
config
.
padded_vocab_size
,
config
.
hidden_size
)
self
.
num_layers
=
config
.
num_layers
self
.
multi_query_group_num
=
config
.
multi_query_group_num
self
.
kv_channels
=
config
.
kv_channels
self
.
encoder
=
GLMTransformer
(
config
,
cache_config
,
quant_config
)
self
.
output_layer
=
ParallelLMHead
(
config
.
padded_vocab_size
,
config
.
hidden_size
)
self
.
output_layer
=
ParallelLMHead
(
config
.
padded_vocab_size
,
config
.
hidden_size
)
def
forward
(
self
,
...
...
@@ -322,7 +334,7 @@ class ChatGLMModel(nn.Module):
class
ChatGLMForCausalLM
(
nn
.
Module
):
packed_modules_mapping
=
{
"query_key_value"
:
[
"query_key_value"
],
"dense_h_to_4h"
:
[
"dense_h_to_4h"
]
"dense_h_to_4h"
:
[
"dense_h_to_4h"
]
,
}
# LoRA specific attributes
supported_lora_modules
=
[
...
...
@@ -344,8 +356,7 @@ class ChatGLMForCausalLM(nn.Module):
super
().
__init__
()
self
.
config
:
ChatGLMConfig
=
config
self
.
quant_config
=
quant_config
self
.
max_position_embeddings
=
getattr
(
config
,
"max_sequence_length"
,
8192
)
self
.
max_position_embeddings
=
getattr
(
config
,
"max_sequence_length"
,
8192
)
self
.
transformer
=
ChatGLMModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
self
.
transformer
.
output_layer
self
.
logits_processor
=
LogitsProcessor
(
config
)
...
...
@@ -357,8 +368,7 @@ class ChatGLMForCausalLM(nn.Module):
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
...
...
@@ -382,10 +392,10 @@ class ChatGLMForCausalLM(nn.Module):
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
ChatGLMForCausalLM
# compat: glm model.config class == ChatGLMModel
EntryClassRemapping
=
[(
"ChatGLMModel"
,
ChatGLMForCausalLM
)]
python/sglang/srt/models/commandr.py
View file @
fb9296f0
...
...
@@ -23,7 +23,7 @@
# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
from
typing
import
Optional
,
Tuple
,
Iterable
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
import
torch.utils.checkpoint
...
...
@@ -44,8 +44,8 @@ from vllm.model_executor.layers.linear import (
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.utils
import
set_weight_attrs
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
...
python/sglang/srt/models/dbrx.py
View file @
fb9296f0
...
...
@@ -24,8 +24,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
...
...
python/sglang/srt/models/gemma.py
View file @
fb9296f0
...
...
@@ -6,7 +6,7 @@ from typing import Iterable, Optional, Tuple
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
LoRA
Config
,
Cache
Config
from
vllm.config
import
Cache
Config
,
LoRA
Config
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
...
python/sglang/srt/models/grok.py
View file @
fb9296f0
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Grok1 model."""
from
typing
import
Iterable
,
Optional
,
Tuple
,
List
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
...
...
@@ -9,7 +9,6 @@ import torch.nn.functional as F
import
tqdm
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
...
...
@@ -35,12 +34,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
print_warning_once
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.fused_moe
import
fused_moe
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.controller.model_runner
import
InputMetadata
use_fused
=
True
...
...
@@ -134,9 +132,12 @@ class Grok1MoEUnfused(nn.Module):
final_hidden_states
=
torch
.
zeros
(
(
hidden_states
.
shape
[
0
],
hidden_dim
),
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
)
expert_mask
=
torch
.
nn
.
functional
.
one_hot
(
selected_experts
,
num_classes
=
self
.
num_total_experts
).
permute
(
2
,
1
,
0
)
expert_mask
=
torch
.
nn
.
functional
.
one_hot
(
selected_experts
,
num_classes
=
self
.
num_total_experts
).
permute
(
2
,
1
,
0
)
for
expert_idx
in
self
.
expert_indicies
:
expert_layer
=
self
.
experts
[
expert_idx
]
...
...
@@ -153,7 +154,10 @@ class Grok1MoEUnfused(nn.Module):
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state
=
hidden_states
[
None
,
top_x_list
].
reshape
(
-
1
,
hidden_dim
)
current_hidden_states
=
expert_layer
(
current_state
)
*
routing_weights
[
top_x_list
,
idx_list
,
None
]
current_hidden_states
=
(
expert_layer
(
current_state
)
*
routing_weights
[
top_x_list
,
idx_list
,
None
]
)
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
...
...
@@ -198,32 +202,46 @@ class Grok1MoE(nn.Module):
self
.
params_dtype
=
params_dtype
# Gate always runs at half / full precision for now.
self
.
gate
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
gate
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
num_total_experts
,
bias
=
False
,
params_dtype
=
self
.
params_dtype
,
quant_config
=
None
)
quant_config
=
None
,
)
if
self
.
use_fp8
and
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
params_dtype
=
torch
.
float8_e4m3fn
self
.
w13_weight
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
torch
.
empty
(
self
.
num_total_experts
,
2
*
self
.
intermediate_size
,
self
.
hidden_size
,
dtype
=
params_dtype
))
dtype
=
params_dtype
,
)
)
self
.
w2_weight
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
torch
.
empty
(
self
.
num_total_experts
,
self
.
hidden_size
,
self
.
intermediate_size
,
dtype
=
params_dtype
))
dtype
=
params_dtype
,
)
)
set_weight_attrs
(
self
.
w13_weight
,
{
set_weight_attrs
(
self
.
w13_weight
,
{
"weight_loader"
:
self
.
weight_loader
,
})
set_weight_attrs
(
self
.
w2_weight
,
{
},
)
set_weight_attrs
(
self
.
w2_weight
,
{
"weight_loader"
:
self
.
weight_loader
,
})
},
)
# Used for fp8.
self
.
w13_scale
=
None
...
...
@@ -233,46 +251,69 @@ class Grok1MoE(nn.Module):
if
self
.
use_fp8
:
# WEIGHT_SCALE (for fp8)
self
.
w13_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
self
.
w2_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
self
.
w13_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
self
.
w2_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if
quant_config
.
is_checkpoint_fp8_serialized
:
set_weight_attrs
(
self
.
w13_scale
,
{
set_weight_attrs
(
self
.
w13_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
})
set_weight_attrs
(
self
.
w2_scale
,
{
},
)
set_weight_attrs
(
self
.
w2_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
})
},
)
# ACT_SCALE (for fp8)
if
quant_config
.
activation_scheme
==
"static"
:
if
not
quant_config
.
is_checkpoint_fp8_serialized
:
raise
ValueError
(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
self
.
a13_scale
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
self
.
a2_scale
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
set_weight_attrs
(
self
.
a13_scale
,
{
"was not serialized fp8."
)
self
.
a13_scale
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
self
.
a2_scale
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
set_weight_attrs
(
self
.
a13_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
})
set_weight_attrs
(
self
.
a2_scale
,
{
},
)
set_weight_attrs
(
self
.
a2_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
})
},
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
expert_id
:
int
,
pre_sharded
:
bool
):
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
expert_id
:
int
,
pre_sharded
:
bool
,
):
param_data
=
param
.
data
shard_size
=
self
.
intermediate_size
if
pre_sharded
:
...
...
@@ -284,8 +325,9 @@ class Grok1MoE(nn.Module):
if
weight_name
.
endswith
(
"w1.weight"
):
param_data
[
expert_id
,
0
:
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
if
weight_name
.
endswith
(
"w3.weight"
):
param_data
[
expert_id
,
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
param_data
[
expert_id
,
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[
shard
,
:
]
if
weight_name
.
endswith
(
"w2.weight"
):
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
if
"act_scale"
in
weight_name
or
"weight_scale"
in
weight_name
:
...
...
@@ -298,17 +340,17 @@ class Grok1MoE(nn.Module):
# If checkpoint is fp16, quantize here.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
w13_weight
=
torch
.
empty_like
(
self
.
w13_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
w2_weight
=
torch
.
empty_like
(
self
.
w2_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
w13_weight
=
torch
.
empty_like
(
self
.
w13_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
w2_weight
=
torch
.
empty_like
(
self
.
w2_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
for
expert
in
range
(
self
.
num_total_experts
):
w13_weight
[
expert
,
:,
:],
self
.
w13_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
self
.
w13_weight
.
data
[
expert
,
:,
:]
)
w2_weight
[
expert
,
:,
:],
self
.
w2_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
self
.
w2_weight
.
data
[
expert
,
:,
:]
)
w13_weight
[
expert
,
:,
:],
self
.
w13_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
self
.
w13_weight
.
data
[
expert
,
:,
:]
)
w2_weight
[
expert
,
:,
:],
self
.
w2_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
self
.
w2_weight
.
data
[
expert
,
:,
:]
)
self
.
w13_weight
=
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
self
.
w2_weight
=
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
...
...
@@ -319,25 +361,25 @@ class Grok1MoE(nn.Module):
if
self
.
a13_scale
is
None
or
self
.
a2_scale
is
None
:
raise
ValueError
(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
"activation scales are None."
)
if
(
not
all_close_1d
(
self
.
a13_scale
)
or
not
all_close_1d
(
self
.
a2_scale
)):
if
not
all_close_1d
(
self
.
a13_scale
)
or
not
all_close_1d
(
self
.
a2_scale
):
print_warning_once
(
"Found act_scales that are not equal for fp8 MoE layer. "
"Using the maximum across experts for each layer. "
)
"Using the maximum across experts for each layer. "
)
self
.
a13_scale
=
nn
.
Parameter
(
self
.
a13_scale
.
max
(),
requires_grad
=
False
)
self
.
a2_scale
=
nn
.
Parameter
(
self
.
a2_scale
.
max
(),
requires_grad
=
False
)
self
.
a13_scale
=
nn
.
Parameter
(
self
.
a13_scale
.
max
(),
requires_grad
=
False
)
self
.
a2_scale
=
nn
.
Parameter
(
self
.
a2_scale
.
max
(),
requires_grad
=
False
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_size
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
fused_moe
(
hidden_states
,
final_hidden_states
=
fused_moe
(
hidden_states
,
self
.
w13_weight
,
self
.
w2_weight
,
router_logits
,
...
...
@@ -348,11 +390,11 @@ class Grok1MoE(nn.Module):
w1_scale
=
self
.
w13_scale
,
w2_scale
=
self
.
w2_scale
,
a1_scale
=
self
.
a13_scale
,
a2_scale
=
self
.
a2_scale
)
a2_scale
=
self
.
a2_scale
,
)
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_size
)
...
...
@@ -462,10 +504,12 @@ class Grok1DecoderLayer(nn.Module):
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
)
else
:
self
.
block_sparse_moe
=
Grok1MoEUnfused
(
config
=
config
,
quant_config
=
quant_config
)
config
=
config
,
quant_config
=
quant_config
)
self
.
pre_attn_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attn_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
pre_moe_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -478,12 +522,21 @@ class Grok1DecoderLayer(nn.Module):
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
post_attn_norm
(
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
self
.
pre_attn_norm
(
hidden_states
),
hidden_states
=
(
self
.
post_attn_norm
(
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
self
.
pre_attn_norm
(
hidden_states
),
input_metadata
=
input_metadata
,
))
+
hidden_states
)
)
+
hidden_states
)
hidden_states
=
self
.
post_moe_norm
(
self
.
block_sparse_moe
(
self
.
pre_moe_norm
(
hidden_states
)))
+
hidden_states
hidden_states
=
(
self
.
post_moe_norm
(
self
.
block_sparse_moe
(
self
.
pre_moe_norm
(
hidden_states
)))
+
hidden_states
)
return
hidden_states
...
...
@@ -525,9 +578,7 @@ class Grok1Model(nn.Module):
hidden_states
.
mul_
(
self
.
config
.
embedding_multiplier_scale
)
for
i
in
range
(
len
(
self
.
layers
)):
hidden_states
=
self
.
layers
[
i
](
positions
,
hidden_states
,
input_metadata
)
hidden_states
=
self
.
layers
[
i
](
positions
,
hidden_states
,
input_metadata
)
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
.
mul_
(
self
.
config
.
output_multiplier_scale
)
...
...
@@ -572,28 +623,41 @@ class Grok1ModelForCausalLM(nn.Module):
]
if
use_fused
:
expert_params_mapping
=
[
expert_params_mapping
=
(
[
# These are the weight scales for the experts
# (param_name, weight_name, expert_id)
(
"w13_scale"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight_scale"
,
expert_id
)
(
"w13_scale"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight_scale"
,
expert_id
,
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
+
[
]
+
[
# These are the weights for the experts
# (param_name, weight_name, expert_id)
(
"w13_weight"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2_weight"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
)
(
"w13_weight"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2_weight"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
,
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
+
[
]
+
[
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
(
"a13_scale"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"a2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.act_scale"
,
expert_id
)
(
"a13_scale"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"a2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.act_scale"
,
expert_id
,
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
)
else
:
expert_params_mapping
=
[]
...
...
@@ -601,11 +665,11 @@ class Grok1ModelForCausalLM(nn.Module):
if
get_tensor_model_parallel_rank
()
==
0
:
weights
=
tqdm
.
tqdm
(
weights
,
total
=
int
(
len
(
params_dict
)
*
3.4
))
for
name
,
loaded_weight
in
weights
:
#print(get_tensor_model_parallel_rank(), name)
#
print(get_tensor_model_parallel_rank(), name)
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
...
...
@@ -623,19 +687,22 @@ class Grok1ModelForCausalLM(nn.Module):
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
weight_loader
(
param
,
loaded_weight
,
weight_name
,
expert_id
=
expert_id
,
pre_sharded
=
get_tensor_model_parallel_world_size
()
>
1
)
pre_sharded
=
get_tensor_model_parallel_world_size
()
>
1
,
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
...
...
@@ -645,10 +712,11 @@ def all_close_1d(x: torch.Tensor) -> bool:
old_prepare_weights
=
getattr
(
DefaultModelLoader
,
"_prepare_weights"
)
def
_prepare_presharded_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
fall_back_to_pt
:
bool
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
def
_prepare_presharded_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
fall_back_to_pt
:
bool
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
import
glob
import
os
...
...
python/sglang/srt/models/llama2.py
View file @
fb9296f0
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Iterable
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
import
torch
import
tqdm
...
...
@@ -10,7 +10,7 @@ from transformers import LlamaConfig
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
get_tensor_model_parallel_world_size
,
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
...
@@ -158,9 +158,11 @@ class LlamaDecoderLayer(nn.Module):
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
if
rope_scaling
is
not
None
and
getattr
(
config
,
"original_max_position_embeddings"
,
None
):
config
,
"original_max_position_embeddings"
,
None
):
rope_scaling
[
"original_max_position_embeddings"
]
=
(
config
.
original_max_position_embeddings
)
config
.
original_max_position_embeddings
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
self_attn
=
LlamaAttention
(
hidden_size
=
self
.
hidden_size
,
...
...
python/sglang/srt/models/llava.py
View file @
fb9296f0
"""Inference-only LLaVa model compatible with HuggingFace weights."""
from
typing
import
List
,
Iterable
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
from
torch
import
nn
from
transformers
import
CLIPVisionModel
,
CLIPVisionConfig
,
LlavaConfig
,
Qwen2Config
,
MistralConfig
from
transformers
import
(
CLIPVisionConfig
,
CLIPVisionModel
,
LlavaConfig
,
MistralConfig
,
Qwen2Config
,
)
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
vllm.config
import
CacheConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
...
...
@@ -19,8 +25,8 @@ from sglang.srt.mm_utils import (
unpad_image_shape
,
)
from
sglang.srt.models.llama2
import
LlamaForCausalLM
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
from
sglang.srt.models.mistral
import
MistralForCausalLM
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
class
LlavaLlamaForCausalLM
(
nn
.
Module
):
...
...
@@ -359,6 +365,7 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
first_call
=
True
def
clip_vision_embed_forward
(
self
,
pixel_values
:
torch
.
FloatTensor
)
->
torch
.
Tensor
:
batch_size
=
pixel_values
.
shape
[
0
]
...
...
@@ -388,8 +395,4 @@ def monkey_path_clip_vision_embed_forward():
)
EntryClass
=
[
LlavaLlamaForCausalLM
,
LlavaQwenForCausalLM
,
LlavaMistralForCausalLM
]
EntryClass
=
[
LlavaLlamaForCausalLM
,
LlavaQwenForCausalLM
,
LlavaMistralForCausalLM
]
python/sglang/srt/models/llavavid.py
View file @
fb9296f0
"""Inference-only LLaVa video model compatible with HuggingFace weights."""
from
typing
import
List
,
Iterable
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
...
...
python/sglang/srt/models/mixtral.py
View file @
fb9296f0
...
...
@@ -33,13 +33,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
print_warning_once
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.controller.model_runner
import
InputMetadata
class
MixtralMoE
(
nn
.
Module
):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
across all ranks.
...
...
@@ -76,32 +74,46 @@ class MixtralMoE(nn.Module):
self
.
params_dtype
=
params_dtype
# Gate always runs at half / full precision for now.
self
.
gate
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
gate
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
num_total_experts
,
bias
=
False
,
params_dtype
=
self
.
params_dtype
,
quant_config
=
None
)
quant_config
=
None
,
)
if
self
.
use_fp8
and
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
params_dtype
=
torch
.
float8_e4m3fn
self
.
w13_weight
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
torch
.
empty
(
self
.
num_total_experts
,
2
*
self
.
intermediate_size
,
self
.
hidden_size
,
dtype
=
params_dtype
))
dtype
=
params_dtype
,
)
)
self
.
w2_weight
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
torch
.
empty
(
self
.
num_total_experts
,
self
.
hidden_size
,
self
.
intermediate_size
,
dtype
=
params_dtype
))
dtype
=
params_dtype
,
)
)
set_weight_attrs
(
self
.
w13_weight
,
{
set_weight_attrs
(
self
.
w13_weight
,
{
"weight_loader"
:
self
.
weight_loader
,
})
set_weight_attrs
(
self
.
w2_weight
,
{
},
)
set_weight_attrs
(
self
.
w2_weight
,
{
"weight_loader"
:
self
.
weight_loader
,
})
},
)
# Used for fp8.
self
.
w13_scale
=
None
...
...
@@ -111,46 +123,68 @@ class MixtralMoE(nn.Module):
if
self
.
use_fp8
:
# WEIGHT_SCALE (for fp8)
self
.
w13_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
self
.
w2_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
self
.
w13_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
self
.
w2_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if
quant_config
.
is_checkpoint_fp8_serialized
:
set_weight_attrs
(
self
.
w13_scale
,
{
set_weight_attrs
(
self
.
w13_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
})
set_weight_attrs
(
self
.
w2_scale
,
{
},
)
set_weight_attrs
(
self
.
w2_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
})
},
)
# ACT_SCALE (for fp8)
if
quant_config
.
activation_scheme
==
"static"
:
if
not
quant_config
.
is_checkpoint_fp8_serialized
:
raise
ValueError
(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
self
.
a13_scale
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
self
.
a2_scale
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
set_weight_attrs
(
self
.
a13_scale
,
{
"was not serialized fp8."
)
self
.
a13_scale
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
self
.
a2_scale
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
set_weight_attrs
(
self
.
a13_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
})
set_weight_attrs
(
self
.
a2_scale
,
{
},
)
set_weight_attrs
(
self
.
a2_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
})
},
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
expert_id
:
int
):
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
expert_id
:
int
,
):
tp_rank
=
get_tensor_model_parallel_rank
()
param_data
=
param
.
data
shard_size
=
self
.
intermediate_size
...
...
@@ -158,8 +192,9 @@ class MixtralMoE(nn.Module):
if
weight_name
.
endswith
(
"w1.weight"
):
param_data
[
expert_id
,
0
:
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
if
weight_name
.
endswith
(
"w3.weight"
):
param_data
[
expert_id
,
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
param_data
[
expert_id
,
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[
shard
,
:
]
if
weight_name
.
endswith
(
"w2.weight"
):
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
if
"act_scale"
in
weight_name
or
"weight_scale"
in
weight_name
:
...
...
@@ -172,17 +207,17 @@ class MixtralMoE(nn.Module):
# If checkpoint is fp16, quantize here.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
w13_weight
=
torch
.
empty_like
(
self
.
w13_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
w2_weight
=
torch
.
empty_like
(
self
.
w2_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
w13_weight
=
torch
.
empty_like
(
self
.
w13_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
w2_weight
=
torch
.
empty_like
(
self
.
w2_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
for
expert
in
range
(
self
.
num_total_experts
):
w13_weight
[
expert
,
:,
:],
self
.
w13_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
self
.
w13_weight
.
data
[
expert
,
:,
:]
)
w2_weight
[
expert
,
:,
:],
self
.
w2_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
self
.
w2_weight
.
data
[
expert
,
:,
:]
)
w13_weight
[
expert
,
:,
:],
self
.
w13_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
self
.
w13_weight
.
data
[
expert
,
:,
:]
)
w2_weight
[
expert
,
:,
:],
self
.
w2_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
self
.
w2_weight
.
data
[
expert
,
:,
:]
)
self
.
w13_weight
=
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
self
.
w2_weight
=
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
...
...
@@ -193,25 +228,25 @@ class MixtralMoE(nn.Module):
if
self
.
a13_scale
is
None
or
self
.
a2_scale
is
None
:
raise
ValueError
(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
"activation scales are None."
)
if
(
not
all_close_1d
(
self
.
a13_scale
)
or
not
all_close_1d
(
self
.
a2_scale
)):
if
not
all_close_1d
(
self
.
a13_scale
)
or
not
all_close_1d
(
self
.
a2_scale
):
print_warning_once
(
"Found act_scales that are not equal for fp8 MoE layer. "
"Using the maximum across experts for each layer. "
)
"Using the maximum across experts for each layer. "
)
self
.
a13_scale
=
nn
.
Parameter
(
self
.
a13_scale
.
max
(),
requires_grad
=
False
)
self
.
a2_scale
=
nn
.
Parameter
(
self
.
a2_scale
.
max
(),
requires_grad
=
False
)
self
.
a13_scale
=
nn
.
Parameter
(
self
.
a13_scale
.
max
(),
requires_grad
=
False
)
self
.
a2_scale
=
nn
.
Parameter
(
self
.
a2_scale
.
max
(),
requires_grad
=
False
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_size
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
fused_moe
(
hidden_states
,
final_hidden_states
=
fused_moe
(
hidden_states
,
self
.
w13_weight
,
self
.
w2_weight
,
router_logits
,
...
...
@@ -222,11 +257,11 @@ class MixtralMoE(nn.Module):
w1_scale
=
self
.
w13_scale
,
w2_scale
=
self
.
w2_scale
,
a1_scale
=
self
.
a13_scale
,
a2_scale
=
self
.
a2_scale
)
a2_scale
=
self
.
a2_scale
,
)
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_size
)
...
...
@@ -335,7 +370,8 @@ class MixtralDecoderLayer(nn.Module):
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
...
...
@@ -444,35 +480,48 @@ class MixtralForCausalLM(nn.Module):
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
expert_params_mapping
=
[
expert_params_mapping
=
(
[
# These are the weight scales for the experts
# (param_name, weight_name, expert_id)
(
"w13_scale"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight_scale"
,
expert_id
)
(
"w13_scale"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight_scale"
,
expert_id
,
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
+
[
]
+
[
# These are the weights for the experts
# (param_name, weight_name, expert_id)
(
"w13_weight"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2_weight"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
)
(
"w13_weight"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2_weight"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
,
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
+
[
]
+
[
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
(
"a13_scale"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"a2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.act_scale"
,
expert_id
)
(
"a13_scale"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"a2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.act_scale"
,
expert_id
,
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
)
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
...
...
@@ -490,18 +539,18 @@ class MixtralForCausalLM(nn.Module):
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
weight_name
,
expert_id
=
expert_id
)
weight_loader
(
param
,
loaded_weight
,
weight_name
,
expert_id
=
expert_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
...
...
python/sglang/srt/models/mixtral_quant.py
View file @
fb9296f0
...
...
@@ -28,7 +28,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.controller.model_runner
import
InputMetadata
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment