Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
0992d85f
"vscode:/vscode.git/clone" did not exist on "a5bfc296648b8c77374d7df0176d304b4d5ea421"
Unverified
Commit
0992d85f
authored
May 14, 2024
by
Yuanhan Zhang
Committed by
GitHub
May 13, 2024
Browse files
support llava video (#426)
parent
5dc55a5f
Changes
37
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
476 additions
and
39 deletions
+476
-39
.gitignore
.gitignore
+4
-0
README.md
README.md
+1
-1
examples/quick_start/srt_example_llava.py
examples/quick_start/srt_example_llava.py
+2
-2
examples/usage/llava_video/srt_example_llava_v.py
examples/usage/llava_video/srt_example_llava_v.py
+208
-0
examples/usage/llava_video/srt_example_llava_v.sh
examples/usage/llava_video/srt_example_llava_v.sh
+130
-0
examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4
examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4
+0
-0
python/pyproject.toml
python/pyproject.toml
+2
-2
python/sglang/__init__.py
python/sglang/__init__.py
+2
-0
python/sglang/api.py
python/sglang/api.py
+5
-0
python/sglang/lang/chat_template.py
python/sglang/lang/chat_template.py
+9
-2
python/sglang/lang/interpreter.py
python/sglang/lang/interpreter.py
+14
-1
python/sglang/lang/ir.py
python/sglang/lang/ir.py
+9
-0
python/sglang/lang/tracer.py
python/sglang/lang/tracer.py
+1
-1
python/sglang/launch_server.py
python/sglang/launch_server.py
+1
-2
python/sglang/launch_server_llavavid.py
python/sglang/launch_server_llavavid.py
+31
-0
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+8
-1
python/sglang/srt/managers/router/manager.py
python/sglang/srt/managers/router/manager.py
+2
-4
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+19
-6
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+2
-2
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+26
-15
No files found.
.gitignore
View file @
0992d85f
...
...
@@ -177,3 +177,7 @@ tmp*.txt
# Plots
*.png
*.pdf
# personnal
work_dirs/
*.csv
README.md
View file @
0992d85f
examples/quick_start/srt_example_llava.py
View file @
0992d85f
...
...
@@ -14,7 +14,7 @@ def single():
state
=
image_qa
.
run
(
image_path
=
"images/cat.jpeg"
,
question
=
"What is this?"
,
max_new_tokens
=
64
)
max_new_tokens
=
128
)
print
(
state
[
"answer"
],
"
\n
"
)
...
...
@@ -36,7 +36,7 @@ def batch():
{
"image_path"
:
"images/cat.jpeg"
,
"question"
:
"What is this?"
},
{
"image_path"
:
"images/dog.jpeg"
,
"question"
:
"What is this?"
},
],
max_new_tokens
=
64
,
max_new_tokens
=
128
,
)
for
s
in
states
:
print
(
s
[
"answer"
],
"
\n
"
)
...
...
examples/usage/llava_video/srt_example_llava_v.py
0 → 100644
View file @
0992d85f
"""
Usage: python3 srt_example_llava.py
"""
import
sglang
as
sgl
import
os
import
csv
import
time
import
argparse
@
sgl
.
function
def
video_qa
(
s
,
num_frames
,
video_path
,
question
):
s
+=
sgl
.
user
(
sgl
.
video
(
video_path
,
num_frames
)
+
question
)
s
+=
sgl
.
assistant
(
sgl
.
gen
(
"answer"
))
def
single
(
path
,
num_frames
=
16
):
state
=
video_qa
.
run
(
num_frames
=
num_frames
,
video_path
=
path
,
question
=
"Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes"
,
temperature
=
0.0
,
max_new_tokens
=
1024
,
)
print
(
state
[
"answer"
],
"
\n
"
)
def
split_into_chunks
(
lst
,
num_chunks
):
"""Split a list into a specified number of chunks."""
# Calculate the chunk size using integer division. Note that this may drop some items if not evenly divisible.
chunk_size
=
len
(
lst
)
//
num_chunks
if
chunk_size
==
0
:
chunk_size
=
len
(
lst
)
# Use list comprehension to generate chunks. The last chunk will take any remainder if the list size isn't evenly divisible.
chunks
=
[
lst
[
i
:
i
+
chunk_size
]
for
i
in
range
(
0
,
len
(
lst
),
chunk_size
)]
# Ensure we have exactly num_chunks chunks, even if some are empty
chunks
.
extend
([[]
for
_
in
range
(
num_chunks
-
len
(
chunks
))])
return
chunks
def
save_batch_results
(
batch_video_files
,
states
,
cur_chunk
,
batch_idx
,
save_dir
):
csv_filename
=
f
"
{
save_dir
}
/chunk_
{
cur_chunk
}
_batch_
{
batch_idx
}
.csv"
with
open
(
csv_filename
,
'w'
,
newline
=
''
)
as
csvfile
:
writer
=
csv
.
writer
(
csvfile
)
writer
.
writerow
([
'video_name'
,
'answer'
])
for
video_path
,
state
in
zip
(
batch_video_files
,
states
):
video_name
=
os
.
path
.
basename
(
video_path
)
writer
.
writerow
([
video_name
,
state
[
"answer"
]])
def
compile_and_cleanup_final_results
(
cur_chunk
,
num_batches
,
save_dir
):
final_csv_filename
=
f
"
{
save_dir
}
/final_results_chunk_
{
cur_chunk
}
.csv"
with
open
(
final_csv_filename
,
'w'
,
newline
=
''
)
as
final_csvfile
:
writer
=
csv
.
writer
(
final_csvfile
)
writer
.
writerow
([
'video_name'
,
'answer'
])
for
batch_idx
in
range
(
num_batches
):
batch_csv_filename
=
f
"
{
save_dir
}
/chunk_
{
cur_chunk
}
_batch_
{
batch_idx
}
.csv"
with
open
(
batch_csv_filename
,
'r'
)
as
batch_csvfile
:
reader
=
csv
.
reader
(
batch_csvfile
)
next
(
reader
)
# Skip header row
for
row
in
reader
:
writer
.
writerow
(
row
)
os
.
remove
(
batch_csv_filename
)
def
find_video_files
(
video_dir
):
# Check if the video_dir is actually a file
if
os
.
path
.
isfile
(
video_dir
):
# If it's a file, return it as a single-element list
return
[
video_dir
]
# Original logic to find video files in a directory
video_files
=
[]
for
root
,
dirs
,
files
in
os
.
walk
(
video_dir
):
for
file
in
files
:
if
file
.
endswith
((
'.mp4'
,
'.avi'
,
'.mov'
)):
video_files
.
append
(
os
.
path
.
join
(
root
,
file
))
return
video_files
def
batch
(
video_dir
,
save_dir
,
cur_chunk
,
num_chunks
,
num_frames
=
16
,
batch_size
=
64
):
video_files
=
find_video_files
(
video_dir
)
chunked_video_files
=
split_into_chunks
(
video_files
,
num_chunks
)[
cur_chunk
]
num_batches
=
0
for
i
in
range
(
0
,
len
(
chunked_video_files
),
batch_size
):
batch_video_files
=
chunked_video_files
[
i
:
i
+
batch_size
]
print
(
f
"Processing batch of
{
len
(
batch_video_files
)
}
video(s)..."
)
if
not
batch_video_files
:
print
(
"No video files found in the specified directory."
)
return
batch_input
=
[
{
"num_frames"
:
num_frames
,
"video_path"
:
video_path
,
"question"
:
"Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes."
,
}
for
video_path
in
batch_video_files
]
start_time
=
time
.
time
()
states
=
video_qa
.
run_batch
(
batch_input
,
max_new_tokens
=
512
,
temperature
=
0.2
)
total_time
=
time
.
time
()
-
start_time
average_time
=
total_time
/
len
(
batch_video_files
)
print
(
f
"Number of videos in batch:
{
len
(
batch_video_files
)
}
. Average processing time per video:
{
average_time
:.
2
f
}
seconds. Total time for this batch:
{
total_time
:.
2
f
}
seconds"
)
save_batch_results
(
batch_video_files
,
states
,
cur_chunk
,
num_batches
,
save_dir
)
num_batches
+=
1
compile_and_cleanup_final_results
(
cur_chunk
,
num_batches
,
save_dir
)
if
__name__
==
"__main__"
:
# Create the parser
parser
=
argparse
.
ArgumentParser
(
description
=
'Run video processing with specified port.'
)
# Add an argument for the port
parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
30000
,
help
=
'The master port for distributed serving.'
)
parser
.
add_argument
(
'--chunk-idx'
,
type
=
int
,
default
=
0
,
help
=
'The index of the chunk to process.'
)
parser
.
add_argument
(
'--num-chunks'
,
type
=
int
,
default
=
8
,
help
=
'The number of chunks to process.'
)
parser
.
add_argument
(
'--save-dir'
,
type
=
str
,
default
=
"./work_dirs/llava_video"
,
help
=
'The directory to save the processed video files.'
)
parser
.
add_argument
(
'--video-dir'
,
type
=
str
,
default
=
"./videos/Q98Z4OTh8RwmDonc.mp4"
,
help
=
'The directory or path for the processed video files.'
)
parser
.
add_argument
(
'--model-path'
,
type
=
str
,
default
=
"lmms-lab/LLaVA-NeXT-Video-7B"
,
help
=
'The model path for the video processing.'
)
parser
.
add_argument
(
'--num-frames'
,
type
=
int
,
default
=
16
,
help
=
'The number of frames to process in each video.'
)
parser
.
add_argument
(
"--mm_spatial_pool_stride"
,
type
=
int
,
default
=
2
)
# Parse the arguments
args
=
parser
.
parse_args
()
cur_port
=
args
.
port
cur_chunk
=
args
.
chunk_idx
num_chunks
=
args
.
num_chunks
num_frames
=
args
.
num_frames
if
"34b"
in
args
.
model_path
.
lower
():
tokenizer_path
=
"liuhaotian/llava-v1.6-34b-tokenizer"
elif
"7b"
in
args
.
model_path
.
lower
():
tokenizer_path
=
"llava-hf/llava-1.5-7b-hf"
else
:
print
(
"Invalid model path. Please specify a valid model path."
)
exit
()
model_overide_args
=
{}
model_overide_args
[
"mm_spatial_pool_stride"
]
=
args
.
mm_spatial_pool_stride
model_overide_args
[
"architectures"
]
=
[
"LlavaVidForCausalLM"
]
model_overide_args
[
"num_frames"
]
=
args
.
num_frames
model_overide_args
[
"model_type"
]
=
"llava"
if
"34b"
in
args
.
model_path
.
lower
():
model_overide_args
[
"image_token_index"
]
=
64002
if
args
.
num_frames
==
32
:
model_overide_args
[
"rope_scaling"
]
=
{
"factor"
:
2.0
,
"type"
:
"linear"
}
model_overide_args
[
"max_sequence_length"
]
=
4096
*
2
model_overide_args
[
"tokenizer_model_max_length"
]
=
4096
*
2
elif
args
.
num_frames
<
32
:
pass
else
:
print
(
"The maximum number of frames to process is 32. Please specify a valid number of frames."
)
exit
()
runtime
=
sgl
.
Runtime
(
model_path
=
args
.
model_path
,
#"liuhaotian/llava-v1.6-vicuna-7b",
tokenizer_path
=
tokenizer_path
,
port
=
cur_port
,
additional_ports
=
[
cur_port
+
1
,
cur_port
+
2
,
cur_port
+
3
,
cur_port
+
4
],
model_overide_args
=
model_overide_args
,
tp_size
=
1
)
sgl
.
set_default_backend
(
runtime
)
print
(
f
"chat template:
{
runtime
.
endpoint
.
chat_template
.
name
}
"
)
# Run a single request
# try:
print
(
"
\n
========== single ==========
\n
"
)
root
=
args
.
video_dir
if
os
.
path
.
isfile
(
root
):
video_files
=
[
root
]
else
:
video_files
=
[
os
.
path
.
join
(
root
,
f
)
for
f
in
os
.
listdir
(
root
)
if
f
.
endswith
((
'.mp4'
,
'.avi'
,
'.mov'
))]
# Add more extensions if needed
start_time
=
time
.
time
()
# Start time for processing a single video
for
cur_video
in
video_files
[:
1
]:
print
(
cur_video
)
single
(
cur_video
,
num_frames
)
end_time
=
time
.
time
()
# End time for processing a single video
total_time
=
end_time
-
start_time
average_time
=
total_time
/
len
(
video_files
)
# Calculate the average processing time
print
(
f
"Average processing time per video:
{
average_time
:.
2
f
}
seconds"
)
runtime
.
shutdown
()
# except Exception as e:
# print(e)
runtime
.
shutdown
()
# # # Run a batch of requests
# print("\n========== batch ==========\n")
# if not os.path.exists(args.save_dir):
# os.makedirs(args.save_dir)
# batch(args.video_dir,args.save_dir,cur_chunk, num_chunks, num_frames, num_chunks)
# runtime.shutdown()
\ No newline at end of file
examples/usage/llava_video/srt_example_llava_v.sh
0 → 100644
View file @
0992d85f
#!/bin/bash
##### USAGE #####
# - First node:
# ```sh
# bash examples/quick_start/srt_example_llava_v.sh K 0 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
# ```
# - Second node:
# ```sh
# bash examples/quick_start/srt_example_llava_v.sh K 1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
# ```
# - The K node:
# ```sh
# bash examples/quick_start/srt_example_llava_v.sh K K-1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
# ```
# Replace `K`, `YOUR_VIDEO_PATH`, `YOUR_MODEL_PATH`, and `FRAMES_PER_VIDEO` with your specific details.
CURRENT_ROOT
=
"
$(
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
"
&&
pwd
)
"
echo
${
CURRENT_ROOT
}
cd
${
CURRENT_ROOT
}
export
PYTHONWARNINGS
=
ignore
START_TIME
=
$(
date
+%s
)
# Capture start time
NUM_NODES
=
$1
CUR_NODES_IDX
=
$2
VIDEO_DIR
=
$3
MODEL_PATH
=
$4
NUM_FRAMES
=
$5
# FRAME_FORMAT=$6
# FRAME_FORMAT=$(echo $FRAME_FORMAT | tr '[:lower:]' '[:upper:]')
# # Check if FRAME_FORMAT is either JPEG or PNG
# if [[ "$FRAME_FORMAT" != "JPEG" && "$FRAME_FORMAT" != "PNG" ]]; then
# echo "Error: FRAME_FORMAT must be either JPEG or PNG."
# exit 1
# fi
# export TARGET_FRAMES=$TARGET_FRAMES
echo
"Each video you will sample
$NUM_FRAMES
frames"
# export FRAME_FORMAT=$FRAME_FORMAT
# echo "The frame format is $FRAME_FORMAT"
# Assuming GPULIST is a bash array containing your GPUs
GPULIST
=(
0 1 2 3 4 5 6 7
)
LOCAL_CHUNKS
=
${#
GPULIST
[@]
}
echo
"Number of GPUs in GPULIST:
$LOCAL_CHUNKS
"
ALL_CHUNKS
=
$((
NUM_NODES
*
LOCAL_CHUNKS
))
# Calculate GPUs per chunk
GPUS_PER_CHUNK
=
8
echo
$GPUS_PER_CHUNK
for
IDX
in
$(
seq
1
$LOCAL_CHUNKS
)
;
do
(
START
=
$((
(
IDX-1
)
*
GPUS_PER_CHUNK
))
LENGTH
=
$GPUS_PER_CHUNK
# Length for slicing, not the end index
CHUNK_GPUS
=(
${
GPULIST
[@]
:
$START
:
$LENGTH
}
)
# Convert the chunk GPUs array to a comma-separated string
CHUNK_GPUS_STR
=
$(
IFS
=
,
;
echo
"
${
CHUNK_GPUS
[*]
}
"
)
LOCAL_IDX
=
$((
CUR_NODES_IDX
*
LOCAL_CHUNKS
+
IDX
))
echo
"Chunk
$((
$LOCAL_IDX
-
1
))
will run on GPUs
$CHUNK_GPUS_STR
"
# Calculate the port for this chunk. Ensure it's incremented by 5 for each chunk.
PORT
=
$((
10000
+
RANDOM
%
55536
))
MAX_RETRIES
=
10
RETRY_COUNT
=
0
COMMAND_STATUS
=
1
# Initialize as failed
while
[
$RETRY_COUNT
-lt
$MAX_RETRIES
]
&&
[
$COMMAND_STATUS
-ne
0
]
;
do
echo
"Running chunk
$((
$LOCAL_IDX
-
1
))
on GPUs
$CHUNK_GPUS_STR
with port
$PORT
. Attempt
$((
$RETRY_COUNT
+
1
))
"
#!/bin/bash
CUDA_VISIBLE_DEVICES
=
$CHUNK_GPUS_STR
python3 examples/usage/llava_video/srt_example_llava_v.py
\
--port
$PORT
\
--num-chunks
$ALL_CHUNKS
\
--chunk-idx
$((
$LOCAL_IDX
-
1
))
\
--save-dir
work_dirs/llava_next_video_inference_results
\
--video-dir
$VIDEO_DIR
\
--model-path
$MODEL_PATH
\
--num-frames
$NUM_FRAMES
#&
wait
$!
# Wait for the process to finish and capture its exit status
COMMAND_STATUS
=
$?
if
[
$COMMAND_STATUS
-ne
0
]
;
then
echo
"Execution failed for chunk
$((
$LOCAL_IDX
-
1
))
, attempt
$((
$RETRY_COUNT
+
1
))
. Retrying..."
RETRY_COUNT
=
$((
$RETRY_COUNT
+
1
))
sleep
180
# Wait a bit before retrying
else
echo
"Execution succeeded for chunk
$((
$LOCAL_IDX
-
1
))
."
fi
done
if
[
$COMMAND_STATUS
-ne
0
]
;
then
echo
"Execution failed for chunk
$((
$LOCAL_IDX
-
1
))
after
$MAX_RETRIES
attempts."
fi
)
#&
sleep
2
# Slight delay to stagger the start times
done
wait
cat
work_dirs/llava_next_video_inference_results/final_results_chunk_
*
.csv
>
work_dirs/llava_next_video_inference_results/final_results_node_
${
CUR_NODES_IDX
}
.csv
END_TIME
=
$(
date
+%s
)
# Capture end time
ELAPSED_TIME
=
$((
$END_TIME
-
$START_TIME
))
echo
"Total execution time:
$ELAPSED_TIME
seconds."
\ No newline at end of file
examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4
0 → 100644
View file @
0992d85f
File added
python/pyproject.toml
View file @
0992d85f
...
...
@@ -20,7 +20,7 @@ dependencies = [
[project.optional-dependencies]
srt
=
[
"aiohttp"
,
"fastapi"
,
"psutil"
,
"rpyc"
,
"torch"
,
"uvloop"
,
"uvicorn"
,
"zmq"
,
"vllm>=0.4.2"
,
"interegular"
,
"pydantic"
,
"pillow"
,
"
outlines>=0.0.27"
,
"packaging
"
]
"zmq"
,
"vllm>=0.4.2"
,
"interegular"
,
"pydantic"
,
"pillow"
,
"
packaging"
,
"huggingface_hub"
,
"hf_transfer"
,
"outlines>=0.0.34
"
]
openai
=
[
"openai>=1.0"
,
"numpy"
,
"tiktoken"
]
anthropic
=
[
"anthropic>=0.20.0"
,
"numpy"
]
all
=
["sglang[srt]
", "
sglang
[openai]
", "
sglang
[anthropic]"]
...
...
python/sglang/__init__.py
View file @
0992d85f
...
...
@@ -19,6 +19,7 @@ from sglang.api import (
user
,
user_begin
,
user_end
,
video
,
)
# SGL Backends
...
...
@@ -46,6 +47,7 @@ __all__ = [
"gen_int"
,
"gen_string"
,
"image"
,
"video"
,
"select"
,
"system"
,
"user"
,
...
...
python/sglang/api.py
View file @
0992d85f
...
...
@@ -15,6 +15,7 @@ from sglang.lang.ir import (
SglRoleBegin
,
SglRoleEnd
,
SglSelect
,
SglVideo
,
)
...
...
@@ -151,6 +152,10 @@ def image(expr: SglExpr):
return
SglImage
(
expr
)
def
video
(
path
:
str
,
num_frames
:
int
):
return
SglVideo
(
path
,
num_frames
)
def
select
(
name
:
Optional
[
str
]
=
None
,
choices
:
List
[
str
]
=
None
,
...
...
python/sglang/lang/chat_template.py
View file @
0992d85f
...
...
@@ -259,6 +259,8 @@ def match_vicuna(model_path: str):
return
get_chat_template
(
"vicuna_v1.1"
)
if
"llava-v1.5"
in
model_path
.
lower
():
return
get_chat_template
(
"vicuna_v1.1"
)
if
"llava-next-video-7b"
in
model_path
.
lower
():
return
get_chat_template
(
"vicuna_v1.1"
)
@
register_chat_template_matching_function
...
...
@@ -283,19 +285,24 @@ def match_llama3_instruct(model_path: str):
@
register_chat_template_matching_function
def
match_chat_ml
(
model_path
:
str
):
# import pdb;pdb.set_trace()
model_path
=
model_path
.
lower
()
if
"tinyllama"
in
model_path
:
return
get_chat_template
(
"chatml"
)
if
"qwen"
in
model_path
and
"chat"
in
model_path
:
return
get_chat_template
(
"chatml"
)
if
"llava-v1.6-34b"
in
model_path
:
if
(
"llava-v1.6-34b"
in
model_path
or
"llava-v1.6-yi-34b"
in
model_path
or
"llava-next-video-34b"
in
model_path
):
return
get_chat_template
(
"chatml-llava"
)
@
register_chat_template_matching_function
def
match_chat_yi
(
model_path
:
str
):
model_path
=
model_path
.
lower
()
if
"yi"
in
model_path
:
if
"yi"
in
model_path
and
"llava"
not
in
model_path
:
return
get_chat_template
(
"yi"
)
...
...
python/sglang/lang/interpreter.py
View file @
0992d85f
...
...
@@ -28,8 +28,9 @@ from sglang.lang.ir import (
SglVariable
,
SglVarScopeBegin
,
SglVarScopeEnd
,
SglVideo
,
)
from
sglang.utils
import
encode_image_base64
,
get_exception_traceback
from
sglang.utils
import
encode_image_base64
,
encode_video_base64
,
get_exception_traceback
def
run_internal
(
state
,
program
,
func_args
,
func_kwargs
,
sync
):
...
...
@@ -361,6 +362,8 @@ class StreamExecutor:
self
.
_execute_role_end
(
other
)
elif
isinstance
(
other
,
SglImage
):
self
.
_execute_image
(
other
)
elif
isinstance
(
other
,
SglVideo
):
self
.
_execute_video
(
other
)
elif
isinstance
(
other
,
SglVariable
):
self
.
_execute_variable
(
other
)
elif
isinstance
(
other
,
SglVarScopeBegin
):
...
...
@@ -397,6 +400,16 @@ class StreamExecutor:
self
.
cur_images
.
append
((
path
,
base64_data
))
self
.
text_
+=
self
.
chat_template
.
image_token
def
_execute_video
(
self
,
expr
:
SglVideo
):
path
=
expr
.
path
num_frames
=
expr
.
num_frames
base64_data
=
encode_video_base64
(
path
,
num_frames
)
self
.
images_
.
append
((
path
,
base64_data
))
self
.
cur_images
.
append
((
path
,
base64_data
))
self
.
text_
+=
self
.
chat_template
.
image_token
# if global_config.eager_fill_image:
# self.backend.fill_image(self)
...
...
python/sglang/lang/ir.py
View file @
0992d85f
...
...
@@ -330,6 +330,15 @@ class SglImage(SglExpr):
return
f
"SglImage(
{
self
.
path
}
)"
class
SglVideo
(
SglExpr
):
def
__init__
(
self
,
path
,
num_frames
):
self
.
path
=
path
self
.
num_frames
=
num_frames
def
__repr__
(
self
)
->
str
:
return
f
"SglVideo(
{
self
.
path
}
,
{
self
.
num_frames
}
)"
class
SglGen
(
SglExpr
):
def
__init__
(
self
,
...
...
python/sglang/lang/tracer.py
View file @
0992d85f
...
...
@@ -110,7 +110,7 @@ class TracerProgramState(ProgramState):
##################################
def
fork
(
self
,
size
:
int
=
1
,
position_ids_offset
:
Optional
[
List
[
int
]]
=
None
):
assert
(
size
>=
1
)
assert
size
>=
1
if
self
.
only_trace_prefix
:
raise
StopTracing
()
...
...
python/sglang/launch_server.py
View file @
0992d85f
...
...
@@ -2,7 +2,6 @@ import argparse
from
sglang.srt.server
import
ServerArgs
,
launch_server
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
ServerArgs
.
add_cli_args
(
parser
)
...
...
python/sglang/launch_server_llavavid.py
0 → 100644
View file @
0992d85f
import
argparse
import
multiprocessing
as
mp
from
sglang.srt.server
import
ServerArgs
,
launch_server
if
__name__
==
"__main__"
:
model_overide_args
=
{}
model_overide_args
[
"mm_spatial_pool_stride"
]
=
2
model_overide_args
[
"architectures"
]
=
[
"LlavaVidForCausalLM"
]
model_overide_args
[
"num_frames"
]
=
16
model_overide_args
[
"model_type"
]
=
"llavavid"
if
model_overide_args
[
"num_frames"
]
==
32
:
model_overide_args
[
"rope_scaling"
]
=
{
"factor"
:
2.0
,
"type"
:
"linear"
}
model_overide_args
[
"max_sequence_length"
]
=
4096
*
2
model_overide_args
[
"tokenizer_model_max_length"
]
=
4096
*
2
model_overide_args
[
"model_max_length"
]
=
4096
*
2
parser
=
argparse
.
ArgumentParser
()
ServerArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
if
"34b"
in
args
.
model_path
.
lower
():
model_overide_args
[
"image_token_index"
]
=
64002
server_args
=
ServerArgs
.
from_cli_args
(
args
)
pipe_reader
,
pipe_writer
=
mp
.
Pipe
(
duplex
=
False
)
launch_server
(
server_args
,
pipe_writer
,
model_overide_args
)
python/sglang/srt/hf_transformers_utils.py
View file @
0992d85f
...
...
@@ -30,10 +30,17 @@ def get_config_json(model_path: str):
return
config
def
get_config
(
model
:
str
,
trust_remote_code
:
bool
,
revision
:
Optional
[
str
]
=
None
):
def
get_config
(
model
:
str
,
trust_remote_code
:
bool
,
revision
:
Optional
[
str
]
=
None
,
model_overide_args
:
Optional
[
dict
]
=
None
,
):
config
=
AutoConfig
.
from_pretrained
(
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
)
if
model_overide_args
:
config
.
update
(
model_overide_args
)
return
config
...
...
python/sglang/srt/managers/router/manager.py
View file @
0992d85f
...
...
@@ -60,9 +60,7 @@ class RouterManager:
def
start_router_process
(
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
pipe_writer
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
pipe_writer
,
model_overide_args
):
logging
.
basicConfig
(
level
=
getattr
(
logging
,
server_args
.
log_level
.
upper
()),
...
...
@@ -70,7 +68,7 @@ def start_router_process(
)
try
:
model_client
=
ModelRpcClient
(
server_args
,
port_args
)
model_client
=
ModelRpcClient
(
server_args
,
port_args
,
model_overide_args
)
router
=
RouterManager
(
model_client
,
port_args
)
except
Exception
:
pipe_writer
.
send
(
get_exception_traceback
())
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
0992d85f
...
...
@@ -4,12 +4,13 @@ import multiprocessing
import
time
import
warnings
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
List
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
rpyc
import
torch
from
rpyc.utils.classic
import
obtain
from
rpyc.utils.server
import
ThreadedServer
try
:
from
vllm.logger
import
_default_handler
as
vllm_default_logger
except
ImportError
:
...
...
@@ -48,6 +49,7 @@ class ModelRpcServer:
tp_rank
:
int
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
model_overide_args
:
Optional
[
dict
]
=
None
,
):
server_args
,
port_args
=
[
obtain
(
x
)
for
x
in
[
server_args
,
port_args
]]
...
...
@@ -62,6 +64,7 @@ class ModelRpcServer:
server_args
.
model_path
,
server_args
.
trust_remote_code
,
context_length
=
server_args
.
context_length
,
model_overide_args
=
model_overide_args
,
)
# For model end global settings
...
...
@@ -673,13 +676,15 @@ class ModelRpcService(rpyc.Service):
class
ModelRpcClient
:
def
__init__
(
self
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
):
def
__init__
(
self
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
model_overide_args
):
tp_size
=
server_args
.
tp_size
if
tp_size
==
1
:
# Init model
self
.
model_server
=
ModelRpcService
().
exposed_ModelRpcServer
(
0
,
server_args
,
port_args
0
,
server_args
,
port_args
,
model_overide_args
)
# Wrap functions
...
...
@@ -700,7 +705,7 @@ class ModelRpcClient:
# Init model
def
init_model
(
i
):
return
self
.
remote_services
[
i
].
ModelRpcServer
(
i
,
server_args
,
port_args
i
,
server_args
,
port_args
,
model_overide_args
)
self
.
model_servers
=
executor
.
map
(
init_model
,
range
(
tp_size
))
...
...
@@ -723,7 +728,11 @@ def _init_service(port):
t
=
ThreadedServer
(
ModelRpcService
(),
port
=
port
,
protocol_config
=
{
"allow_pickle"
:
True
,
"sync_request_timeout"
:
1800
},
protocol_config
=
{
"allow_public_attrs"
:
True
,
"allow_pickle"
:
True
,
"sync_request_timeout"
:
1800
,
},
)
t
.
start
()
...
...
@@ -739,7 +748,11 @@ def start_model_process(port):
con
=
rpyc
.
connect
(
"localhost"
,
port
,
config
=
{
"allow_pickle"
:
True
,
"sync_request_timeout"
:
1800
},
config
=
{
"allow_public_attrs"
:
True
,
"allow_pickle"
:
True
,
"sync_request_timeout"
:
1800
,
},
)
break
except
ConnectionRefusedError
:
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
0992d85f
...
...
@@ -9,11 +9,11 @@ from typing import List
import
numpy
as
np
import
torch
from
vllm.distributed
import
initialize_model_parallel
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.distributed
import
initialize_model_parallel
from
sglang.srt.managers.router.infer_batch
import
Batch
,
ForwardMode
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
...
...
@@ -143,7 +143,7 @@ class InputMetadata:
self
.
kv_last_page_len
,
self
.
model_runner
.
model_config
.
num_attention_heads
//
tp_size
,
self
.
model_runner
.
model_config
.
num_key_value_heads
//
tp_size
,
self
.
model_runner
.
model_config
.
head_dim
self
.
model_runner
.
model_config
.
head_dim
,
]
self
.
prefill_wrapper
.
begin_forward
(
*
args
)
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
0992d85f
...
...
@@ -60,7 +60,15 @@ def get_pixel_values(
):
try
:
processor
=
processor
or
global_processor
image
=
load_image
(
image_data
)
image
,
image_size
=
load_image
(
image_data
)
if
image_size
!=
None
:
image_hash
=
hash
(
image_data
)
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
]
for
_
in
range
(
len
(
pixel_values
)):
pixel_values
[
_
]
=
pixel_values
[
_
].
astype
(
np
.
float16
)
pixel_values
=
np
.
stack
(
pixel_values
,
axis
=
0
)
return
pixel_values
,
image_hash
,
image_size
else
:
image_hash
=
hash
(
image_data
)
if
image_aspect_ratio
==
"pad"
:
image
=
expand2square
(
...
...
@@ -84,6 +92,7 @@ class TokenizerManager:
self
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
model_overide_args
:
dict
=
None
,
):
self
.
server_args
=
server_args
...
...
@@ -96,7 +105,9 @@ class TokenizerManager:
self
.
model_path
=
server_args
.
model_path
self
.
hf_config
=
get_config
(
self
.
model_path
,
trust_remote_code
=
server_args
.
trust_remote_code
self
.
model_path
,
trust_remote_code
=
server_args
.
trust_remote_code
,
model_overide_args
=
model_overide_args
,
)
self
.
context_len
=
get_context_length
(
self
.
hf_config
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment