Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fb9296f0
Unverified
Commit
fb9296f0
authored
Jun 12, 2024
by
Ying Sheng
Committed by
GitHub
Jun 12, 2024
Browse files
Higher priority for user input of max_prefill_tokens & format (#540)
parent
1374334d
Changes
50
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
505 additions
and
351 deletions
+505
-351
python/sglang/srt/managers/controller/manager_multi.py
python/sglang/srt/managers/controller/manager_multi.py
+5
-5
python/sglang/srt/managers/controller/manager_single.py
python/sglang/srt/managers/controller/manager_single.py
+5
-2
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+13
-5
python/sglang/srt/managers/controller/radix_cache.py
python/sglang/srt/managers/controller/radix_cache.py
+1
-0
python/sglang/srt/managers/controller/schedule_heuristic.py
python/sglang/srt/managers/controller/schedule_heuristic.py
+1
-0
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+14
-14
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+2
-1
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+1
-1
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+4
-4
python/sglang/srt/model_config.py
python/sglang/srt/model_config.py
+21
-12
python/sglang/srt/models/chatglm.py
python/sglang/srt/models/chatglm.py
+48
-38
python/sglang/srt/models/commandr.py
python/sglang/srt/models/commandr.py
+2
-2
python/sglang/srt/models/dbrx.py
python/sglang/srt/models/dbrx.py
+1
-1
python/sglang/srt/models/gemma.py
python/sglang/srt/models/gemma.py
+1
-1
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+204
-136
python/sglang/srt/models/llama2.py
python/sglang/srt/models/llama2.py
+6
-4
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+11
-8
python/sglang/srt/models/llavavid.py
python/sglang/srt/models/llavavid.py
+1
-1
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+164
-115
python/sglang/srt/models/mixtral_quant.py
python/sglang/srt/models/mixtral_quant.py
+0
-1
No files found.
python/sglang/srt/managers/controller/manager_multi.py
View file @
fb9296f0
...
@@ -13,15 +13,15 @@ import zmq
...
@@ -13,15 +13,15 @@ import zmq
import
zmq.asyncio
import
zmq.asyncio
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.managers.controller.dp_worker
import
(
DataParallelWorkerThread
,
start_data_parallel_worker
,
)
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
AbortReq
,
FlushCacheReq
,
FlushCacheReq
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
)
)
from
sglang.srt.managers.controller.dp_worker
import
(
DataParallelWorkerThread
,
start_data_parallel_worker
,
)
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
...
@@ -136,7 +136,7 @@ class Controller:
...
@@ -136,7 +136,7 @@ class Controller:
self
.
recv_reqs
=
[]
self
.
recv_reqs
=
[]
if
next_step_input
:
if
next_step_input
:
await
self
.
dispatching
(
next_step_input
)
await
self
.
dispatching
(
next_step_input
)
#else:
#
else:
# logger.error("There is no live worker.")
# logger.error("There is no live worker.")
await
asyncio
.
sleep
(
global_config
.
wait_for_new_request_delay
)
await
asyncio
.
sleep
(
global_config
.
wait_for_new_request_delay
)
...
...
python/sglang/srt/managers/controller/manager_single.py
View file @
fb9296f0
"""A controller that manages a group of tensor parallel workers."""
"""A controller that manages a group of tensor parallel workers."""
import
asyncio
import
asyncio
import
logging
import
logging
import
time
import
time
...
@@ -49,7 +50,9 @@ class ControllerSingle:
...
@@ -49,7 +50,9 @@ class ControllerSingle:
# async sleep for receiving the subsequent request and avoiding cache miss
# async sleep for receiving the subsequent request and avoiding cache miss
slept
=
False
slept
=
False
if
len
(
out_pyobjs
)
!=
0
:
if
len
(
out_pyobjs
)
!=
0
:
has_finished
=
any
([
obj
.
finished_reason
is
not
None
for
obj
in
out_pyobjs
])
has_finished
=
any
(
[
obj
.
finished_reason
is
not
None
for
obj
in
out_pyobjs
]
)
if
has_finished
:
if
has_finished
:
if
self
.
request_dependency_delay
>
0
:
if
self
.
request_dependency_delay
>
0
:
slept
=
True
slept
=
True
...
@@ -94,4 +97,4 @@ def start_controller_process(
...
@@ -94,4 +97,4 @@ def start_controller_process(
except
Exception
:
except
Exception
:
logger
.
error
(
"Exception in ControllerSingle:
\n
"
+
get_exception_traceback
())
logger
.
error
(
"Exception in ControllerSingle:
\n
"
+
get_exception_traceback
())
finally
:
finally
:
kill_parent_process
()
kill_parent_process
()
\ No newline at end of file
python/sglang/srt/managers/controller/model_runner.py
View file @
fb9296f0
"""ModelRunner runs the forward passes of the models."""
"""ModelRunner runs the forward passes of the models."""
import
importlib
import
importlib
import
importlib.resources
import
importlib.resources
import
logging
import
logging
...
@@ -12,15 +13,18 @@ import torch
...
@@ -12,15 +13,18 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.config
import
DeviceConfig
,
LoadConfig
from
vllm.config
import
DeviceConfig
,
LoadConfig
from
vllm.config
import
ModelConfig
as
VllmModelConfig
from
vllm.config
import
ModelConfig
as
VllmModelConfig
from
vllm.distributed
import
init
ialize_model_parallel
,
init_distributed_environment
from
vllm.distributed
import
init
_distributed_environment
,
initialize_model_parallel
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
sglang.srt.managers.controller.infer_batch
import
Batch
,
ForwardMode
from
sglang.srt.managers.controller.infer_batch
import
Batch
,
ForwardMode
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_available_gpu_memory
,
is_multimodal_model
,
monkey_patch_vllm_p2p_access_check
from
sglang.srt.utils
import
(
get_available_gpu_memory
,
is_multimodal_model
,
monkey_patch_vllm_p2p_access_check
,
)
logger
=
logging
.
getLogger
(
"srt.model_runner"
)
logger
=
logging
.
getLogger
(
"srt.model_runner"
)
...
@@ -441,7 +445,9 @@ def import_model_classes():
...
@@ -441,7 +445,9 @@ def import_model_classes():
module
=
importlib
.
import_module
(
name
)
module
=
importlib
.
import_module
(
name
)
if
hasattr
(
module
,
"EntryClass"
):
if
hasattr
(
module
,
"EntryClass"
):
entry
=
module
.
EntryClass
entry
=
module
.
EntryClass
if
isinstance
(
entry
,
list
):
# To support multiple model classes in one module
if
isinstance
(
entry
,
list
):
# To support multiple model classes in one module
for
tmp
in
entry
:
for
tmp
in
entry
:
model_arch_name_to_cls
[
tmp
.
__name__
]
=
tmp
model_arch_name_to_cls
[
tmp
.
__name__
]
=
tmp
else
:
else
:
...
@@ -449,7 +455,9 @@ def import_model_classes():
...
@@ -449,7 +455,9 @@ def import_model_classes():
# compat: some models such as chatglm has incorrect class set in config.json
# compat: some models such as chatglm has incorrect class set in config.json
# usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
# usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
if
hasattr
(
module
,
"EntryClassRemapping"
)
and
isinstance
(
module
.
EntryClassRemapping
,
list
):
if
hasattr
(
module
,
"EntryClassRemapping"
)
and
isinstance
(
module
.
EntryClassRemapping
,
list
):
for
remap
in
module
.
EntryClassRemapping
:
for
remap
in
module
.
EntryClassRemapping
:
if
isinstance
(
remap
,
tuple
)
and
len
(
remap
)
==
2
:
if
isinstance
(
remap
,
tuple
)
and
len
(
remap
)
==
2
:
model_arch_name_to_cls
[
remap
[
0
]]
=
remap
[
1
]
model_arch_name_to_cls
[
remap
[
0
]]
=
remap
[
1
]
...
...
python/sglang/srt/managers/controller/radix_cache.py
View file @
fb9296f0
"""
"""
The radix tree data structure for managing the KV cache.
The radix tree data structure for managing the KV cache.
"""
"""
import
heapq
import
heapq
import
time
import
time
from
collections
import
defaultdict
from
collections
import
defaultdict
...
...
python/sglang/srt/managers/controller/schedule_heuristic.py
View file @
fb9296f0
"""Request scheduler heuristic."""
"""Request scheduler heuristic."""
import
random
import
random
from
collections
import
defaultdict
from
collections
import
defaultdict
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
fb9296f0
...
@@ -15,22 +15,22 @@ from sglang.global_config import global_config
...
@@ -15,22 +15,22 @@ from sglang.global_config import global_config
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchTokenIDOut
,
FlushCacheReq
,
TokenizedGenerateReqInput
,
)
from
sglang.srt.managers.controller.infer_batch
import
(
from
sglang.srt.managers.controller.infer_batch
import
(
FINISH_ABORT
,
BaseFinishReason
,
BaseFinishReason
,
Batch
,
Batch
,
FINISH_ABORT
,
ForwardMode
,
ForwardMode
,
Req
,
Req
,
)
)
from
sglang.srt.managers.controller.model_runner
import
ModelRunner
from
sglang.srt.managers.controller.model_runner
import
ModelRunner
from
sglang.srt.managers.controller.radix_cache
import
RadixCache
from
sglang.srt.managers.controller.radix_cache
import
RadixCache
from
sglang.srt.managers.controller.schedule_heuristic
import
ScheduleHeuristic
from
sglang.srt.managers.controller.schedule_heuristic
import
ScheduleHeuristic
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchTokenIDOut
,
FlushCacheReq
,
TokenizedGenerateReqInput
,
)
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.server_args
import
ModelPortArgs
,
ServerArgs
from
sglang.srt.server_args
import
ModelPortArgs
,
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
...
@@ -96,13 +96,13 @@ class ModelTpServer:
...
@@ -96,13 +96,13 @@ class ModelTpServer:
trust_remote_code
=
server_args
.
trust_remote_code
,
trust_remote_code
=
server_args
.
trust_remote_code
,
)
)
self
.
max_total_num_tokens
=
self
.
model_runner
.
max_total_num_tokens
self
.
max_total_num_tokens
=
self
.
model_runner
.
max_total_num_tokens
self
.
max_prefill_tokens
=
max
(
self
.
max_prefill_tokens
=
(
self
.
model_config
.
context_len
,
max
(
(
self
.
model_config
.
context_len
,
min
(
self
.
max_total_num_tokens
//
6
,
65536
)
min
(
self
.
max_total_num_tokens
//
6
,
65536
)
,
if
server_args
.
max_prefill_tokens
is
None
)
else
server_args
.
max_prefill_tokens
if
server_args
.
max_prefill_tokens
is
None
),
else
server_args
.
max_prefill_tokens
)
)
self
.
max_running_requests
=
(
self
.
max_running_requests
=
(
self
.
max_total_num_tokens
//
2
self
.
max_total_num_tokens
//
2
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
fb9296f0
"""DetokenizerManager is a process that detokenizes the token ids."""
"""DetokenizerManager is a process that detokenizes the token ids."""
import
asyncio
import
asyncio
import
inspect
import
inspect
...
@@ -7,10 +8,10 @@ import zmq
...
@@ -7,10 +8,10 @@ import zmq
import
zmq.asyncio
import
zmq.asyncio
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.controller.infer_batch
import
FINISH_MATCHED_STR
from
sglang.srt.managers.io_struct
import
BatchStrOut
,
BatchTokenIDOut
from
sglang.srt.managers.io_struct
import
BatchStrOut
,
BatchTokenIDOut
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.utils
import
get_exception_traceback
,
graceful_registry
from
sglang.utils
import
get_exception_traceback
,
graceful_registry
from
sglang.srt.managers.controller.infer_batch
import
FINISH_MATCHED_STR
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
...
...
python/sglang/srt/managers/io_struct.py
View file @
fb9296f0
...
@@ -7,8 +7,8 @@ import uuid
...
@@ -7,8 +7,8 @@ import uuid
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.managers.controller.infer_batch
import
BaseFinishReason
from
sglang.srt.managers.controller.infer_batch
import
BaseFinishReason
from
sglang.srt.sampling_params
import
SamplingParams
@
dataclass
@
dataclass
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
fb9296f0
"""TokenizerManager is a process that tokenizes the text."""
"""TokenizerManager is a process that tokenizes the text."""
import
asyncio
import
asyncio
import
concurrent.futures
import
concurrent.futures
import
dataclasses
import
dataclasses
import
logging
import
logging
import
multiprocessing
as
mp
import
multiprocessing
as
mp
import
os
import
os
from
typing
import
List
,
Dic
t
from
typing
import
Dict
,
Lis
t
import
numpy
as
np
import
numpy
as
np
import
transformers
import
transformers
...
@@ -23,11 +24,11 @@ from sglang.srt.hf_transformers_utils import (
...
@@ -23,11 +24,11 @@ from sglang.srt.hf_transformers_utils import (
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
AbortReq
,
BatchStrOut
,
BatchStrOut
,
BatchTokenIDOut
,
FlushCacheReq
,
FlushCacheReq
,
GenerateReqInput
,
GenerateReqInput
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
)
)
from
sglang.srt.managers.io_struct
import
BatchTokenIDOut
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
...
@@ -91,7 +92,7 @@ class TokenizerManager:
...
@@ -91,7 +92,7 @@ class TokenizerManager:
)
)
self
.
to_create_loop
=
True
self
.
to_create_loop
=
True
self
.
rid_to_state
:
Dict
[
str
,
ReqState
]
=
{}
self
.
rid_to_state
:
Dict
[
str
,
ReqState
]
=
{}
async
def
get_pixel_values
(
self
,
image_data
):
async
def
get_pixel_values
(
self
,
image_data
):
aspect_ratio
=
getattr
(
self
.
hf_config
,
"image_aspect_ratio"
,
None
)
aspect_ratio
=
getattr
(
self
.
hf_config
,
"image_aspect_ratio"
,
None
)
...
@@ -322,7 +323,6 @@ class TokenizerManager:
...
@@ -322,7 +323,6 @@ class TokenizerManager:
state
.
finished
=
recv_obj
.
finished_reason
[
i
]
is
not
None
state
.
finished
=
recv_obj
.
finished_reason
[
i
]
is
not
None
state
.
event
.
set
()
state
.
event
.
set
()
def
convert_logprob_style
(
def
convert_logprob_style
(
self
,
ret
,
return_logprob
,
top_logprobs_num
,
return_text_in_logprobs
self
,
ret
,
return_logprob
,
top_logprobs_num
,
return_text_in_logprobs
):
):
...
...
python/sglang/srt/model_config.py
View file @
fb9296f0
from
typing
import
Optional
from
typing
import
Optional
from
sglang.srt.hf_transformers_utils
import
get_config
,
get_context_length
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
sglang.srt.hf_transformers_utils
import
get_config
,
get_context_length
class
ModelConfig
:
class
ModelConfig
:
def
__init__
(
def
__init__
(
...
@@ -17,8 +18,12 @@ class ModelConfig:
...
@@ -17,8 +18,12 @@ class ModelConfig:
self
.
trust_remote_code
=
trust_remote_code
self
.
trust_remote_code
=
trust_remote_code
self
.
revision
=
revision
self
.
revision
=
revision
self
.
model_overide_args
=
model_overide_args
self
.
model_overide_args
=
model_overide_args
self
.
hf_config
=
get_config
(
self
.
path
,
trust_remote_code
,
revision
,
self
.
hf_config
=
get_config
(
model_overide_args
=
model_overide_args
)
self
.
path
,
trust_remote_code
,
revision
,
model_overide_args
=
model_overide_args
,
)
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
if
context_length
is
not
None
:
if
context_length
is
not
None
:
self
.
context_len
=
context_length
self
.
context_len
=
context_length
...
@@ -55,18 +60,23 @@ class ModelConfig:
...
@@ -55,18 +60,23 @@ class ModelConfig:
# KV heads.
# KV heads.
falcon_model_types
=
[
"falcon"
,
"RefinedWeb"
,
"RefinedWebModel"
]
falcon_model_types
=
[
"falcon"
,
"RefinedWeb"
,
"RefinedWebModel"
]
new_decoder_arch_falcon
=
(
new_decoder_arch_falcon
=
(
self
.
hf_config
.
model_type
in
falcon_model_types
self
.
hf_config
.
model_type
in
falcon_model_types
and
getattr
(
self
.
hf_config
,
"new_decoder_architecture"
,
False
))
and
getattr
(
self
.
hf_config
,
"new_decoder_architecture"
,
False
)
if
not
new_decoder_arch_falcon
and
getattr
(
self
.
hf_text_config
,
)
"multi_query"
,
False
):
if
not
new_decoder_arch_falcon
and
getattr
(
self
.
hf_text_config
,
"multi_query"
,
False
):
# Multi-query attention, only one KV head.
# Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case.
# Currently, tensor parallelism is not supported in this case.
return
1
return
1
# For DBRX and MPT
# For DBRX and MPT
if
self
.
hf_config
.
model_type
in
[
"dbrx"
,
"mpt"
]:
if
self
.
hf_config
.
model_type
in
[
"dbrx"
,
"mpt"
]:
return
getattr
(
self
.
hf_config
.
attn_config
,
"kv_n_heads"
,
return
getattr
(
self
.
hf_config
.
num_attention_heads
)
self
.
hf_config
.
attn_config
,
"kv_n_heads"
,
self
.
hf_config
.
num_attention_heads
,
)
attributes
=
[
attributes
=
[
# For Falcon:
# For Falcon:
...
@@ -94,13 +104,12 @@ class ModelConfig:
...
@@ -94,13 +104,12 @@ class ModelConfig:
# the tensor parallel size. We will replicate the KV heads in the
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
# parallel size so each GPU has at least one KV head.
return
max
(
1
,
return
max
(
1
,
total_num_kv_heads
//
tensor_parallel_size
)
total_num_kv_heads
//
tensor_parallel_size
)
def
get_hf_text_config
(
config
:
PretrainedConfig
):
def
get_hf_text_config
(
config
:
PretrainedConfig
):
"""Get the "sub" config relevant to llm for multi modal models.
"""Get the "sub" config relevant to llm for multi modal models.
No op for pure text models.
No op for pure text models.
"""
"""
if
hasattr
(
config
,
"text_config"
):
if
hasattr
(
config
,
"text_config"
):
# The code operates under the assumption that text_config should have
# The code operates under the assumption that text_config should have
...
...
python/sglang/srt/models/chatglm.py
View file @
fb9296f0
...
@@ -5,30 +5,32 @@
...
@@ -5,30 +5,32 @@
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.controller.model_runner
import
InputMetadata
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
LayerNorm
from
torch.nn
import
LayerNorm
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
MergedColumnParallelLinear
,
RowParallelLinear
)
QKVParallelLinear
,
from
vllm.model_executor.layers.quantization.base_config
import
(
RowParallelLinear
,
QuantizationConfig
)
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs
import
ChatGLMConfig
from
vllm.transformers_utils.configs
import
ChatGLMConfig
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.controller.model_runner
import
InputMetadata
LoraConfig
=
None
LoraConfig
=
None
...
@@ -49,9 +51,11 @@ class GLMAttention(nn.Module):
...
@@ -49,9 +51,11 @@ class GLMAttention(nn.Module):
assert
self
.
total_num_heads
%
tp_size
==
0
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
multi_query_attention
=
config
.
multi_query_attention
self
.
multi_query_attention
=
config
.
multi_query_attention
self
.
total_num_kv_heads
=
(
config
.
multi_query_group_num
self
.
total_num_kv_heads
=
(
if
config
.
multi_query_attention
else
config
.
multi_query_group_num
config
.
num_attention_heads
)
if
config
.
multi_query_attention
else
config
.
num_attention_heads
)
if
self
.
total_num_kv_heads
>=
tp_size
:
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
# the KV heads across multiple tensor parallel GPUs.
...
@@ -91,11 +95,13 @@ class GLMAttention(nn.Module):
...
@@ -91,11 +95,13 @@ class GLMAttention(nn.Module):
base
=
10000
*
rope_ratio
,
base
=
10000
*
rope_ratio
,
is_neox_style
=
False
,
is_neox_style
=
False
,
)
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
attn
=
RadixAttention
(
self
.
head_dim
,
self
.
num_heads
,
self
.
scaling
,
self
.
head_dim
,
num_kv_heads
=
self
.
num_kv_heads
,
self
.
scaling
,
layer_id
=
layer_id
)
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -176,14 +182,16 @@ class GLMBlock(nn.Module):
...
@@ -176,14 +182,16 @@ class GLMBlock(nn.Module):
):
):
super
().
__init__
()
super
().
__init__
()
self
.
apply_residual_connection_post_layernorm
=
(
self
.
apply_residual_connection_post_layernorm
=
(
config
.
apply_residual_connection_post_layernorm
)
config
.
apply_residual_connection_post_layernorm
)
self
.
fp32_residual_connection
=
config
.
fp32_residual_connection
self
.
fp32_residual_connection
=
config
.
fp32_residual_connection
layer_norm_func
=
RMSNorm
if
config
.
rmsnorm
else
LayerNorm
layer_norm_func
=
RMSNorm
if
config
.
rmsnorm
else
LayerNorm
# Layernorm on the input data.
# Layernorm on the input data.
self
.
input_layernorm
=
layer_norm_func
(
config
.
hidden_size
,
self
.
input_layernorm
=
layer_norm_func
(
eps
=
config
.
layernorm_epsilon
)
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
)
# Self attention.
# Self attention.
self
.
self_attention
=
GLMAttention
(
config
,
layer_id
,
cache_config
,
quant_config
)
self
.
self_attention
=
GLMAttention
(
config
,
layer_id
,
cache_config
,
quant_config
)
...
@@ -191,7 +199,8 @@ class GLMBlock(nn.Module):
...
@@ -191,7 +199,8 @@ class GLMBlock(nn.Module):
# Layernorm on the attention output
# Layernorm on the attention output
self
.
post_attention_layernorm
=
layer_norm_func
(
self
.
post_attention_layernorm
=
layer_norm_func
(
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
)
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
)
# MLP
# MLP
self
.
mlp
=
GLMMLP
(
config
,
quant_config
)
self
.
mlp
=
GLMMLP
(
config
,
quant_config
)
...
@@ -250,16 +259,19 @@ class GLMTransformer(nn.Module):
...
@@ -250,16 +259,19 @@ class GLMTransformer(nn.Module):
self
.
num_layers
=
config
.
num_layers
self
.
num_layers
=
config
.
num_layers
# Transformer layers.
# Transformer layers.
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
(
GLMBlock
(
config
,
i
,
cache_config
,
quant_config
)
[
for
i
in
range
(
self
.
num_layers
)
GLMBlock
(
config
,
i
,
cache_config
,
quant_config
)
])
for
i
in
range
(
self
.
num_layers
)
]
)
if
self
.
post_layer_norm
:
if
self
.
post_layer_norm
:
layer_norm_func
=
RMSNorm
if
config
.
rmsnorm
else
LayerNorm
layer_norm_func
=
RMSNorm
if
config
.
rmsnorm
else
LayerNorm
# Final layer norm before output.
# Final layer norm before output.
self
.
final_layernorm
=
layer_norm_func
(
self
.
final_layernorm
=
layer_norm_func
(
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
)
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -291,16 +303,16 @@ class ChatGLMModel(nn.Module):
...
@@ -291,16 +303,16 @@ class ChatGLMModel(nn.Module):
):
):
super
().
__init__
()
super
().
__init__
()
self
.
embedding
=
VocabParallelEmbedding
(
config
.
padded_vocab_size
,
self
.
embedding
=
VocabParallelEmbedding
(
config
.
hidden_size
)
config
.
padded_vocab_size
,
config
.
hidden_size
)
self
.
num_layers
=
config
.
num_layers
self
.
num_layers
=
config
.
num_layers
self
.
multi_query_group_num
=
config
.
multi_query_group_num
self
.
multi_query_group_num
=
config
.
multi_query_group_num
self
.
kv_channels
=
config
.
kv_channels
self
.
kv_channels
=
config
.
kv_channels
self
.
encoder
=
GLMTransformer
(
config
,
cache_config
,
quant_config
)
self
.
encoder
=
GLMTransformer
(
config
,
cache_config
,
quant_config
)
self
.
output_layer
=
ParallelLMHead
(
config
.
padded_vocab_size
,
self
.
output_layer
=
ParallelLMHead
(
config
.
padded_vocab_size
,
config
.
hidden_size
)
config
.
hidden_size
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -322,7 +334,7 @@ class ChatGLMModel(nn.Module):
...
@@ -322,7 +334,7 @@ class ChatGLMModel(nn.Module):
class
ChatGLMForCausalLM
(
nn
.
Module
):
class
ChatGLMForCausalLM
(
nn
.
Module
):
packed_modules_mapping
=
{
packed_modules_mapping
=
{
"query_key_value"
:
[
"query_key_value"
],
"query_key_value"
:
[
"query_key_value"
],
"dense_h_to_4h"
:
[
"dense_h_to_4h"
]
"dense_h_to_4h"
:
[
"dense_h_to_4h"
]
,
}
}
# LoRA specific attributes
# LoRA specific attributes
supported_lora_modules
=
[
supported_lora_modules
=
[
...
@@ -344,8 +356,7 @@ class ChatGLMForCausalLM(nn.Module):
...
@@ -344,8 +356,7 @@ class ChatGLMForCausalLM(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
config
:
ChatGLMConfig
=
config
self
.
config
:
ChatGLMConfig
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
max_position_embeddings
=
getattr
(
config
,
"max_sequence_length"
,
self
.
max_position_embeddings
=
getattr
(
config
,
"max_sequence_length"
,
8192
)
8192
)
self
.
transformer
=
ChatGLMModel
(
config
,
cache_config
,
quant_config
)
self
.
transformer
=
ChatGLMModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
self
.
transformer
.
output_layer
self
.
lm_head
=
self
.
transformer
.
output_layer
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
...
@@ -357,8 +368,7 @@ class ChatGLMForCausalLM(nn.Module):
...
@@ -357,8 +368,7 @@ class ChatGLMForCausalLM(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
input_metadata
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
...
@@ -382,10 +392,10 @@ class ChatGLMForCausalLM(nn.Module):
...
@@ -382,10 +392,10 @@ class ChatGLMForCausalLM(nn.Module):
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
ChatGLMForCausalLM
EntryClass
=
ChatGLMForCausalLM
# compat: glm model.config class == ChatGLMModel
# compat: glm model.config class == ChatGLMModel
EntryClassRemapping
=
[(
"ChatGLMModel"
,
ChatGLMForCausalLM
)]
EntryClassRemapping
=
[(
"ChatGLMModel"
,
ChatGLMForCausalLM
)]
python/sglang/srt/models/commandr.py
View file @
fb9296f0
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
# This file is based on the LLama model definition file in transformers
# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
"""PyTorch Cohere model."""
from
typing
import
Optional
,
Tuple
,
Iterable
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
import
torch
import
torch.utils.checkpoint
import
torch.utils.checkpoint
...
@@ -44,8 +44,8 @@ from vllm.model_executor.layers.linear import (
...
@@ -44,8 +44,8 @@ from vllm.model_executor.layers.linear import (
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.utils
import
set_weight_attrs
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
...
python/sglang/srt/models/dbrx.py
View file @
fb9296f0
...
@@ -24,8 +24,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -24,8 +24,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
...
...
python/sglang/srt/models/gemma.py
View file @
fb9296f0
...
@@ -6,7 +6,7 @@ from typing import Iterable, Optional, Tuple
...
@@ -6,7 +6,7 @@ from typing import Iterable, Optional, Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.config
import
LoRA
Config
,
Cache
Config
from
vllm.config
import
Cache
Config
,
LoRA
Config
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
...
python/sglang/srt/models/grok.py
View file @
fb9296f0
This diff is collapsed.
Click to expand it.
python/sglang/srt/models/llama2.py
View file @
fb9296f0
# Adapted from
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
"""Inference-only LLaMA model compatible with HuggingFace weights."""
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Iterable
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
import
torch
import
torch
import
tqdm
import
tqdm
...
@@ -10,7 +10,7 @@ from transformers import LlamaConfig
...
@@ -10,7 +10,7 @@ from transformers import LlamaConfig
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
get_tensor_model_parallel_world_size
,
)
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
@@ -158,9 +158,11 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -158,9 +158,11 @@ class LlamaDecoderLayer(nn.Module):
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
if
rope_scaling
is
not
None
and
getattr
(
if
rope_scaling
is
not
None
and
getattr
(
config
,
"original_max_position_embeddings"
,
None
):
config
,
"original_max_position_embeddings"
,
None
):
rope_scaling
[
"original_max_position_embeddings"
]
=
(
rope_scaling
[
"original_max_position_embeddings"
]
=
(
config
.
original_max_position_embeddings
)
config
.
original_max_position_embeddings
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
self_attn
=
LlamaAttention
(
self
.
self_attn
=
LlamaAttention
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
...
...
python/sglang/srt/models/llava.py
View file @
fb9296f0
"""Inference-only LLaVa model compatible with HuggingFace weights."""
"""Inference-only LLaVa model compatible with HuggingFace weights."""
from
typing
import
List
,
Iterable
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
CLIPVisionModel
,
CLIPVisionConfig
,
LlavaConfig
,
Qwen2Config
,
MistralConfig
from
transformers
import
(
CLIPVisionConfig
,
CLIPVisionModel
,
LlavaConfig
,
MistralConfig
,
Qwen2Config
,
)
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
...
@@ -19,8 +25,8 @@ from sglang.srt.mm_utils import (
...
@@ -19,8 +25,8 @@ from sglang.srt.mm_utils import (
unpad_image_shape
,
unpad_image_shape
,
)
)
from
sglang.srt.models.llama2
import
LlamaForCausalLM
from
sglang.srt.models.llama2
import
LlamaForCausalLM
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
from
sglang.srt.models.mistral
import
MistralForCausalLM
from
sglang.srt.models.mistral
import
MistralForCausalLM
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
class
LlavaLlamaForCausalLM
(
nn
.
Module
):
class
LlavaLlamaForCausalLM
(
nn
.
Module
):
...
@@ -359,6 +365,7 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
...
@@ -359,6 +365,7 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
first_call
=
True
first_call
=
True
def
clip_vision_embed_forward
(
self
,
pixel_values
:
torch
.
FloatTensor
)
->
torch
.
Tensor
:
def
clip_vision_embed_forward
(
self
,
pixel_values
:
torch
.
FloatTensor
)
->
torch
.
Tensor
:
batch_size
=
pixel_values
.
shape
[
0
]
batch_size
=
pixel_values
.
shape
[
0
]
...
@@ -388,8 +395,4 @@ def monkey_path_clip_vision_embed_forward():
...
@@ -388,8 +395,4 @@ def monkey_path_clip_vision_embed_forward():
)
)
EntryClass
=
[
EntryClass
=
[
LlavaLlamaForCausalLM
,
LlavaQwenForCausalLM
,
LlavaMistralForCausalLM
]
LlavaLlamaForCausalLM
,
LlavaQwenForCausalLM
,
LlavaMistralForCausalLM
]
python/sglang/srt/models/llavavid.py
View file @
fb9296f0
"""Inference-only LLaVa video model compatible with HuggingFace weights."""
"""Inference-only LLaVa video model compatible with HuggingFace weights."""
from
typing
import
List
,
Iterable
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
...
python/sglang/srt/models/mixtral.py
View file @
fb9296f0
...
@@ -33,13 +33,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -33,13 +33,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
print_warning_once
from
vllm.utils
import
print_warning_once
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.controller.model_runner
import
InputMetadata
from
sglang.srt.managers.controller.model_runner
import
InputMetadata
class
MixtralMoE
(
nn
.
Module
):
class
MixtralMoE
(
nn
.
Module
):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
across all ranks.
across all ranks.
...
@@ -76,32 +74,46 @@ class MixtralMoE(nn.Module):
...
@@ -76,32 +74,46 @@ class MixtralMoE(nn.Module):
self
.
params_dtype
=
params_dtype
self
.
params_dtype
=
params_dtype
# Gate always runs at half / full precision for now.
# Gate always runs at half / full precision for now.
self
.
gate
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
gate
=
ReplicatedLinear
(
self
.
num_total_experts
,
self
.
hidden_size
,
bias
=
False
,
self
.
num_total_experts
,
params_dtype
=
self
.
params_dtype
,
bias
=
False
,
quant_config
=
None
)
params_dtype
=
self
.
params_dtype
,
quant_config
=
None
,
)
if
self
.
use_fp8
and
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
self
.
use_fp8
and
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
params_dtype
=
torch
.
float8_e4m3fn
params_dtype
=
torch
.
float8_e4m3fn
self
.
w13_weight
=
nn
.
Parameter
(
self
.
w13_weight
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
torch
.
empty
(
2
*
self
.
intermediate_size
,
self
.
num_total_experts
,
self
.
hidden_size
,
2
*
self
.
intermediate_size
,
dtype
=
params_dtype
))
self
.
hidden_size
,
dtype
=
params_dtype
,
)
)
self
.
w2_weight
=
nn
.
Parameter
(
self
.
w2_weight
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
torch
.
empty
(
self
.
hidden_size
,
self
.
num_total_experts
,
self
.
intermediate_size
,
self
.
hidden_size
,
dtype
=
params_dtype
))
self
.
intermediate_size
,
dtype
=
params_dtype
,
set_weight_attrs
(
self
.
w13_weight
,
{
)
"weight_loader"
:
self
.
weight_loader
,
)
})
set_weight_attrs
(
self
.
w2_weight
,
{
set_weight_attrs
(
"weight_loader"
:
self
.
weight_loader
,
self
.
w13_weight
,
})
{
"weight_loader"
:
self
.
weight_loader
,
},
)
set_weight_attrs
(
self
.
w2_weight
,
{
"weight_loader"
:
self
.
weight_loader
,
},
)
# Used for fp8.
# Used for fp8.
self
.
w13_scale
=
None
self
.
w13_scale
=
None
...
@@ -111,46 +123,68 @@ class MixtralMoE(nn.Module):
...
@@ -111,46 +123,68 @@ class MixtralMoE(nn.Module):
if
self
.
use_fp8
:
if
self
.
use_fp8
:
# WEIGHT_SCALE (for fp8)
# WEIGHT_SCALE (for fp8)
self
.
w13_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
self
.
w13_scale
=
nn
.
Parameter
(
dtype
=
torch
.
float32
),
torch
.
ones
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
requires_grad
=
False
,
self
.
w2_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
)
dtype
=
torch
.
float32
),
self
.
w2_scale
=
nn
.
Parameter
(
requires_grad
=
False
)
torch
.
ones
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
# process_weights_after_loading()
if
quant_config
.
is_checkpoint_fp8_serialized
:
if
quant_config
.
is_checkpoint_fp8_serialized
:
set_weight_attrs
(
self
.
w13_scale
,
{
set_weight_attrs
(
"weight_loader"
:
self
.
weight_loader
,
self
.
w13_scale
,
})
{
set_weight_attrs
(
self
.
w2_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
"weight_loader"
:
self
.
weight_loader
,
},
})
)
set_weight_attrs
(
self
.
w2_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
},
)
# ACT_SCALE (for fp8)
# ACT_SCALE (for fp8)
if
quant_config
.
activation_scheme
==
"static"
:
if
quant_config
.
activation_scheme
==
"static"
:
if
not
quant_config
.
is_checkpoint_fp8_serialized
:
if
not
quant_config
.
is_checkpoint_fp8_serialized
:
raise
ValueError
(
raise
ValueError
(
"Found static activation scheme for checkpoint that "
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
"was not serialized fp8."
self
.
a13_scale
=
nn
.
Parameter
(
torch
.
zeros
(
)
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
self
.
a13_scale
=
nn
.
Parameter
(
requires_grad
=
False
)
torch
.
zeros
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
self
.
a2_scale
=
nn
.
Parameter
(
torch
.
zeros
(
requires_grad
=
False
,
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
)
requires_grad
=
False
)
self
.
a2_scale
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
set_weight_attrs
(
self
.
a13_scale
,
{
requires_grad
=
False
,
"weight_loader"
:
self
.
weight_loader
,
)
})
set_weight_attrs
(
self
.
a2_scale
,
{
set_weight_attrs
(
"weight_loader"
:
self
.
weight_loader
,
self
.
a13_scale
,
})
{
"weight_loader"
:
self
.
weight_loader
,
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
},
weight_name
:
str
,
expert_id
:
int
):
)
set_weight_attrs
(
self
.
a2_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
},
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
expert_id
:
int
,
):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
param_data
=
param
.
data
param_data
=
param
.
data
shard_size
=
self
.
intermediate_size
shard_size
=
self
.
intermediate_size
...
@@ -158,8 +192,9 @@ class MixtralMoE(nn.Module):
...
@@ -158,8 +192,9 @@ class MixtralMoE(nn.Module):
if
weight_name
.
endswith
(
"w1.weight"
):
if
weight_name
.
endswith
(
"w1.weight"
):
param_data
[
expert_id
,
0
:
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
param_data
[
expert_id
,
0
:
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
if
weight_name
.
endswith
(
"w3.weight"
):
if
weight_name
.
endswith
(
"w3.weight"
):
param_data
[
expert_id
,
param_data
[
expert_id
,
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
shard
,
:
]
if
weight_name
.
endswith
(
"w2.weight"
):
if
weight_name
.
endswith
(
"w2.weight"
):
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
if
"act_scale"
in
weight_name
or
"weight_scale"
in
weight_name
:
if
"act_scale"
in
weight_name
or
"weight_scale"
in
weight_name
:
...
@@ -172,17 +207,17 @@ class MixtralMoE(nn.Module):
...
@@ -172,17 +207,17 @@ class MixtralMoE(nn.Module):
# If checkpoint is fp16, quantize here.
# If checkpoint is fp16, quantize here.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
w13_weight
=
torch
.
empty_like
(
self
.
w13_weight
.
data
,
w13_weight
=
torch
.
empty_like
(
dtype
=
torch
.
float8_e4m3fn
)
self
.
w13_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
w2_weight
=
torch
.
empty_like
(
self
.
w2_weight
.
data
,
)
dtype
=
torch
.
float8_e4m3fn
)
w2_weight
=
torch
.
empty_like
(
self
.
w2_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
for
expert
in
range
(
self
.
num_total_experts
):
for
expert
in
range
(
self
.
num_total_experts
):
w13_weight
[
expert
,
:,
:],
self
.
w13_scale
[
w13_weight
[
expert
,
:,
:],
self
.
w13_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
expert
]
=
ops
.
scaled_fp8_quant
(
self
.
w13_weight
.
data
[
expert
,
:,
:]
self
.
w13_weight
.
data
[
expert
,
:,
:]
)
)
w2_weight
[
expert
,
:,
:],
self
.
w2_scale
[
w2_weight
[
expert
,
:,
:],
self
.
w2_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
expert
]
=
ops
.
scaled_fp8_quant
(
self
.
w2_weight
.
data
[
expert
,
:,
:]
self
.
w2_weight
.
data
[
expert
,
:,
:]
)
)
self
.
w13_weight
=
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
self
.
w13_weight
=
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
self
.
w2_weight
=
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
self
.
w2_weight
=
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
...
@@ -193,40 +228,40 @@ class MixtralMoE(nn.Module):
...
@@ -193,40 +228,40 @@ class MixtralMoE(nn.Module):
if
self
.
a13_scale
is
None
or
self
.
a2_scale
is
None
:
if
self
.
a13_scale
is
None
or
self
.
a2_scale
is
None
:
raise
ValueError
(
raise
ValueError
(
"QuantConfig has static quantization, but found "
"QuantConfig has static quantization, but found "
"activation scales are None."
)
"activation scales are None."
)
if
(
not
all_close_1d
(
self
.
a13_scale
)
if
not
all_close_1d
(
self
.
a13_scale
)
or
not
all_close_1d
(
self
.
a2_scale
):
or
not
all_close_1d
(
self
.
a2_scale
)):
print_warning_once
(
print_warning_once
(
"Found act_scales that are not equal for fp8 MoE layer. "
"Found act_scales that are not equal for fp8 MoE layer. "
"Using the maximum across experts for each layer. "
)
"Using the maximum across experts for each layer. "
)
self
.
a13_scale
=
nn
.
Parameter
(
self
.
a13_scale
.
max
(),
self
.
a13_scale
=
nn
.
Parameter
(
self
.
a13_scale
.
max
(),
requires_grad
=
False
)
requires_grad
=
False
)
self
.
a2_scale
=
nn
.
Parameter
(
self
.
a2_scale
.
max
(),
requires_grad
=
False
)
self
.
a2_scale
=
nn
.
Parameter
(
self
.
a2_scale
.
max
(),
requires_grad
=
False
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_size
=
hidden_states
.
shape
num_tokens
,
hidden_size
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
fused_moe
(
hidden_states
,
final_hidden_states
=
fused_moe
(
self
.
w13_weight
,
hidden_states
,
self
.
w2_weight
,
self
.
w13_weight
,
router_logits
,
self
.
w2_weight
,
self
.
top_k
,
router_logits
,
renormalize
=
True
,
self
.
top_k
,
inplace
=
True
,
renormalize
=
True
,
use_fp8
=
self
.
use_fp8
,
inplace
=
True
,
w1_scale
=
self
.
w13_scale
,
use_fp8
=
self
.
use_fp8
,
w2_scale
=
self
.
w2_scale
,
w1_scale
=
self
.
w13_scale
,
a1_scale
=
self
.
a13_scale
,
w2_scale
=
self
.
w2_scale
,
a2_scale
=
self
.
a2_scale
)
a1_scale
=
self
.
a13_scale
,
a2_scale
=
self
.
a2_scale
,
)
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_size
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_size
)
...
@@ -335,7 +370,8 @@ class MixtralDecoderLayer(nn.Module):
...
@@ -335,7 +370,8 @@ class MixtralDecoderLayer(nn.Module):
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
...
@@ -444,35 +480,48 @@ class MixtralForCausalLM(nn.Module):
...
@@ -444,35 +480,48 @@ class MixtralForCausalLM(nn.Module):
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
]
expert_params_mapping
=
[
expert_params_mapping
=
(
# These are the weight scales for the experts
[
# (param_name, weight_name, expert_id)
# These are the weight scales for the experts
(
"w13_scale"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2_scale"
,
# (param_name, weight_name, expert_id)
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight_scale"
,
expert_id
)
(
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
"w13_scale"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2_scale"
,
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight_scale"
,
]
+
[
expert_id
,
# These are the weights for the experts
)
# (param_name, weight_name, expert_id)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
(
"w13_weight"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2_weight"
,
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
)
]
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
+
[
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
# These are the weights for the experts
]
+
[
# (param_name, weight_name, expert_id)
# These are the activation scales for the experts
(
# (param_name, weight_name, expert_id)
"w13_weight"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2_weight"
,
(
"a13_scale"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"a2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.act_scale"
,
expert_id
)
expert_id
,
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
]
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
+
[
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
(
"a13_scale"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"a2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.act_scale"
,
expert_id
,
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
)
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
...
@@ -490,18 +539,18 @@ class MixtralForCausalLM(nn.Module):
...
@@ -490,18 +539,18 @@ class MixtralForCausalLM(nn.Module):
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
weight_loader
(
loaded_weight
,
param
,
loaded_weight
,
weight_name
,
expert_id
=
expert_id
weight_name
,
)
expert_id
=
expert_id
)
break
break
else
:
else
:
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
default_weight_loader
)
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
...
...
python/sglang/srt/models/mixtral_quant.py
View file @
fb9296f0
...
@@ -28,7 +28,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -28,7 +28,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.controller.model_runner
import
InputMetadata
from
sglang.srt.managers.controller.model_runner
import
InputMetadata
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment