Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bf53bf51
Unverified
Commit
bf53bf51
authored
Aug 28, 2024
by
Lianmin Zheng
Committed by
GitHub
Aug 28, 2024
Browse files
[Fix] Fix llava on multi images (#1247)
parent
b1a540ec
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
254 additions
and
453 deletions
+254
-453
README.md
README.md
+1
-1
examples/frontend_language/usage/llava_video/srt_example_llava_v.py
...rontend_language/usage/llava_video/srt_example_llava_v.py
+2
-11
python/sglang/launch_server_llavavid.py
python/sglang/launch_server_llavavid.py
+26
-0
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+0
-149
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+5
-4
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+5
-5
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+74
-61
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+9
-10
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+10
-20
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+15
-6
python/sglang/srt/models/chatglm.py
python/sglang/srt/models/chatglm.py
+1
-1
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+9
-3
python/sglang/srt/models/llama2.py
python/sglang/srt/models/llama2.py
+3
-4
python/sglang/srt/models/llama_classification.py
python/sglang/srt/models/llama_classification.py
+0
-4
python/sglang/srt/models/llama_embedding.py
python/sglang/srt/models/llama_embedding.py
+3
-4
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+42
-69
python/sglang/srt/models/llavavid.py
python/sglang/srt/models/llavavid.py
+40
-86
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+3
-4
python/sglang/srt/models/yivl.py
python/sglang/srt/models/yivl.py
+2
-7
python/sglang/srt/server.py
python/sglang/srt/server.py
+4
-4
No files found.
README.md
View file @
bf53bf51
...
...
@@ -240,7 +240,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
-
Qwen / Qwen 2 / Qwen 2 MoE
-
DeepSeek / DeepSeek 2
-
[
LLaVA-OneVision
](
https://llava-vl.github.io/blog/2024-08-05-llava-onevision/
)
-
`python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava
--chunked-prefill-size=16384
`
-
`python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava`
-
Query the server with the
[
OpenAI Vision API
](
https://platform.openai.com/docs/guides/vision
)
. See examples at
[
test/srt/test_vision_openai_server.py
](
test/srt/test_vision_openai_server.py
)
-
LLaVA 1.5 / 1.6 / NeXT
-
`python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000 --tp-size=1 --chat-template=llava_llama_3`
...
...
examples/frontend_language/usage/llava_video/srt_example_llava_v.py
View file @
bf53bf51
...
...
@@ -184,13 +184,9 @@ if __name__ == "__main__":
# Parse the arguments
args
=
parser
.
parse_args
()
cur_port
=
args
.
port
cur_chunk
=
args
.
chunk_idx
num_chunks
=
args
.
num_chunks
num_frames
=
args
.
num_frames
if
"34b"
in
args
.
model_path
.
lower
():
...
...
@@ -202,7 +198,6 @@ if __name__ == "__main__":
exit
()
model_overide_args
=
{}
model_overide_args
[
"mm_spatial_pool_stride"
]
=
args
.
mm_spatial_pool_stride
model_overide_args
[
"architectures"
]
=
[
"LlavaVidForCausalLM"
]
model_overide_args
[
"num_frames"
]
=
args
.
num_frames
...
...
@@ -235,7 +230,6 @@ if __name__ == "__main__":
print
(
f
"chat template:
{
runtime
.
endpoint
.
chat_template
.
name
}
"
)
# Run a single request
# try:
print
(
"
\n
========== single ==========
\n
"
)
root
=
args
.
video_dir
if
os
.
path
.
isfile
(
root
):
...
...
@@ -257,13 +251,10 @@ if __name__ == "__main__":
)
# Calculate the average processing time
print
(
f
"Average processing time per video:
{
average_time
:.
2
f
}
seconds"
)
runtime
.
shutdown
()
# except Exception as e:
# print(e)
runtime
.
shutdown
()
# #
#
Run a batch of requests
# # Run a batch of requests
# print("\n========== batch ==========\n")
# if not os.path.exists(args.save_dir):
# os.makedirs(args.save_dir)
# batch(args.video_dir,args.save_dir,cur_chunk, num_chunks, num_frames, num_chunks)
# batch(args.video_dir,
args.save_dir,
cur_chunk, num_chunks, num_frames, num_chunks)
# runtime.shutdown()
python/sglang/launch_server_llavavid.py
0 → 100644
View file @
bf53bf51
"""Launch the inference server for Llava-video model."""
import
argparse
from
sglang.srt.server
import
ServerArgs
,
launch_server
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
ServerArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
server_args
=
ServerArgs
.
from_cli_args
(
args
)
model_overide_args
=
{}
model_overide_args
[
"mm_spatial_pool_stride"
]
=
2
model_overide_args
[
"architectures"
]
=
[
"LlavaVidForCausalLM"
]
model_overide_args
[
"num_frames"
]
=
16
model_overide_args
[
"model_type"
]
=
"llavavid"
if
model_overide_args
[
"num_frames"
]
==
32
:
model_overide_args
[
"rope_scaling"
]
=
{
"factor"
:
2.0
,
"type"
:
"linear"
}
model_overide_args
[
"max_sequence_length"
]
=
4096
*
2
model_overide_args
[
"tokenizer_model_max_length"
]
=
4096
*
2
model_overide_args
[
"model_max_length"
]
=
4096
*
2
if
"34b"
in
args
.
model_path
.
lower
():
model_overide_args
[
"image_token_index"
]
=
64002
launch_server
(
server_args
,
model_overide_args
,
None
)
python/sglang/srt/hf_transformers_utils.py
View file @
bf53bf51
...
...
@@ -119,24 +119,7 @@ def get_tokenizer(
tokenizer_revision
:
Optional
[
str
]
=
None
,
**
kwargs
,
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
if
tokenizer_name
.
endswith
(
".json"
):
return
TiktokenTokenizer
(
tokenizer_name
)
if
tokenizer_name
.
endswith
(
".model"
):
return
SentencePieceTokenizer
(
tokenizer_name
)
"""Gets a tokenizer for the given model name via Huggingface."""
if
is_multimodal_model
(
tokenizer_name
):
processor
=
get_processor
(
tokenizer_name
,
*
args
,
trust_remote_code
=
trust_remote_code
,
tokenizer_revision
=
tokenizer_revision
,
**
kwargs
,
)
tokenizer
=
processor
.
tokenizer
return
tokenizer
if
tokenizer_mode
==
"slow"
:
if
kwargs
.
get
(
"use_fast"
,
False
):
raise
ValueError
(
"Cannot use the fast tokenizer in slow tokenizer mode."
)
...
...
@@ -199,135 +182,3 @@ def get_processor(
**
kwargs
,
)
return
processor
class
TiktokenTokenizer
:
def
__init__
(
self
,
tokenizer_path
):
import
tiktoken
from
jinja2
import
Template
PAT_STR_B
=
r
"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
# Read JSON
name
=
"tmp-json"
with
open
(
tokenizer_path
,
"rb"
)
as
fin
:
tok_dict
=
json
.
load
(
fin
)
mergeable_ranks
=
{
bytes
(
item
[
"bytes"
]):
item
[
"token"
]
for
item
in
tok_dict
[
"regular_tokens"
]
}
special_tokens
=
{
bytes
(
item
[
"bytes"
]).
decode
():
item
[
"token"
]
for
item
in
tok_dict
[
"special_tokens"
]
}
assert
tok_dict
[
"word_split"
]
==
"V1"
default_allowed_special
=
None
kwargs
=
{
"name"
:
name
,
"pat_str"
:
tok_dict
.
get
(
"pat_str"
,
PAT_STR_B
),
"mergeable_ranks"
:
mergeable_ranks
,
"special_tokens"
:
special_tokens
,
}
if
"default_allowed_special"
in
tok_dict
:
default_allowed_special
=
set
(
[
bytes
(
bytes_list
).
decode
()
for
bytes_list
in
tok_dict
[
"default_allowed_special"
]
]
)
if
"vocab_size"
in
tok_dict
:
kwargs
[
"explicit_n_vocab"
]
=
tok_dict
[
"vocab_size"
]
PAD
=
"<|pad|>"
EOS
=
"<|eos|>"
SEP
=
"<|separator|>"
DEFAULT_CONTROL_TOKENS
=
{
"pad"
:
PAD
,
"sep"
:
EOS
,
"eos"
:
SEP
}
tokenizer
=
tiktoken
.
Encoding
(
**
kwargs
)
tokenizer
.
_default_allowed_special
=
default_allowed_special
or
set
()
tokenizer
.
_control_tokens
=
DEFAULT_CONTROL_TOKENS
def
encode_patched
(
self
,
text
:
str
,
*
,
allowed_special
:
Union
[
Literal
[
"all"
],
AbstractSet
[
str
]
]
=
set
(),
# noqa: B006
disallowed_special
:
Union
[
Literal
[
"all"
],
Collection
[
str
]]
=
"all"
,
)
->
List
[
int
]:
if
isinstance
(
allowed_special
,
set
):
allowed_special
|=
self
.
_default_allowed_special
return
tiktoken
.
Encoding
.
encode
(
self
,
text
,
allowed_special
=
allowed_special
,
disallowed_special
=
(),
)
tokenizer
.
encode
=
functools
.
partial
(
encode_patched
,
tokenizer
)
# Convert to HF interface
self
.
tokenizer
=
tokenizer
self
.
eos_token_id
=
tokenizer
.
_special_tokens
[
EOS
]
self
.
vocab_size
=
tokenizer
.
n_vocab
self
.
chat_template
=
Template
(
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>
\n\n
' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>
\n\n
' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>
\n\n
' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
)
def
encode
(
self
,
x
,
add_special_tokens
=
False
):
return
self
.
tokenizer
.
encode
(
x
)
def
decode
(
self
,
x
):
return
self
.
tokenizer
.
decode
(
x
)
def
batch_decode
(
self
,
batch
,
skip_special_tokens
=
True
,
spaces_between_special_tokens
=
False
):
if
isinstance
(
batch
[
0
],
int
):
batch
=
[[
x
]
for
x
in
batch
]
return
self
.
tokenizer
.
decode_batch
(
batch
)
def
apply_chat_template
(
self
,
messages
,
tokenize
,
add_generation_prompt
):
ret
=
self
.
chat_template
.
render
(
messages
=
messages
,
add_generation_prompt
=
add_generation_prompt
)
return
self
.
encode
(
ret
)
if
tokenize
else
ret
class
SentencePieceTokenizer
:
def
__init__
(
self
,
tokenizer_path
):
import
sentencepiece
as
spm
from
jinja2
import
Template
tokenizer
=
spm
.
SentencePieceProcessor
(
model_file
=
tokenizer_path
)
# Convert to HF interface
self
.
tokenizer
=
tokenizer
self
.
eos_token_id
=
tokenizer
.
eos_id
()
self
.
vocab_size
=
tokenizer
.
vocab_size
()
self
.
chat_template
=
Template
(
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>
\n\n
' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>
\n\n
' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>
\n\n
' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
)
def
encode
(
self
,
x
,
add_special_tokens
=
False
):
return
self
.
tokenizer
.
encode
(
x
)
def
decode
(
self
,
x
):
return
self
.
tokenizer
.
decode
(
x
)
def
batch_decode
(
self
,
batch
,
skip_special_tokens
=
True
,
spaces_between_special_tokens
=
False
):
if
isinstance
(
batch
[
0
],
int
):
batch
=
[[
x
]
for
x
in
batch
]
return
self
.
tokenizer
.
decode
(
batch
)
def
apply_chat_template
(
self
,
messages
,
tokenize
,
add_generation_prompt
):
ret
=
self
.
chat_template
.
render
(
messages
=
messages
,
add_generation_prompt
=
add_generation_prompt
)
return
self
.
encode
(
ret
)
if
tokenize
else
ret
python/sglang/srt/managers/io_struct.py
View file @
bf53bf51
...
...
@@ -55,6 +55,7 @@ class GenerateReqInput:
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
):
raise
ValueError
(
"Either text or input_ids should be provided."
)
if
(
isinstance
(
self
.
sampling_params
,
dict
)
and
self
.
sampling_params
.
get
(
"n"
,
1
)
!=
1
...
...
@@ -161,10 +162,10 @@ class TokenizedGenerateReqInput:
input_ids
:
List
[
int
]
# The pixel values for input images
pixel_values
:
List
[
float
]
# The hash of input images
image_hash
:
int
# The image size
image_size
:
List
[
int
]
# The hash
values
of input images
image_hash
es
:
List
[
int
]
# The image size
s
image_size
s
:
List
[
List
[
int
]
]
# The sampling parameters
sampling_params
:
SamplingParams
# Whether to return the logprobs
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
bf53bf51
...
...
@@ -121,8 +121,8 @@ class Req:
# For vision input
self
.
pixel_values
=
None
self
.
image_size
=
None
self
.
image_offset
=
None
self
.
image_size
s
=
None
self
.
image_offset
s
=
None
self
.
pad_value
=
None
# Prefix info
...
...
@@ -600,12 +600,12 @@ class ScheduleBatch:
if
req
.
pixel_values
is
not
None
:
(
req
.
origin_input_ids
,
req
.
image_offset
,
req
.
image_offset
s
,
)
=
model_runner
.
model
.
pad_input_ids
(
req
.
origin_input_ids_unpadded
,
req
.
pad_value
,
req
.
pixel_values
.
shape
,
req
.
image_size
,
req
.
pixel_values
,
req
.
image_size
s
,
)
jump_forward_reqs
.
append
(
req
)
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
bf53bf51
...
...
@@ -23,6 +23,7 @@ import multiprocessing as mp
import
os
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
fastapi
import
numpy
as
np
import
transformers
import
uvloop
...
...
@@ -96,21 +97,18 @@ class TokenizerManager:
trust_remote_code
=
server_args
.
trust_remote_code
,
model_overide_args
=
model_overide_args
,
)
self
.
is_generation
=
is_generation_model
(
self
.
hf_config
.
architectures
,
self
.
server_args
.
is_embedding
)
if
server_args
.
context_length
is
not
None
:
self
.
context_len
=
server_args
.
context_length
else
:
self
.
context_len
=
get_context_length
(
self
.
hf_config
)
self
.
context_len
=
server_args
.
context_length
or
get_context_length
(
self
.
hf_config
)
# Create tokenizer
if
server_args
.
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
processor
=
None
else
:
if
is_multimodal_model
(
self
.
model_path
):
if
is_multimodal_model
(
self
.
hf_config
.
architectures
):
self
.
processor
=
get_processor
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
...
...
@@ -118,6 +116,9 @@ class TokenizerManager:
)
self
.
tokenizer
=
self
.
processor
.
tokenizer
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
# We want to parallelize the image pre-processing so we
# create an executor for it
self
.
executor
=
concurrent
.
futures
.
ProcessPoolExecutor
(
initializer
=
init_global_processor
,
mp_context
=
mp
.
get_context
(
"fork"
),
...
...
@@ -134,12 +135,14 @@ class TokenizerManager:
self
.
to_create_loop
=
True
self
.
rid_to_state
:
Dict
[
str
,
ReqState
]
=
{}
#
f
or update model weights
#
F
or update model weights
self
.
model_update_lock
=
asyncio
.
Lock
()
self
.
model_update_result
=
None
async
def
generate_request
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
=
None
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
if
self
.
to_create_loop
:
self
.
create_handle_loop
()
...
...
@@ -160,7 +163,7 @@ class TokenizerManager:
async
def
_handle_single_request
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
index
:
Optional
[
int
]
=
None
,
is_cache_for_prefill
:
Optional
[
bool
]
=
False
,
):
...
...
@@ -182,8 +185,8 @@ class TokenizerManager:
)
if
self
.
is_generation
:
pixel_values
,
image_hash
,
image_size
=
await
self
.
_get_pixel_values
(
obj
.
image_data
pixel_values
,
image_hash
es
,
image_size
s
=
await
self
.
_get_pixel_values
(
obj
.
image_data
if
not_use_index
else
obj
.
image_data
[
index
]
)
return_logprob
=
(
obj
.
return_logprob
if
not_use_index
else
obj
.
return_logprob
[
index
]
...
...
@@ -195,7 +198,6 @@ class TokenizerManager:
)
if
return_logprob
and
logprob_start_len
==
-
1
:
logprob_start_len
=
len
(
input_ids
)
-
1
top_logprobs_num
=
(
obj
.
top_logprobs_num
if
not_use_index
...
...
@@ -238,13 +240,14 @@ class TokenizerManager:
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
[
0
])
sampling_params
.
max_new_tokens
=
0
pixel_values
,
image_hash
,
image_size
=
await
self
.
_get_pixel_values
(
pixel_values
,
image_hash
es
,
image_size
s
=
await
self
.
_get_pixel_values
(
obj
.
image_data
[
0
]
)
return_logprob
=
obj
.
return_logprob
[
0
]
logprob_start_len
=
obj
.
logprob_start_len
[
0
]
top_logprobs_num
=
obj
.
top_logprobs_num
[
0
]
# Send to the controller
if
self
.
is_generation
:
if
return_logprob
and
logprob_start_len
==
-
1
:
logprob_start_len
=
len
(
input_ids
)
-
1
...
...
@@ -253,8 +256,8 @@ class TokenizerManager:
input_text
,
input_ids
,
pixel_values
,
image_hash
,
image_size
,
image_hash
es
,
image_size
s
,
sampling_params
,
return_logprob
,
logprob_start_len
,
...
...
@@ -268,24 +271,24 @@ class TokenizerManager:
input_ids
,
sampling_params
,
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
# Recv results
event
=
asyncio
.
Event
()
state
=
ReqState
([],
False
,
event
)
self
.
rid_to_state
[
rid
]
=
state
if
not
is_cache_for_prefill
:
async
for
response
in
self
.
_wait_for_response
(
event
,
state
,
obj
,
rid
,
request
):
async
for
response
in
self
.
_wait_for_response
(
state
,
obj
,
rid
,
request
):
yield
response
else
:
assert
self
.
is_generation
await
self
.
_wait_for_cache_prefill_response
(
event
,
state
,
obj
,
rid
,
request
)
await
self
.
_wait_for_cache_prefill_response
(
state
,
obj
,
rid
,
request
)
yield
input_ids
async
def
_handle_batch_request
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
batch_size
=
obj
.
batch_size
if
self
.
is_generation
:
...
...
@@ -340,8 +343,8 @@ class TokenizerManager:
if
self
.
is_generation
:
if
obj
.
return_logprob
[
index
]
and
obj
.
logprob_start_len
[
index
]
==
-
1
:
obj
.
logprob_start_len
[
index
]
=
len
(
input_ids
)
-
1
pixel_values
,
image_hash
,
image_size
=
await
self
.
_get_pixel_values
(
obj
.
image_data
[
index
]
pixel_values
,
image_hash
es
,
image_size
s
=
(
await
self
.
_get_pixel_values
(
obj
.
image_data
[
index
]
)
)
tokenized_obj
=
TokenizedGenerateReqInput
(
...
...
@@ -349,8 +352,8 @@ class TokenizerManager:
input_text
,
input_ids
,
pixel_values
,
image_hash
,
image_size
,
image_hash
es
,
image_size
s
,
sampling_params
,
obj
.
return_logprob
[
index
],
obj
.
logprob_start_len
[
index
],
...
...
@@ -372,7 +375,6 @@ class TokenizerManager:
generators
.
append
(
self
.
_wait_for_response
(
event
,
state
,
obj
,
rid
,
...
...
@@ -388,6 +390,7 @@ class TokenizerManager:
tasks
=
[
asyncio
.
create_task
(
gen
.
__anext__
())
for
gen
in
generators
]
output_list
=
[
None
]
*
len
(
tasks
)
# Recv results
while
tasks
:
done
,
_
=
await
asyncio
.
wait
(
tasks
,
return_when
=
asyncio
.
FIRST_COMPLETED
)
...
...
@@ -426,25 +429,18 @@ class TokenizerManager:
sampling_params
.
verify
()
return
sampling_params
async
def
_get_pixel_values
(
self
,
image_data
):
if
image_data
is
None
:
return
None
,
None
,
None
else
:
return
await
self
.
_get_pixel_values_internal
(
image_data
)
async
def
_wait_for_response
(
self
,
event
:
asyncio
.
Event
,
state
:
ReqState
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
rid
:
str
,
request
,
index
:
int
=
None
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
index
:
Optional
[
int
]
=
None
,
response_index
:
int
=
0
,
):
while
True
:
try
:
await
asyncio
.
wait_for
(
event
.
wait
(),
timeout
=
4
)
await
asyncio
.
wait_for
(
state
.
event
.
wait
(),
timeout
=
4
)
except
asyncio
.
TimeoutError
:
if
request
is
not
None
and
await
request
.
is_disconnected
():
for
rid
in
[
obj
.
rid
]
if
obj
.
is_single
else
obj
.
rid
:
...
...
@@ -478,16 +474,15 @@ class TokenizerManager:
yield
out
break
event
.
clear
()
state
.
event
.
clear
()
yield
out
async
def
_wait_for_cache_prefill_response
(
self
,
event
:
asyncio
.
Event
,
state
:
ReqState
,
obj
:
GenerateReqInput
,
rid
:
str
,
request
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
while
True
:
try
:
...
...
@@ -514,7 +509,9 @@ class TokenizerManager:
req
=
AbortReq
(
rid
)
self
.
send_to_router
.
send_pyobj
(
req
)
async
def
update_weights
(
self
,
obj
:
UpdateWeightReqInput
,
request
):
async
def
update_weights
(
self
,
obj
:
UpdateWeightReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
):
if
self
.
to_create_loop
:
self
.
create_handle_loop
()
...
...
@@ -659,12 +656,11 @@ class TokenizerManager:
)
return
top_logprobs
async
def
_get_pixel_values_internal
(
self
,
image_data
,
aspect_ratio
=
None
):
aspect_ratio
=
(
getattr
(
self
.
hf_config
,
"image_aspect_ratio"
,
None
)
if
aspect_ratio
is
None
else
aspect_ratio
)
async
def
_get_pixel_values
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]]):
if
not
image_data
:
return
None
,
None
,
None
aspect_ratio
=
getattr
(
self
.
hf_config
,
"image_aspect_ratio"
,
None
)
grid_pinpoints
=
(
self
.
hf_config
.
image_grid_pinpoints
if
hasattr
(
self
.
hf_config
,
"image_grid_pinpoints"
)
...
...
@@ -673,35 +669,42 @@ class TokenizerManager:
)
if
isinstance
(
image_data
,
list
)
and
len
(
image_data
)
>
0
:
pixel_values
,
image_hash
,
image_size
=
[],
[],
[]
# Multiple images
if
len
(
image_data
)
>
1
:
aspect_ratio
=
"pad"
# LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values
,
image_hashes
,
image_sizes
=
[],
[],
[]
for
img_data
in
image_data
:
pixel_v
,
image_h
,
image_s
=
await
self
.
_process_single_image
(
img_data
,
aspect_ratio
,
grid_pinpoints
)
pixel_values
.
append
(
pixel_v
)
image_hash
.
append
(
image_h
)
image_size
.
append
(
image_s
)
pixel_values
=
np
.
stack
(
pixel_values
,
axis
=
0
)
image_hashes
.
append
(
image_h
)
image_sizes
.
append
(
image_s
)
if
isinstance
(
pixel_values
[
0
],
np
.
ndarray
):
pixel_values
=
np
.
stack
(
pixel_values
,
axis
=
0
)
else
:
# A single image
pixel_values
,
image_hash
,
image_size
=
await
self
.
_process_single_image
(
image_data
[
0
],
aspect_ratio
,
grid_pinpoints
)
image_hash
=
[
image_hash
]
image_size
=
[
image_size
]
image_hash
es
=
[
image_hash
]
image_size
s
=
[
image_size
]
elif
isinstance
(
image_data
,
str
):
# A single image
pixel_values
,
image_hash
,
image_size
=
await
self
.
_process_single_image
(
image_data
,
aspect_ratio
,
grid_pinpoints
)
image_hash
=
[
image_hash
]
image_size
=
[
image_size
]
image_hash
es
=
[
image_hash
]
image_size
s
=
[
image_size
]
else
:
pixel_values
,
image_hash
,
image_size
=
None
,
None
,
None
raise
ValueError
(
f
"Invalid image data:
{
image_data
}
"
)
return
pixel_values
,
image_hash
,
image_size
return
pixel_values
,
image_hash
es
,
image_size
s
async
def
_process_single_image
(
self
,
image_data
,
aspect_ratio
,
grid_pinpoints
):
async
def
_process_single_image
(
self
,
image_data
:
Union
[
bytes
,
str
],
aspect_ratio
:
str
,
grid_pinpoints
:
str
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
return
await
loop
.
run_in_executor
(
...
...
@@ -732,12 +735,16 @@ def init_global_processor(server_args: ServerArgs):
def
_process_single_image_task
(
image_data
,
image_aspect_ratio
=
None
,
image_grid_pinpoints
=
None
,
processor
=
None
image_data
:
Union
[
str
,
bytes
],
image_aspect_ratio
:
Optional
[
str
]
=
None
,
image_grid_pinpoints
:
Optional
[
str
]
=
None
,
processor
=
None
,
):
try
:
processor
=
processor
or
global_processor
image
,
image_size
=
load_image
(
image_data
)
if
image_size
is
not
None
:
# It is a video with multiple images
image_hash
=
hash
(
image_data
)
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
]
for
_
in
range
(
len
(
pixel_values
)):
...
...
@@ -745,6 +752,7 @@ def _process_single_image_task(
pixel_values
=
np
.
stack
(
pixel_values
,
axis
=
0
)
return
pixel_values
,
image_hash
,
image_size
else
:
# It is an image
image_hash
=
hash
(
image_data
)
if
image_aspect_ratio
==
"pad"
:
image
=
expand2square
(
...
...
@@ -754,13 +762,18 @@ def _process_single_image_task(
pixel_values
=
processor
.
image_processor
(
image
.
convert
(
"RGB"
))[
"pixel_values"
][
0
]
elif
image_aspect_ratio
==
"anyres"
or
"anyres_max"
in
image_aspect_ratio
:
elif
image_aspect_ratio
==
"anyres"
or
(
image_aspect_ratio
is
not
None
and
"anyres_max"
in
image_aspect_ratio
):
pixel_values
=
process_anyres_image
(
image
,
processor
.
image_processor
,
image_grid_pinpoints
)
else
:
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
][
0
]
pixel_values
=
pixel_values
.
astype
(
np
.
float16
)
if
isinstance
(
pixel_values
,
np
.
ndarray
):
pixel_values
=
pixel_values
.
astype
(
np
.
float16
)
return
pixel_values
,
image_hash
,
image
.
size
except
Exception
:
logger
.
error
(
"Exception in TokenizerManager:
\n
"
+
get_exception_traceback
())
python/sglang/srt/managers/tp_worker.py
View file @
bf53bf51
...
...
@@ -108,7 +108,7 @@ class ModelTpServer:
if
server_args
.
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
processor
=
None
else
:
if
is_multimodal_model
(
se
rver_args
.
model_path
):
if
is_multimodal_model
(
se
lf
.
model_config
.
hf_config
.
architectures
):
self
.
processor
=
get_processor
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
...
...
@@ -333,26 +333,24 @@ class ModelTpServer:
if
self
.
model_runner
.
is_generation
:
req
.
pixel_values
=
recv_req
.
pixel_values
if
req
.
pixel_values
is
not
None
:
image_hash
=
(
hash
(
tuple
(
recv_req
.
image_hash
))
if
isinstance
(
recv_req
.
image_hash
,
list
)
else
recv_req
.
image_hash
)
# Use image hash as fake token_ids, which is then used
# for prefix matching
image_hash
=
hash
(
tuple
(
recv_req
.
image_hashes
))
req
.
pad_value
=
[
(
image_hash
)
%
self
.
model_config
.
vocab_size
,
(
image_hash
>>
16
)
%
self
.
model_config
.
vocab_size
,
(
image_hash
>>
32
)
%
self
.
model_config
.
vocab_size
,
(
image_hash
>>
64
)
%
self
.
model_config
.
vocab_size
,
]
req
.
image_size
=
recv_req
.
image_size
req
.
image_size
s
=
recv_req
.
image_size
s
(
req
.
origin_input_ids
,
req
.
image_offset
,
req
.
image_offset
s
,
)
=
self
.
model_runner
.
model
.
pad_input_ids
(
req
.
origin_input_ids_unpadded
,
req
.
pad_value
,
req
.
pixel_values
.
shape
,
req
.
image_size
,
req
.
pixel_values
,
req
.
image_size
s
,
)
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
logprob_start_len
=
recv_req
.
logprob_start_len
...
...
@@ -368,6 +366,7 @@ class ModelTpServer:
req
.
jump_forward_map
=
self
.
jump_forward_cache
.
query
(
computed_regex_string
)
# Init regex fsm
elif
req
.
sampling_params
.
regex
is
not
None
:
req
.
regex_fsm
=
self
.
regex_fsm_cache
.
query
(
req
.
sampling_params
.
regex
)
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
bf53bf51
...
...
@@ -16,7 +16,7 @@ limitations under the License.
"""ModelRunner runs the forward passes of the models."""
from
dataclasses
import
dataclass
from
enum
import
IntEnum
,
auto
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
List
import
numpy
as
np
import
torch
...
...
@@ -58,6 +58,7 @@ class InputMetadata:
# For extend
extend_seq_lens
:
torch
.
Tensor
=
None
extend_prefix_lens
:
torch
.
Tensor
=
None
extend_start_loc
:
torch
.
Tensor
=
None
extend_no_prefix
:
bool
=
None
...
...
@@ -69,8 +70,8 @@ class InputMetadata:
# For multimodal
pixel_values
:
List
[
torch
.
Tensor
]
=
None
image_sizes
:
List
[
List
[
int
]]
=
None
image_offsets
:
List
[
int
]
=
None
image_sizes
:
List
[
List
[
List
[
int
]]
]
=
None
image_offsets
:
List
[
List
[
int
]
]
=
None
# Trition attention backend
triton_max_seq_len
:
int
=
0
...
...
@@ -87,20 +88,8 @@ class InputMetadata:
def
init_multimuldal_info
(
self
,
batch
:
ScheduleBatch
):
reqs
=
batch
.
reqs
self
.
pixel_values
=
[
r
.
pixel_values
for
r
in
reqs
]
self
.
image_sizes
=
[
r
.
image_size
for
r
in
reqs
]
self
.
image_offsets
=
[]
for
r
in
reqs
:
if
isinstance
(
r
.
image_offset
,
list
):
self
.
image_offsets
.
append
(
[
(
image_offset
-
len
(
r
.
prefix_indices
))
for
image_offset
in
r
.
image_offset
]
)
elif
isinstance
(
r
.
image_offset
,
int
):
self
.
image_offsets
.
append
(
r
.
image_offset
-
len
(
r
.
prefix_indices
))
elif
r
.
image_offset
is
None
:
self
.
image_offsets
.
append
(
0
)
self
.
image_sizes
=
[
r
.
image_sizes
for
r
in
reqs
]
self
.
image_offsets
=
[
r
.
image_offsets
for
r
in
reqs
]
def
compute_positions
(
self
,
batch
:
ScheduleBatch
):
position_ids_offsets
=
batch
.
position_ids_offsets
...
...
@@ -153,6 +142,7 @@ class InputMetadata:
for
i
,
r
in
enumerate
(
batch
.
reqs
)
]
self
.
extend_seq_lens
=
torch
.
tensor
(
extend_lens_cpu
,
device
=
"cuda"
)
self
.
extend_prefix_lens
=
torch
.
tensor
(
batch
.
prefix_lens_cpu
,
device
=
"cuda"
)
self
.
extend_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
)
self
.
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
[:
-
1
],
dim
=
0
)
self
.
extend_no_prefix
=
all
(
l
==
0
for
l
in
batch
.
prefix_lens_cpu
)
...
...
@@ -238,10 +228,10 @@ class InputMetadata:
prefix_lens_cpu
,
flashinfer_use_ragged
,
):
if
self
.
forward_mode
!=
ForwardMode
.
DECODE
:
prefix_lens
=
torch
.
tensor
(
prefix_lens_cpu
,
device
=
"cuda"
)
else
:
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
prefix_lens
=
None
else
:
prefix_lens
=
self
.
extend_prefix_lens
update_flashinfer_indices
(
self
.
forward_mode
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
bf53bf51
...
...
@@ -50,7 +50,7 @@ from sglang.srt.mem_cache.memory_pool import (
MLATokenToKVPool
,
ReqToTokenPool
,
)
from
sglang.srt.model_config
import
AttentionArch
from
sglang.srt.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
...
...
@@ -69,7 +69,7 @@ logger = logging.getLogger(__name__)
class
ModelRunner
:
def
__init__
(
self
,
model_config
,
model_config
:
ModelConfig
,
mem_fraction_static
:
float
,
gpu_id
:
int
,
tp_rank
:
int
,
...
...
@@ -85,7 +85,9 @@ class ModelRunner:
self
.
tp_size
=
tp_size
self
.
nccl_port
=
nccl_port
self
.
server_args
=
server_args
self
.
is_multimodal_model
=
is_multimodal_model
(
self
.
model_config
)
self
.
is_multimodal_model
=
is_multimodal_model
(
self
.
model_config
.
hf_config
.
architectures
)
global_server_args_dict
.
update
(
{
"disable_flashinfer"
:
server_args
.
disable_flashinfer
,
...
...
@@ -95,6 +97,13 @@ class ModelRunner:
}
)
if
self
.
is_multimodal_model
:
logger
.
info
(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args
.
chunked_prefill_size
=
None
server_args
.
mem_fraction_static
*=
0.95
min_per_gpu_memory
=
self
.
init_torch_distributed
()
self
.
load_model
()
self
.
init_memory_pool
(
...
...
@@ -507,9 +516,9 @@ class ModelRunner:
raise
Exception
(
f
"Capture cuda graph failed:
{
e
}
\n
"
"Possible solutions:
\n
"
"1. disable
torch compile by not using --enable-torch-compile
\n
"
"2.
disable cuda graph by --disable-cuda-graph
\n
"
"3.
set --mem-fraction-static to a smaller valu
e
\n
"
"1. disable
cuda graph by --disable-cuda-graph
\n
"
"2.
set --mem-fraction-static to a smaller value
\n
"
"3.
disable torch compile by not using --enable-torch-compil
e
\n
"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
\n
"
)
...
...
python/sglang/srt/models/chatglm.py
View file @
bf53bf51
...
...
@@ -17,7 +17,7 @@ limitations under the License.
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
from
torch
import
nn
...
...
python/sglang/srt/models/grok.py
View file @
bf53bf51
...
...
@@ -273,9 +273,9 @@ class Grok1Model(nn.Module):
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
.
mul_
(
self
.
config
.
embedding_multiplier_scale
)
else
:
hidden_states
=
input_embeds
hidden_states
.
mul_
(
self
.
config
.
embedding_multiplier_scale
)
for
i
in
range
(
len
(
self
.
layers
)):
hidden_states
=
self
.
layers
[
i
](
positions
,
hidden_states
,
input_metadata
)
...
...
@@ -284,7 +284,7 @@ class Grok1Model(nn.Module):
return
hidden_states
class
Grok1
Model
ForCausalLM
(
nn
.
Module
):
class
Grok1ForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
...
...
@@ -415,4 +415,10 @@ def _prepare_presharded_weights(
return
hf_folder
,
hf_weights_files
,
use_safetensors
EntryClass
=
Grok1ModelForCausalLM
class
Grok1ModelForCausalLM
(
Grok1ForCausalLM
):
"""An alias for backward-compatbility."""
pass
EntryClass
=
[
Grok1ForCausalLM
,
Grok1ModelForCausalLM
]
python/sglang/srt/models/llama2.py
View file @
bf53bf51
...
...
@@ -357,6 +357,9 @@ class LlamaForCausalLM(nn.Module):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
return
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
return
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
...
...
@@ -364,8 +367,6 @@ class LlamaForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
...
...
@@ -374,8 +375,6 @@ class LlamaForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
return
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
return
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
...
...
python/sglang/srt/models/llama_classification.py
View file @
bf53bf51
...
...
@@ -103,8 +103,6 @@ class LlamaForClassification(nn.Module):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
...
...
@@ -113,8 +111,6 @@ class LlamaForClassification(nn.Module):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
...
...
python/sglang/srt/models/llama_embedding.py
View file @
bf53bf51
...
...
@@ -57,6 +57,9 @@ class LlamaEmbeddingModel(nn.Module):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
return
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
return
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
...
...
@@ -64,8 +67,6 @@ class LlamaEmbeddingModel(nn.Module):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
...
...
@@ -74,8 +75,6 @@ class LlamaEmbeddingModel(nn.Module):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
return
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
return
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
...
...
python/sglang/srt/models/llava.py
View file @
bf53bf51
...
...
@@ -28,7 +28,6 @@ from transformers import (
LlavaConfig
,
MistralConfig
,
Qwen2Config
,
SiglipVisionConfig
,
SiglipVisionModel
,
)
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
...
...
@@ -66,13 +65,18 @@ class LlavaLlamaForCausalLM(nn.Module):
torch
.
empty
(
config
.
text_config
.
hidden_size
,
dtype
=
torch
.
float16
)
)
def
pad_input_ids
(
self
,
input_ids
,
pad_value
,
pt_shape
=
None
,
image_size
=
None
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
pad_value
:
List
[
int
],
pixel_values
:
List
,
image_sizes
:
List
[
List
[
int
]],
):
# hardcode for spatial_unpad + anyres
image_aspect_ratio
=
"anyres"
if
len
(
image_size
)
==
1
else
"pad"
image_aspect_ratio
=
"anyres"
if
len
(
image_size
s
)
==
1
else
"pad"
offset_list
=
[]
for
image_s
in
image_size
:
if
len
(
image_size
)
>
16
:
for
image_s
in
image_size
s
:
if
len
(
image_size
s
)
>
16
:
# 2x2 pooling with stride 2
new_image_feature_len
=
(
math
.
ceil
(
self
.
image_size
/
self
.
patch_size
/
2
)
**
2
...
...
@@ -153,17 +157,15 @@ class LlavaLlamaForCausalLM(nn.Module):
if
input_metadata
.
forward_mode
==
ForwardMode
.
EXTEND
:
bs
=
input_metadata
.
batch_size
# Embed text input
# Embed text input
s
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
# Embed vision input
need_vision
=
(
(
positions
[
input_metadata
.
extend_start_loc
]
<
self
.
image_feature_len
)
.
cpu
()
.
numpy
()
# Whether the requests need vision inputs
max_image_offset
=
np
.
array
(
[
max
(
image_offsets
[
i
])
if
image_offsets
[
i
]
else
-
1
for
i
in
range
(
bs
)]
)
# FIXME: We need to substract the length of the system prompt
has_pixel
=
np
.
array
([
pixel_values
[
i
]
is
not
None
for
i
in
range
(
bs
)])
need_vision
=
need_vision
&
has_pixel
start_positions
=
positions
[
input_metadata
.
extend_start_loc
].
cpu
().
numpy
()
need_vision
=
start_positions
<=
max_image_offset
if
need_vision
.
any
():
pixel_values
=
[
pixel_values
[
i
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]]
...
...
@@ -332,31 +334,35 @@ class LlavaLlamaForCausalLM(nn.Module):
new_image_features
.
append
(
image_feature
)
image_features
=
new_image_features
# Fill in the placeholder for the image
extend_start_loc_cpu
=
input_metadata
.
extend_start_loc
.
cpu
().
numpy
()
prefix_lens_cpu
=
input_metadata
.
extend_prefix_lens
.
cpu
().
numpy
()
pt
=
0
for
i
in
range
(
bs
):
if
not
need_vision
[
i
]:
continue
start_idx
=
extend_start_loc_cpu
[
i
]
pad_dim
=
image_features
[
pt
].
shape
[
-
1
]
# 576, 4096
dim
=
input_embeds
.
shape
[
1
]
assert
(
pad_dim
==
dim
),
"invalid pad_dim={}, input_embed_dim={}!"
.
format
(
pad_dim
,
dim
)
# Fill in the placeholder for the image
try
:
for
j
,
image_off
in
enumerate
(
image_offsets
[
i
]):
# print("actual image_features length: ", image_features[pt][j].shape[0])
pad_len
=
image_features
[
pt
][
j
].
shape
[
0
]
input_embeds
[
start_idx
+
image_off
:
start_idx
+
image_off
+
pad_len
]
=
image_features
[
pt
][
j
]
except
RuntimeError
as
e
:
print
(
f
"RuntimeError in llava image encoding:
{
e
}
"
)
print
(
image_features
[
pt
].
shape
)
print
(
input_embeds
.
shape
)
print
(
start_idx
,
image_offsets
[
i
])
prefix_len
=
prefix_lens_cpu
[
i
]
# Multiple images
for
j
,
image_offset
in
enumerate
(
image_offsets
[
i
]):
if
image_offset
<
prefix_len
:
continue
tmp_image_feature
=
image_features
[
pt
][
j
]
pad_len
=
tmp_image_feature
.
shape
[
0
]
left_idx
=
start_idx
+
(
image_offset
-
prefix_len
)
right_idx
=
start_idx
+
(
image_offset
-
prefix_len
)
+
pad_len
try
:
input_embeds
[
left_idx
:
right_idx
]
=
tmp_image_feature
except
RuntimeError
as
e
:
print
(
f
"RuntimeError in image encoding:
{
e
}
"
)
print
(
f
"
{
input_embeds
.
shape
=
}
,
{
tmp_image_feature
.
shape
=
}
"
)
print
(
f
"
{
start_idx
=
}
,
{
image_offset
=
}
,
{
prefix_len
=
}
,
{
pad_len
=
}
"
)
pt
+=
1
return
self
.
language_model
(
...
...
@@ -366,8 +372,9 @@ class LlavaLlamaForCausalLM(nn.Module):
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
# Load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
# We put the initialization here instead of __init__ to allow it being reused by other subclasses.
vision_path
=
self
.
config
.
mm_vision_tower
if
"clip"
in
vision_path
:
self
.
vision_tower
=
CLIPVisionModel
.
from_pretrained
(
...
...
@@ -422,8 +429,6 @@ class LlavaLlamaForCausalLM(nn.Module):
# load language model
self
.
language_model
.
load_weights
(
weights
)
monkey_path_clip_vision_embed_forward
()
@
property
def
num_patches_per_side
(
self
):
return
self
.
image_size
//
self
.
patch_size
...
...
@@ -495,36 +500,4 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
)
first_call
=
True
def
clip_vision_embed_forward
(
self
,
pixel_values
:
torch
.
FloatTensor
)
->
torch
.
Tensor
:
batch_size
=
pixel_values
.
shape
[
0
]
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
global
first_call
if
first_call
:
self
.
patch_embedding
.
cpu
().
float
()
first_call
=
False
pixel_values
=
pixel_values
.
to
(
dtype
=
torch
.
float32
,
device
=
"cpu"
)
patch_embeds
=
self
.
patch_embedding
(
pixel_values
).
cuda
().
half
()
patch_embeds
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
class_embeds
=
self
.
class_embedding
.
expand
(
batch_size
,
1
,
-
1
)
embeddings
=
torch
.
cat
([
class_embeds
,
patch_embeds
],
dim
=
1
)
embeddings
=
embeddings
+
self
.
position_embedding
(
self
.
position_ids
)
return
embeddings
def
monkey_path_clip_vision_embed_forward
():
import
transformers
setattr
(
transformers
.
models
.
clip
.
modeling_clip
.
CLIPVisionEmbeddings
,
"forward"
,
clip_vision_embed_forward
,
)
EntryClass
=
[
LlavaLlamaForCausalLM
,
LlavaQwenForCausalLM
,
LlavaMistralForCausalLM
]
python/sglang/srt/models/llavavid.py
View file @
bf53bf51
...
...
@@ -26,11 +26,6 @@ from vllm.config import CacheConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.mm_utils
import
(
get_anyres_image_grid_shape
,
unpad_image
,
unpad_image_shape
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.models.llama2
import
LlamaForCausalLM
...
...
@@ -59,23 +54,14 @@ class LlavaVidForCausalLM(nn.Module):
torch
.
empty
(
config
.
text_config
.
hidden_size
,
dtype
=
torch
.
float16
)
)
def
pad_input_ids
(
self
,
input_ids
,
pad_value
,
pt_shape
=
None
,
image_size
=
None
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
pad_value
:
List
[
int
],
pixel_values
:
List
,
image_sizes
:
List
[
List
[
int
]],
):
new_image_feature_len
=
self
.
image_feature_len
# now only support spatial_unpad + anyres
# if self.mm_patch_merge_type.startswith("spatial"):
# height = width = self.num_patches_per_side
# if pt_shape[0] > 1:
# if self.image_aspect_ratio == "anyres":
# num_patch_width, num_patch_height = get_anyres_image_grid_shape(
# image_size,
# self.image_grid_pinpoints,
# self.vision_tower.config.image_size,
# )
# if "unpad" in self.mm_patch_merge_type:
# h = num_patch_height * height
# w = num_patch_width * width
# new_h, new_w = unpad_image_shape(h, w, image_size)
# new_image_feature_len += new_h * (new_w + 1)
pad_ids
=
pad_value
*
(
(
new_image_feature_len
+
len
(
pad_value
))
//
len
(
pad_value
)
...
...
@@ -87,7 +73,7 @@ class LlavaVidForCausalLM(nn.Module):
+
pad_ids
[:
new_image_feature_len
]
+
input_ids
[
offset
+
1
:]
)
return
new_input_ids
,
offset
return
new_input_ids
,
[
offset
]
def
encode_images
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
image_outputs
=
self
.
vision_tower
(
pixel_values
,
output_hidden_states
=
True
)
...
...
@@ -133,22 +119,18 @@ class LlavaVidForCausalLM(nn.Module):
if
input_metadata
.
forward_mode
==
ForwardMode
.
EXTEND
:
bs
=
input_metadata
.
batch_size
# Embed text input
# Embed text input
s
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
# Embed vision input
need_vision
=
(
(
positions
[
input_metadata
.
extend_start_loc
]
<
self
.
image_feature_len
)
.
cpu
()
.
numpy
()
# Whether the requests need vision inputs
max_image_offset
=
np
.
array
(
[
max
(
image_offsets
[
i
])
if
image_offsets
[
i
]
else
-
1
for
i
in
range
(
bs
)]
)
# FIXME: We need to substract the length of the system prompt
has_pixel
=
np
.
array
([
pixel_values
[
i
]
is
not
None
for
i
in
range
(
bs
)])
need_vision
=
need_vision
&
has_pixel
start_positions
=
positions
[
input_metadata
.
extend_start_loc
].
cpu
().
numpy
()
need_vision
=
start_positions
<=
max_image_offset
if
need_vision
.
any
():
pixel_values
=
[
pixel_values
[
i
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]]
image_sizes
=
[
image_sizes
[
i
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]]
########## Encode Image ########
...
...
@@ -183,31 +165,36 @@ class LlavaVidForCausalLM(nn.Module):
new_image_features
.
append
(
image_feature
.
flatten
(
0
,
1
))
image_features
=
new_image_features
# Fill in the placeholder for the image
extend_start_loc_cpu
=
input_metadata
.
extend_start_loc
.
cpu
().
numpy
()
prefix_lens_cpu
=
input_metadata
.
extend_prefix_lens
.
cpu
().
numpy
()
pt
=
0
for
i
in
range
(
bs
):
if
not
need_vision
[
i
]:
continue
start_idx
=
extend_start_loc_cpu
[
i
]
pad_len
,
pad_dim
=
image_features
[
pt
].
shape
# 576, 4096
dim
=
input_embeds
.
shape
[
1
]
assert
(
pad_dim
==
dim
),
"invalid pad_dim={}, input_embed_dim={}!"
.
format
(
pad_dim
,
dim
)
# Fill in the placeholder for the image
try
:
input_embeds
[
start_idx
+
image_offsets
[
i
]
:
start_idx
+
image_offsets
[
i
]
+
pad_len
]
=
image_features
[
pt
]
except
RuntimeError
as
e
:
print
(
f
"RuntimeError in llava image encoding:
{
e
}
"
)
print
(
input_embeds
.
shape
)
print
(
start_idx
,
image_offsets
[
i
])
pt
+=
1
prefix_len
=
prefix_lens_cpu
[
i
]
# Multiple images
for
image_offset
in
image_offsets
[
i
]:
if
image_offset
<
prefix_len
:
continue
tmp_image_feature
=
image_features
[
pt
]
pad_len
=
tmp_image_feature
.
shape
[
0
]
left_idx
=
start_idx
+
(
image_offset
-
prefix_len
)
right_idx
=
start_idx
+
(
image_offset
-
prefix_len
)
+
pad_len
try
:
input_embeds
[
left_idx
:
right_idx
]
=
tmp_image_feature
except
RuntimeError
as
e
:
print
(
f
"RuntimeError in image encoding:
{
e
}
"
)
print
(
f
"
{
input_embeds
.
shape
=
}
,
{
tmp_image_feature
.
shape
=
}
"
)
print
(
f
"
{
start_idx
=
}
,
{
image_offset
=
}
,
{
prefix_len
=
}
,
{
pad_len
=
}
"
)
pt
+=
1
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
=
input_embeds
...
...
@@ -216,8 +203,9 @@ class LlavaVidForCausalLM(nn.Module):
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
# Load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
# We put the initialization here instead of __init__ to allow it being reused by other subclasses.
vision_path
=
self
.
config
.
mm_vision_tower
self
.
vision_tower
=
CLIPVisionModel
.
from_pretrained
(
vision_path
,
torch_dtype
=
torch
.
float16
...
...
@@ -271,43 +259,9 @@ class LlavaVidForCausalLM(nn.Module):
# load language model
self
.
language_model
.
load_weights
(
weights
)
monkey_path_clip_vision_embed_forward
()
@
property
def
num_patches_per_side
(
self
):
return
self
.
image_size
//
self
.
patch_size
first_call
=
True
def
clip_vision_embed_forward
(
self
,
pixel_values
:
torch
.
FloatTensor
)
->
torch
.
Tensor
:
batch_size
=
pixel_values
.
shape
[
0
]
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
global
first_call
if
first_call
:
self
.
patch_embedding
.
cpu
().
float
()
first_call
=
False
pixel_values
=
pixel_values
.
to
(
dtype
=
torch
.
float32
,
device
=
"cpu"
)
patch_embeds
=
self
.
patch_embedding
(
pixel_values
).
cuda
().
half
()
patch_embeds
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
class_embeds
=
self
.
class_embedding
.
expand
(
batch_size
,
1
,
-
1
)
embeddings
=
torch
.
cat
([
class_embeds
,
patch_embeds
],
dim
=
1
)
embeddings
=
embeddings
+
self
.
position_embedding
(
self
.
position_ids
)
return
embeddings
def
monkey_path_clip_vision_embed_forward
():
import
transformers
setattr
(
transformers
.
models
.
clip
.
modeling_clip
.
CLIPVisionEmbeddings
,
"forward"
,
clip_vision_embed_forward
,
)
EntryClass
=
LlavaVidForCausalLM
python/sglang/srt/models/qwen2.py
View file @
bf53bf51
...
...
@@ -312,6 +312,9 @@ class Qwen2ForCausalLM(nn.Module):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
...
...
@@ -319,8 +322,6 @@ class Qwen2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
...
...
@@ -329,8 +330,6 @@ class Qwen2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
...
...
python/sglang/srt/models/yivl.py
View file @
bf53bf51
...
...
@@ -24,10 +24,7 @@ from vllm.config import CacheConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.llava
import
(
LlavaLlamaForCausalLM
,
monkey_path_clip_vision_embed_forward
,
)
from
sglang.srt.models.llava
import
LlavaLlamaForCausalLM
class
YiVLForCausalLM
(
LlavaLlamaForCausalLM
):
...
...
@@ -50,7 +47,7 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
self
.
config
.
_name_or_path
,
torch_dtype
=
torch
.
float16
,
subfolder
=
self
.
vision_tower_subfolder
,
).
cuda
(
)
).
to
(
"
cuda
"
)
self
.
vision_tower
.
eval
()
...
...
@@ -94,8 +91,6 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
# load language model
self
.
language_model
.
load_weights
(
weights
)
monkey_path_clip_vision_embed_forward
()
class
YiVLMultiModalProjector
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LlavaConfig
):
...
...
python/sglang/srt/server.py
View file @
bf53bf51
...
...
@@ -335,12 +335,12 @@ def launch_server(
pipe_detoken_reader
,
pipe_detoken_writer
=
mp
.
Pipe
(
duplex
=
False
)
if
server_args
.
dp_size
==
1
:
start_process
=
start_controller_process_single
start_
controller_
process
=
start_controller_process_single
else
:
start_process
=
start_controller_process_multi
start_
controller_
process
=
start_controller_process_multi
proc_controller
=
mp
.
Process
(
target
=
start_process
,
target
=
start_
controller_
process
,
args
=
(
server_args
,
port_args
,
pipe_controller_writer
,
model_overide_args
),
)
proc_controller
.
start
()
...
...
@@ -421,7 +421,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if
not
server_args
.
disable_flashinfer
:
assert_pkg_version
(
"flashinfer"
,
"0.1.
6
"
,
"0.1.
5
"
,
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html."
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment