Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0992d85f
Unverified
Commit
0992d85f
authored
May 14, 2024
by
Yuanhan Zhang
Committed by
GitHub
May 13, 2024
Browse files
support llava video (#426)
parent
5dc55a5f
Changes
37
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
476 additions
and
39 deletions
+476
-39
.gitignore
.gitignore
+4
-0
README.md
README.md
+1
-1
examples/quick_start/srt_example_llava.py
examples/quick_start/srt_example_llava.py
+2
-2
examples/usage/llava_video/srt_example_llava_v.py
examples/usage/llava_video/srt_example_llava_v.py
+208
-0
examples/usage/llava_video/srt_example_llava_v.sh
examples/usage/llava_video/srt_example_llava_v.sh
+130
-0
examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4
examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4
+0
-0
python/pyproject.toml
python/pyproject.toml
+2
-2
python/sglang/__init__.py
python/sglang/__init__.py
+2
-0
python/sglang/api.py
python/sglang/api.py
+5
-0
python/sglang/lang/chat_template.py
python/sglang/lang/chat_template.py
+9
-2
python/sglang/lang/interpreter.py
python/sglang/lang/interpreter.py
+14
-1
python/sglang/lang/ir.py
python/sglang/lang/ir.py
+9
-0
python/sglang/lang/tracer.py
python/sglang/lang/tracer.py
+1
-1
python/sglang/launch_server.py
python/sglang/launch_server.py
+1
-2
python/sglang/launch_server_llavavid.py
python/sglang/launch_server_llavavid.py
+31
-0
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+8
-1
python/sglang/srt/managers/router/manager.py
python/sglang/srt/managers/router/manager.py
+2
-4
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+19
-6
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+2
-2
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+26
-15
No files found.
.gitignore
View file @
0992d85f
...
...
@@ -177,3 +177,7 @@ tmp*.txt
# Plots
*.png
*.pdf
# personnal
work_dirs/
*.csv
README.md
View file @
0992d85f
examples/quick_start/srt_example_llava.py
View file @
0992d85f
...
...
@@ -14,7 +14,7 @@ def single():
state
=
image_qa
.
run
(
image_path
=
"images/cat.jpeg"
,
question
=
"What is this?"
,
max_new_tokens
=
64
)
max_new_tokens
=
128
)
print
(
state
[
"answer"
],
"
\n
"
)
...
...
@@ -36,7 +36,7 @@ def batch():
{
"image_path"
:
"images/cat.jpeg"
,
"question"
:
"What is this?"
},
{
"image_path"
:
"images/dog.jpeg"
,
"question"
:
"What is this?"
},
],
max_new_tokens
=
64
,
max_new_tokens
=
128
,
)
for
s
in
states
:
print
(
s
[
"answer"
],
"
\n
"
)
...
...
examples/usage/llava_video/srt_example_llava_v.py
0 → 100644
View file @
0992d85f
"""
Usage: python3 srt_example_llava.py
"""
import
sglang
as
sgl
import
os
import
csv
import
time
import
argparse
@
sgl
.
function
def
video_qa
(
s
,
num_frames
,
video_path
,
question
):
s
+=
sgl
.
user
(
sgl
.
video
(
video_path
,
num_frames
)
+
question
)
s
+=
sgl
.
assistant
(
sgl
.
gen
(
"answer"
))
def
single
(
path
,
num_frames
=
16
):
state
=
video_qa
.
run
(
num_frames
=
num_frames
,
video_path
=
path
,
question
=
"Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes"
,
temperature
=
0.0
,
max_new_tokens
=
1024
,
)
print
(
state
[
"answer"
],
"
\n
"
)
def
split_into_chunks
(
lst
,
num_chunks
):
"""Split a list into a specified number of chunks."""
# Calculate the chunk size using integer division. Note that this may drop some items if not evenly divisible.
chunk_size
=
len
(
lst
)
//
num_chunks
if
chunk_size
==
0
:
chunk_size
=
len
(
lst
)
# Use list comprehension to generate chunks. The last chunk will take any remainder if the list size isn't evenly divisible.
chunks
=
[
lst
[
i
:
i
+
chunk_size
]
for
i
in
range
(
0
,
len
(
lst
),
chunk_size
)]
# Ensure we have exactly num_chunks chunks, even if some are empty
chunks
.
extend
([[]
for
_
in
range
(
num_chunks
-
len
(
chunks
))])
return
chunks
def
save_batch_results
(
batch_video_files
,
states
,
cur_chunk
,
batch_idx
,
save_dir
):
csv_filename
=
f
"
{
save_dir
}
/chunk_
{
cur_chunk
}
_batch_
{
batch_idx
}
.csv"
with
open
(
csv_filename
,
'w'
,
newline
=
''
)
as
csvfile
:
writer
=
csv
.
writer
(
csvfile
)
writer
.
writerow
([
'video_name'
,
'answer'
])
for
video_path
,
state
in
zip
(
batch_video_files
,
states
):
video_name
=
os
.
path
.
basename
(
video_path
)
writer
.
writerow
([
video_name
,
state
[
"answer"
]])
def
compile_and_cleanup_final_results
(
cur_chunk
,
num_batches
,
save_dir
):
final_csv_filename
=
f
"
{
save_dir
}
/final_results_chunk_
{
cur_chunk
}
.csv"
with
open
(
final_csv_filename
,
'w'
,
newline
=
''
)
as
final_csvfile
:
writer
=
csv
.
writer
(
final_csvfile
)
writer
.
writerow
([
'video_name'
,
'answer'
])
for
batch_idx
in
range
(
num_batches
):
batch_csv_filename
=
f
"
{
save_dir
}
/chunk_
{
cur_chunk
}
_batch_
{
batch_idx
}
.csv"
with
open
(
batch_csv_filename
,
'r'
)
as
batch_csvfile
:
reader
=
csv
.
reader
(
batch_csvfile
)
next
(
reader
)
# Skip header row
for
row
in
reader
:
writer
.
writerow
(
row
)
os
.
remove
(
batch_csv_filename
)
def
find_video_files
(
video_dir
):
# Check if the video_dir is actually a file
if
os
.
path
.
isfile
(
video_dir
):
# If it's a file, return it as a single-element list
return
[
video_dir
]
# Original logic to find video files in a directory
video_files
=
[]
for
root
,
dirs
,
files
in
os
.
walk
(
video_dir
):
for
file
in
files
:
if
file
.
endswith
((
'.mp4'
,
'.avi'
,
'.mov'
)):
video_files
.
append
(
os
.
path
.
join
(
root
,
file
))
return
video_files
def
batch
(
video_dir
,
save_dir
,
cur_chunk
,
num_chunks
,
num_frames
=
16
,
batch_size
=
64
):
video_files
=
find_video_files
(
video_dir
)
chunked_video_files
=
split_into_chunks
(
video_files
,
num_chunks
)[
cur_chunk
]
num_batches
=
0
for
i
in
range
(
0
,
len
(
chunked_video_files
),
batch_size
):
batch_video_files
=
chunked_video_files
[
i
:
i
+
batch_size
]
print
(
f
"Processing batch of
{
len
(
batch_video_files
)
}
video(s)..."
)
if
not
batch_video_files
:
print
(
"No video files found in the specified directory."
)
return
batch_input
=
[
{
"num_frames"
:
num_frames
,
"video_path"
:
video_path
,
"question"
:
"Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes."
,
}
for
video_path
in
batch_video_files
]
start_time
=
time
.
time
()
states
=
video_qa
.
run_batch
(
batch_input
,
max_new_tokens
=
512
,
temperature
=
0.2
)
total_time
=
time
.
time
()
-
start_time
average_time
=
total_time
/
len
(
batch_video_files
)
print
(
f
"Number of videos in batch:
{
len
(
batch_video_files
)
}
. Average processing time per video:
{
average_time
:.
2
f
}
seconds. Total time for this batch:
{
total_time
:.
2
f
}
seconds"
)
save_batch_results
(
batch_video_files
,
states
,
cur_chunk
,
num_batches
,
save_dir
)
num_batches
+=
1
compile_and_cleanup_final_results
(
cur_chunk
,
num_batches
,
save_dir
)
if
__name__
==
"__main__"
:
# Create the parser
parser
=
argparse
.
ArgumentParser
(
description
=
'Run video processing with specified port.'
)
# Add an argument for the port
parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
30000
,
help
=
'The master port for distributed serving.'
)
parser
.
add_argument
(
'--chunk-idx'
,
type
=
int
,
default
=
0
,
help
=
'The index of the chunk to process.'
)
parser
.
add_argument
(
'--num-chunks'
,
type
=
int
,
default
=
8
,
help
=
'The number of chunks to process.'
)
parser
.
add_argument
(
'--save-dir'
,
type
=
str
,
default
=
"./work_dirs/llava_video"
,
help
=
'The directory to save the processed video files.'
)
parser
.
add_argument
(
'--video-dir'
,
type
=
str
,
default
=
"./videos/Q98Z4OTh8RwmDonc.mp4"
,
help
=
'The directory or path for the processed video files.'
)
parser
.
add_argument
(
'--model-path'
,
type
=
str
,
default
=
"lmms-lab/LLaVA-NeXT-Video-7B"
,
help
=
'The model path for the video processing.'
)
parser
.
add_argument
(
'--num-frames'
,
type
=
int
,
default
=
16
,
help
=
'The number of frames to process in each video.'
)
parser
.
add_argument
(
"--mm_spatial_pool_stride"
,
type
=
int
,
default
=
2
)
# Parse the arguments
args
=
parser
.
parse_args
()
cur_port
=
args
.
port
cur_chunk
=
args
.
chunk_idx
num_chunks
=
args
.
num_chunks
num_frames
=
args
.
num_frames
if
"34b"
in
args
.
model_path
.
lower
():
tokenizer_path
=
"liuhaotian/llava-v1.6-34b-tokenizer"
elif
"7b"
in
args
.
model_path
.
lower
():
tokenizer_path
=
"llava-hf/llava-1.5-7b-hf"
else
:
print
(
"Invalid model path. Please specify a valid model path."
)
exit
()
model_overide_args
=
{}
model_overide_args
[
"mm_spatial_pool_stride"
]
=
args
.
mm_spatial_pool_stride
model_overide_args
[
"architectures"
]
=
[
"LlavaVidForCausalLM"
]
model_overide_args
[
"num_frames"
]
=
args
.
num_frames
model_overide_args
[
"model_type"
]
=
"llava"
if
"34b"
in
args
.
model_path
.
lower
():
model_overide_args
[
"image_token_index"
]
=
64002
if
args
.
num_frames
==
32
:
model_overide_args
[
"rope_scaling"
]
=
{
"factor"
:
2.0
,
"type"
:
"linear"
}
model_overide_args
[
"max_sequence_length"
]
=
4096
*
2
model_overide_args
[
"tokenizer_model_max_length"
]
=
4096
*
2
elif
args
.
num_frames
<
32
:
pass
else
:
print
(
"The maximum number of frames to process is 32. Please specify a valid number of frames."
)
exit
()
runtime
=
sgl
.
Runtime
(
model_path
=
args
.
model_path
,
#"liuhaotian/llava-v1.6-vicuna-7b",
tokenizer_path
=
tokenizer_path
,
port
=
cur_port
,
additional_ports
=
[
cur_port
+
1
,
cur_port
+
2
,
cur_port
+
3
,
cur_port
+
4
],
model_overide_args
=
model_overide_args
,
tp_size
=
1
)
sgl
.
set_default_backend
(
runtime
)
print
(
f
"chat template:
{
runtime
.
endpoint
.
chat_template
.
name
}
"
)
# Run a single request
# try:
print
(
"
\n
========== single ==========
\n
"
)
root
=
args
.
video_dir
if
os
.
path
.
isfile
(
root
):
video_files
=
[
root
]
else
:
video_files
=
[
os
.
path
.
join
(
root
,
f
)
for
f
in
os
.
listdir
(
root
)
if
f
.
endswith
((
'.mp4'
,
'.avi'
,
'.mov'
))]
# Add more extensions if needed
start_time
=
time
.
time
()
# Start time for processing a single video
for
cur_video
in
video_files
[:
1
]:
print
(
cur_video
)
single
(
cur_video
,
num_frames
)
end_time
=
time
.
time
()
# End time for processing a single video
total_time
=
end_time
-
start_time
average_time
=
total_time
/
len
(
video_files
)
# Calculate the average processing time
print
(
f
"Average processing time per video:
{
average_time
:.
2
f
}
seconds"
)
runtime
.
shutdown
()
# except Exception as e:
# print(e)
runtime
.
shutdown
()
# # # Run a batch of requests
# print("\n========== batch ==========\n")
# if not os.path.exists(args.save_dir):
# os.makedirs(args.save_dir)
# batch(args.video_dir,args.save_dir,cur_chunk, num_chunks, num_frames, num_chunks)
# runtime.shutdown()
\ No newline at end of file
examples/usage/llava_video/srt_example_llava_v.sh
0 → 100644
View file @
0992d85f
#!/bin/bash
##### USAGE #####
# - First node:
# ```sh
# bash examples/quick_start/srt_example_llava_v.sh K 0 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
# ```
# - Second node:
# ```sh
# bash examples/quick_start/srt_example_llava_v.sh K 1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
# ```
# - The K node:
# ```sh
# bash examples/quick_start/srt_example_llava_v.sh K K-1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
# ```
# Replace `K`, `YOUR_VIDEO_PATH`, `YOUR_MODEL_PATH`, and `FRAMES_PER_VIDEO` with your specific details.
CURRENT_ROOT
=
"
$(
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
"
&&
pwd
)
"
echo
${
CURRENT_ROOT
}
cd
${
CURRENT_ROOT
}
export
PYTHONWARNINGS
=
ignore
START_TIME
=
$(
date
+%s
)
# Capture start time
NUM_NODES
=
$1
CUR_NODES_IDX
=
$2
VIDEO_DIR
=
$3
MODEL_PATH
=
$4
NUM_FRAMES
=
$5
# FRAME_FORMAT=$6
# FRAME_FORMAT=$(echo $FRAME_FORMAT | tr '[:lower:]' '[:upper:]')
# # Check if FRAME_FORMAT is either JPEG or PNG
# if [[ "$FRAME_FORMAT" != "JPEG" && "$FRAME_FORMAT" != "PNG" ]]; then
# echo "Error: FRAME_FORMAT must be either JPEG or PNG."
# exit 1
# fi
# export TARGET_FRAMES=$TARGET_FRAMES
echo
"Each video you will sample
$NUM_FRAMES
frames"
# export FRAME_FORMAT=$FRAME_FORMAT
# echo "The frame format is $FRAME_FORMAT"
# Assuming GPULIST is a bash array containing your GPUs
GPULIST
=(
0 1 2 3 4 5 6 7
)
LOCAL_CHUNKS
=
${#
GPULIST
[@]
}
echo
"Number of GPUs in GPULIST:
$LOCAL_CHUNKS
"
ALL_CHUNKS
=
$((
NUM_NODES
*
LOCAL_CHUNKS
))
# Calculate GPUs per chunk
GPUS_PER_CHUNK
=
8
echo
$GPUS_PER_CHUNK
for
IDX
in
$(
seq
1
$LOCAL_CHUNKS
)
;
do
(
START
=
$((
(
IDX-1
)
*
GPUS_PER_CHUNK
))
LENGTH
=
$GPUS_PER_CHUNK
# Length for slicing, not the end index
CHUNK_GPUS
=(
${
GPULIST
[@]
:
$START
:
$LENGTH
}
)
# Convert the chunk GPUs array to a comma-separated string
CHUNK_GPUS_STR
=
$(
IFS
=
,
;
echo
"
${
CHUNK_GPUS
[*]
}
"
)
LOCAL_IDX
=
$((
CUR_NODES_IDX
*
LOCAL_CHUNKS
+
IDX
))
echo
"Chunk
$((
$LOCAL_IDX
-
1
))
will run on GPUs
$CHUNK_GPUS_STR
"
# Calculate the port for this chunk. Ensure it's incremented by 5 for each chunk.
PORT
=
$((
10000
+
RANDOM
%
55536
))
MAX_RETRIES
=
10
RETRY_COUNT
=
0
COMMAND_STATUS
=
1
# Initialize as failed
while
[
$RETRY_COUNT
-lt
$MAX_RETRIES
]
&&
[
$COMMAND_STATUS
-ne
0
]
;
do
echo
"Running chunk
$((
$LOCAL_IDX
-
1
))
on GPUs
$CHUNK_GPUS_STR
with port
$PORT
. Attempt
$((
$RETRY_COUNT
+
1
))
"
#!/bin/bash
CUDA_VISIBLE_DEVICES
=
$CHUNK_GPUS_STR
python3 examples/usage/llava_video/srt_example_llava_v.py
\
--port
$PORT
\
--num-chunks
$ALL_CHUNKS
\
--chunk-idx
$((
$LOCAL_IDX
-
1
))
\
--save-dir
work_dirs/llava_next_video_inference_results
\
--video-dir
$VIDEO_DIR
\
--model-path
$MODEL_PATH
\
--num-frames
$NUM_FRAMES
#&
wait
$!
# Wait for the process to finish and capture its exit status
COMMAND_STATUS
=
$?
if
[
$COMMAND_STATUS
-ne
0
]
;
then
echo
"Execution failed for chunk
$((
$LOCAL_IDX
-
1
))
, attempt
$((
$RETRY_COUNT
+
1
))
. Retrying..."
RETRY_COUNT
=
$((
$RETRY_COUNT
+
1
))
sleep
180
# Wait a bit before retrying
else
echo
"Execution succeeded for chunk
$((
$LOCAL_IDX
-
1
))
."
fi
done
if
[
$COMMAND_STATUS
-ne
0
]
;
then
echo
"Execution failed for chunk
$((
$LOCAL_IDX
-
1
))
after
$MAX_RETRIES
attempts."
fi
)
#&
sleep
2
# Slight delay to stagger the start times
done
wait
cat
work_dirs/llava_next_video_inference_results/final_results_chunk_
*
.csv
>
work_dirs/llava_next_video_inference_results/final_results_node_
${
CUR_NODES_IDX
}
.csv
END_TIME
=
$(
date
+%s
)
# Capture end time
ELAPSED_TIME
=
$((
$END_TIME
-
$START_TIME
))
echo
"Total execution time:
$ELAPSED_TIME
seconds."
\ No newline at end of file
examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4
0 → 100644
View file @
0992d85f
File added
python/pyproject.toml
View file @
0992d85f
...
...
@@ -20,7 +20,7 @@ dependencies = [
[project.optional-dependencies]
srt
=
[
"aiohttp"
,
"fastapi"
,
"psutil"
,
"rpyc"
,
"torch"
,
"uvloop"
,
"uvicorn"
,
"zmq"
,
"vllm>=0.4.2"
,
"interegular"
,
"pydantic"
,
"pillow"
,
"
outlines>=0.0.27"
,
"packaging
"
]
"zmq"
,
"vllm>=0.4.2"
,
"interegular"
,
"pydantic"
,
"pillow"
,
"
packaging"
,
"huggingface_hub"
,
"hf_transfer"
,
"outlines>=0.0.34
"
]
openai
=
[
"openai>=1.0"
,
"numpy"
,
"tiktoken"
]
anthropic
=
[
"anthropic>=0.20.0"
,
"numpy"
]
all
=
["sglang[srt]
", "
sglang
[openai]
", "
sglang
[anthropic]"]
...
...
python/sglang/__init__.py
View file @
0992d85f
...
...
@@ -19,6 +19,7 @@ from sglang.api import (
user
,
user_begin
,
user_end
,
video
,
)
# SGL Backends
...
...
@@ -46,6 +47,7 @@ __all__ = [
"gen_int"
,
"gen_string"
,
"image"
,
"video"
,
"select"
,
"system"
,
"user"
,
...
...
python/sglang/api.py
View file @
0992d85f
...
...
@@ -15,6 +15,7 @@ from sglang.lang.ir import (
SglRoleBegin
,
SglRoleEnd
,
SglSelect
,
SglVideo
,
)
...
...
@@ -151,6 +152,10 @@ def image(expr: SglExpr):
return
SglImage
(
expr
)
def
video
(
path
:
str
,
num_frames
:
int
):
return
SglVideo
(
path
,
num_frames
)
def
select
(
name
:
Optional
[
str
]
=
None
,
choices
:
List
[
str
]
=
None
,
...
...
python/sglang/lang/chat_template.py
View file @
0992d85f
...
...
@@ -259,6 +259,8 @@ def match_vicuna(model_path: str):
return
get_chat_template
(
"vicuna_v1.1"
)
if
"llava-v1.5"
in
model_path
.
lower
():
return
get_chat_template
(
"vicuna_v1.1"
)
if
"llava-next-video-7b"
in
model_path
.
lower
():
return
get_chat_template
(
"vicuna_v1.1"
)
@
register_chat_template_matching_function
...
...
@@ -283,19 +285,24 @@ def match_llama3_instruct(model_path: str):
@
register_chat_template_matching_function
def
match_chat_ml
(
model_path
:
str
):
# import pdb;pdb.set_trace()
model_path
=
model_path
.
lower
()
if
"tinyllama"
in
model_path
:
return
get_chat_template
(
"chatml"
)
if
"qwen"
in
model_path
and
"chat"
in
model_path
:
return
get_chat_template
(
"chatml"
)
if
"llava-v1.6-34b"
in
model_path
:
if
(
"llava-v1.6-34b"
in
model_path
or
"llava-v1.6-yi-34b"
in
model_path
or
"llava-next-video-34b"
in
model_path
):
return
get_chat_template
(
"chatml-llava"
)
@
register_chat_template_matching_function
def
match_chat_yi
(
model_path
:
str
):
model_path
=
model_path
.
lower
()
if
"yi"
in
model_path
:
if
"yi"
in
model_path
and
"llava"
not
in
model_path
:
return
get_chat_template
(
"yi"
)
...
...
python/sglang/lang/interpreter.py
View file @
0992d85f
...
...
@@ -28,8 +28,9 @@ from sglang.lang.ir import (
SglVariable
,
SglVarScopeBegin
,
SglVarScopeEnd
,
SglVideo
,
)
from
sglang.utils
import
encode_image_base64
,
get_exception_traceback
from
sglang.utils
import
encode_image_base64
,
encode_video_base64
,
get_exception_traceback
def
run_internal
(
state
,
program
,
func_args
,
func_kwargs
,
sync
):
...
...
@@ -361,6 +362,8 @@ class StreamExecutor:
self
.
_execute_role_end
(
other
)
elif
isinstance
(
other
,
SglImage
):
self
.
_execute_image
(
other
)
elif
isinstance
(
other
,
SglVideo
):
self
.
_execute_video
(
other
)
elif
isinstance
(
other
,
SglVariable
):
self
.
_execute_variable
(
other
)
elif
isinstance
(
other
,
SglVarScopeBegin
):
...
...
@@ -397,6 +400,16 @@ class StreamExecutor:
self
.
cur_images
.
append
((
path
,
base64_data
))
self
.
text_
+=
self
.
chat_template
.
image_token
def
_execute_video
(
self
,
expr
:
SglVideo
):
path
=
expr
.
path
num_frames
=
expr
.
num_frames
base64_data
=
encode_video_base64
(
path
,
num_frames
)
self
.
images_
.
append
((
path
,
base64_data
))
self
.
cur_images
.
append
((
path
,
base64_data
))
self
.
text_
+=
self
.
chat_template
.
image_token
# if global_config.eager_fill_image:
# self.backend.fill_image(self)
...
...
python/sglang/lang/ir.py
View file @
0992d85f
...
...
@@ -330,6 +330,15 @@ class SglImage(SglExpr):
return
f
"SglImage(
{
self
.
path
}
)"
class
SglVideo
(
SglExpr
):
def
__init__
(
self
,
path
,
num_frames
):
self
.
path
=
path
self
.
num_frames
=
num_frames
def
__repr__
(
self
)
->
str
:
return
f
"SglVideo(
{
self
.
path
}
,
{
self
.
num_frames
}
)"
class
SglGen
(
SglExpr
):
def
__init__
(
self
,
...
...
python/sglang/lang/tracer.py
View file @
0992d85f
...
...
@@ -110,7 +110,7 @@ class TracerProgramState(ProgramState):
##################################
def
fork
(
self
,
size
:
int
=
1
,
position_ids_offset
:
Optional
[
List
[
int
]]
=
None
):
assert
(
size
>=
1
)
assert
size
>=
1
if
self
.
only_trace_prefix
:
raise
StopTracing
()
...
...
python/sglang/launch_server.py
View file @
0992d85f
...
...
@@ -2,7 +2,6 @@ import argparse
from
sglang.srt.server
import
ServerArgs
,
launch_server
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
ServerArgs
.
add_cli_args
(
parser
)
...
...
python/sglang/launch_server_llavavid.py
0 → 100644
View file @
0992d85f
import
argparse
import
multiprocessing
as
mp
from
sglang.srt.server
import
ServerArgs
,
launch_server
if
__name__
==
"__main__"
:
model_overide_args
=
{}
model_overide_args
[
"mm_spatial_pool_stride"
]
=
2
model_overide_args
[
"architectures"
]
=
[
"LlavaVidForCausalLM"
]
model_overide_args
[
"num_frames"
]
=
16
model_overide_args
[
"model_type"
]
=
"llavavid"
if
model_overide_args
[
"num_frames"
]
==
32
:
model_overide_args
[
"rope_scaling"
]
=
{
"factor"
:
2.0
,
"type"
:
"linear"
}
model_overide_args
[
"max_sequence_length"
]
=
4096
*
2
model_overide_args
[
"tokenizer_model_max_length"
]
=
4096
*
2
model_overide_args
[
"model_max_length"
]
=
4096
*
2
parser
=
argparse
.
ArgumentParser
()
ServerArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
if
"34b"
in
args
.
model_path
.
lower
():
model_overide_args
[
"image_token_index"
]
=
64002
server_args
=
ServerArgs
.
from_cli_args
(
args
)
pipe_reader
,
pipe_writer
=
mp
.
Pipe
(
duplex
=
False
)
launch_server
(
server_args
,
pipe_writer
,
model_overide_args
)
python/sglang/srt/hf_transformers_utils.py
View file @
0992d85f
...
...
@@ -30,10 +30,17 @@ def get_config_json(model_path: str):
return
config
def
get_config
(
model
:
str
,
trust_remote_code
:
bool
,
revision
:
Optional
[
str
]
=
None
):
def
get_config
(
model
:
str
,
trust_remote_code
:
bool
,
revision
:
Optional
[
str
]
=
None
,
model_overide_args
:
Optional
[
dict
]
=
None
,
):
config
=
AutoConfig
.
from_pretrained
(
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
)
if
model_overide_args
:
config
.
update
(
model_overide_args
)
return
config
...
...
python/sglang/srt/managers/router/manager.py
View file @
0992d85f
...
...
@@ -60,9 +60,7 @@ class RouterManager:
def
start_router_process
(
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
pipe_writer
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
pipe_writer
,
model_overide_args
):
logging
.
basicConfig
(
level
=
getattr
(
logging
,
server_args
.
log_level
.
upper
()),
...
...
@@ -70,7 +68,7 @@ def start_router_process(
)
try
:
model_client
=
ModelRpcClient
(
server_args
,
port_args
)
model_client
=
ModelRpcClient
(
server_args
,
port_args
,
model_overide_args
)
router
=
RouterManager
(
model_client
,
port_args
)
except
Exception
:
pipe_writer
.
send
(
get_exception_traceback
())
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
0992d85f
...
...
@@ -4,12 +4,13 @@ import multiprocessing
import
time
import
warnings
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
List
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
rpyc
import
torch
from
rpyc.utils.classic
import
obtain
from
rpyc.utils.server
import
ThreadedServer
try
:
from
vllm.logger
import
_default_handler
as
vllm_default_logger
except
ImportError
:
...
...
@@ -48,6 +49,7 @@ class ModelRpcServer:
tp_rank
:
int
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
model_overide_args
:
Optional
[
dict
]
=
None
,
):
server_args
,
port_args
=
[
obtain
(
x
)
for
x
in
[
server_args
,
port_args
]]
...
...
@@ -62,6 +64,7 @@ class ModelRpcServer:
server_args
.
model_path
,
server_args
.
trust_remote_code
,
context_length
=
server_args
.
context_length
,
model_overide_args
=
model_overide_args
,
)
# For model end global settings
...
...
@@ -673,13 +676,15 @@ class ModelRpcService(rpyc.Service):
class
ModelRpcClient
:
def
__init__
(
self
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
):
def
__init__
(
self
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
model_overide_args
):
tp_size
=
server_args
.
tp_size
if
tp_size
==
1
:
# Init model
self
.
model_server
=
ModelRpcService
().
exposed_ModelRpcServer
(
0
,
server_args
,
port_args
0
,
server_args
,
port_args
,
model_overide_args
)
# Wrap functions
...
...
@@ -700,7 +705,7 @@ class ModelRpcClient:
# Init model
def
init_model
(
i
):
return
self
.
remote_services
[
i
].
ModelRpcServer
(
i
,
server_args
,
port_args
i
,
server_args
,
port_args
,
model_overide_args
)
self
.
model_servers
=
executor
.
map
(
init_model
,
range
(
tp_size
))
...
...
@@ -723,7 +728,11 @@ def _init_service(port):
t
=
ThreadedServer
(
ModelRpcService
(),
port
=
port
,
protocol_config
=
{
"allow_pickle"
:
True
,
"sync_request_timeout"
:
1800
},
protocol_config
=
{
"allow_public_attrs"
:
True
,
"allow_pickle"
:
True
,
"sync_request_timeout"
:
1800
,
},
)
t
.
start
()
...
...
@@ -739,7 +748,11 @@ def start_model_process(port):
con
=
rpyc
.
connect
(
"localhost"
,
port
,
config
=
{
"allow_pickle"
:
True
,
"sync_request_timeout"
:
1800
},
config
=
{
"allow_public_attrs"
:
True
,
"allow_pickle"
:
True
,
"sync_request_timeout"
:
1800
,
},
)
break
except
ConnectionRefusedError
:
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
0992d85f
...
...
@@ -9,11 +9,11 @@ from typing import List
import
numpy
as
np
import
torch
from
vllm.distributed
import
initialize_model_parallel
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.distributed
import
initialize_model_parallel
from
sglang.srt.managers.router.infer_batch
import
Batch
,
ForwardMode
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
...
...
@@ -143,7 +143,7 @@ class InputMetadata:
self
.
kv_last_page_len
,
self
.
model_runner
.
model_config
.
num_attention_heads
//
tp_size
,
self
.
model_runner
.
model_config
.
num_key_value_heads
//
tp_size
,
self
.
model_runner
.
model_config
.
head_dim
self
.
model_runner
.
model_config
.
head_dim
,
]
self
.
prefill_wrapper
.
begin_forward
(
*
args
)
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
0992d85f
...
...
@@ -60,7 +60,15 @@ def get_pixel_values(
):
try
:
processor
=
processor
or
global_processor
image
=
load_image
(
image_data
)
image
,
image_size
=
load_image
(
image_data
)
if
image_size
!=
None
:
image_hash
=
hash
(
image_data
)
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
]
for
_
in
range
(
len
(
pixel_values
)):
pixel_values
[
_
]
=
pixel_values
[
_
].
astype
(
np
.
float16
)
pixel_values
=
np
.
stack
(
pixel_values
,
axis
=
0
)
return
pixel_values
,
image_hash
,
image_size
else
:
image_hash
=
hash
(
image_data
)
if
image_aspect_ratio
==
"pad"
:
image
=
expand2square
(
...
...
@@ -84,6 +92,7 @@ class TokenizerManager:
self
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
model_overide_args
:
dict
=
None
,
):
self
.
server_args
=
server_args
...
...
@@ -96,7 +105,9 @@ class TokenizerManager:
self
.
model_path
=
server_args
.
model_path
self
.
hf_config
=
get_config
(
self
.
model_path
,
trust_remote_code
=
server_args
.
trust_remote_code
self
.
model_path
,
trust_remote_code
=
server_args
.
trust_remote_code
,
model_overide_args
=
model_overide_args
,
)
self
.
context_len
=
get_context_length
(
self
.
hf_config
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment