Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0992d85f
Unverified
Commit
0992d85f
authored
May 14, 2024
by
Yuanhan Zhang
Committed by
GitHub
May 13, 2024
Browse files
support llava video (#426)
parent
5dc55a5f
Changes
37
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
663 additions
and
183 deletions
+663
-183
python/sglang/srt/model_config.py
python/sglang/srt/model_config.py
+4
-0
python/sglang/srt/models/commandr.py
python/sglang/srt/models/commandr.py
+6
-10
python/sglang/srt/models/dbrx.py
python/sglang/srt/models/dbrx.py
+14
-15
python/sglang/srt/models/gemma.py
python/sglang/srt/models/gemma.py
+7
-10
python/sglang/srt/models/llama2.py
python/sglang/srt/models/llama2.py
+7
-10
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+2
-6
python/sglang/srt/models/llavavid.py
python/sglang/srt/models/llavavid.py
+307
-0
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+7
-13
python/sglang/srt/models/qwen.py
python/sglang/srt/models/qwen.py
+20
-13
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+7
-10
python/sglang/srt/models/stablelm.py
python/sglang/srt/models/stablelm.py
+13
-12
python/sglang/srt/models/yivl.py
python/sglang/srt/models/yivl.py
+1
-4
python/sglang/srt/server.py
python/sglang/srt/server.py
+26
-15
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-5
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+97
-8
python/sglang/srt/weight_utils.py
python/sglang/srt/weight_utils.py
+66
-51
python/sglang/utils.py
python/sglang/utils.py
+72
-1
No files found.
python/sglang/srt/model_config.py
View file @
0992d85f
...
...
@@ -10,12 +10,16 @@ class ModelConfig:
trust_remote_code
:
bool
=
True
,
revision
:
Optional
[
str
]
=
None
,
context_length
:
Optional
[
int
]
=
None
,
model_overide_args
:
Optional
[
dict
]
=
None
,
)
->
None
:
self
.
path
=
path
self
.
trust_remote_code
=
trust_remote_code
self
.
revision
=
revision
self
.
hf_config
=
get_config
(
self
.
path
,
trust_remote_code
,
revision
)
if
model_overide_args
is
not
None
:
self
.
hf_config
.
update
(
model_overide_args
)
if
context_length
is
not
None
:
self
.
context_len
=
context_length
else
:
...
...
python/sglang/srt/models/commandr.py
View file @
0992d85f
...
...
@@ -27,29 +27,25 @@ import torch.utils.checkpoint
from
torch
import
nn
from
torch.nn.parameter
import
Parameter
from
transformers
import
PretrainedConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
sglang.srt.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
@
torch
.
compile
...
...
python/sglang/srt/models/dbrx.py
View file @
0992d85f
...
...
@@ -5,37 +5,31 @@ from typing import Optional
import
torch
import
torch.nn
as
nn
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.distributed
import
(
tensor_model_parallel_all_reduce
,
)
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
sglang.srt.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.models.dbrx_config
import
DbrxConfig
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
DbrxRouter
(
nn
.
Module
):
...
...
@@ -291,7 +285,9 @@ class DbrxBlock(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
norm_attn_norm
=
DbrxFusedNormAttention
(
config
,
layer_id
,
quant_config
=
quant_config
)
self
.
norm_attn_norm
=
DbrxFusedNormAttention
(
config
,
layer_id
,
quant_config
=
quant_config
)
self
.
ffn
=
DbrxExperts
(
config
,
quant_config
=
quant_config
)
def
forward
(
...
...
@@ -322,7 +318,10 @@ class DbrxModel(nn.Module):
config
.
d_model
,
)
self
.
blocks
=
nn
.
ModuleList
(
[
DbrxBlock
(
config
,
i
,
quant_config
=
quant_config
)
for
i
in
range
(
config
.
n_layers
)]
[
DbrxBlock
(
config
,
i
,
quant_config
=
quant_config
)
for
i
in
range
(
config
.
n_layers
)
]
)
self
.
norm_f
=
nn
.
LayerNorm
(
config
.
d_model
,
eps
=
1e-5
)
for
module
in
self
.
modules
():
...
...
python/sglang/srt/models/gemma.py
View file @
0992d85f
...
...
@@ -7,6 +7,7 @@ import torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
...
...
@@ -14,21 +15,14 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
GemmaMLP
(
nn
.
Module
):
...
...
@@ -46,7 +40,10 @@ class GemmaMLP(nn.Module):
quant_config
=
quant_config
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
act_fn
=
GeluAndMul
()
...
...
python/sglang/srt/models/llama2.py
View file @
0992d85f
...
...
@@ -6,6 +6,7 @@ from typing import Any, Dict, Optional, Tuple
import
torch
from
torch
import
nn
from
transformers
import
LlamaConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
...
...
@@ -13,24 +14,17 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
LlamaMLP
(
nn
.
Module
):
...
...
@@ -49,7 +43,10 @@ class LlamaMLP(nn.Module):
quant_config
=
quant_config
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
...
...
python/sglang/srt/models/llava.py
View file @
0992d85f
...
...
@@ -7,12 +7,7 @@ import torch
from
torch
import
nn
from
transformers
import
CLIPVisionModel
,
LlavaConfig
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
sglang.srt.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.router.infer_batch
import
ForwardMode
from
sglang.srt.managers.router.model_runner
import
InputMetadata
...
...
@@ -22,6 +17,7 @@ from sglang.srt.mm_utils import (
unpad_image_shape
,
)
from
sglang.srt.models.llama2
import
LlamaForCausalLM
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
LlavaLlamaForCausalLM
(
nn
.
Module
):
...
...
python/sglang/srt/models/llavavid.py
0 → 100644
View file @
0992d85f
"""Inference-only LLaVa video model compatible with HuggingFace weights."""
import
os
from
typing
import
List
,
Optional
import
numpy
as
np
import
torch
from
torch
import
nn
from
transformers
import
CLIPVisionModel
,
LlamaConfig
,
LlavaConfig
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.router.infer_batch
import
ForwardMode
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.mm_utils
import
(
get_anyres_image_grid_shape
,
unpad_image
,
unpad_image_shape
,
)
from
sglang.srt.models.llama2
import
LlamaForCausalLM
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
LlavaVidForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LlavaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
vision_tower
=
None
self
.
config
.
vision_config
.
hidden_size
=
config
.
mm_hidden_size
self
.
config
.
text_config
.
hidden_size
=
config
.
hidden_size
self
.
multi_modal_projector
=
LlavaMultiModalProjector
(
config
)
self
.
mm_spatial_pool_stride
=
getattr
(
self
.
config
,
"mm_spatial_pool_stride"
,
2
)
self
.
resampler
=
nn
.
AvgPool2d
(
kernel_size
=
self
.
mm_spatial_pool_stride
,
stride
=
self
.
mm_spatial_pool_stride
)
self
.
language_model
=
LlamaForCausalLM
(
config
,
quant_config
=
quant_config
)
self
.
num_frames
=
getattr
(
self
.
config
,
"num_frames"
,
16
)
if
"unpad"
in
getattr
(
config
,
"mm_patch_merge_type"
,
""
):
self
.
language_model
.
model
.
image_newline
=
nn
.
Parameter
(
torch
.
empty
(
config
.
text_config
.
hidden_size
,
dtype
=
torch
.
float16
)
)
def
pad_input_ids
(
self
,
input_ids
,
pad_value
,
pt_shape
=
None
,
image_size
=
None
):
new_image_feature_len
=
self
.
image_feature_len
# now only support spatial_unpad + anyres
# if self.mm_patch_merge_type.startswith("spatial"):
# height = width = self.num_patches_per_side
# if pt_shape[0] > 1:
# if self.image_aspect_ratio == "anyres":
# num_patch_width, num_patch_height = get_anyres_image_grid_shape(
# image_size,
# self.image_grid_pinpoints,
# self.vision_tower.config.image_size,
# )
# if "unpad" in self.mm_patch_merge_type:
# h = num_patch_height * height
# w = num_patch_width * width
# new_h, new_w = unpad_image_shape(h, w, image_size)
# new_image_feature_len += new_h * (new_w + 1)
pad_ids
=
pad_value
*
(
(
new_image_feature_len
+
len
(
pad_value
))
//
len
(
pad_value
)
)
# print(input_ids)
offset
=
input_ids
.
index
(
self
.
config
.
image_token_index
)
# old_len + pad_len - 1, because we need to remove image_token_id
new_input_ids
=
(
input_ids
[:
offset
]
+
pad_ids
[:
new_image_feature_len
]
+
input_ids
[
offset
+
1
:]
)
return
new_input_ids
,
offset
def
encode_images
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
image_outputs
=
self
.
vision_tower
(
pixel_values
,
output_hidden_states
=
True
)
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
selected_image_feature
=
image_outputs
.
hidden_states
[
self
.
vision_feature_layer
]
if
self
.
vision_feature_select_strategy
in
[
"default"
,
"patch"
]:
selected_image_feature
=
selected_image_feature
[:,
1
:]
elif
self
.
vision_feature_select_strategy
==
"full"
:
selected_image_feature
=
selected_image_feature
else
:
raise
ValueError
(
f
"Unexpected select feature strategy:
{
self
.
config
.
vision_feature_select_strategy
}
"
)
height
=
width
=
self
.
num_patches_per_side
num_of_frames
=
selected_image_feature
.
shape
[
0
]
selected_image_feature
=
selected_image_feature
.
view
(
num_of_frames
,
height
,
width
,
-
1
)
selected_image_feature
=
selected_image_feature
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
selected_image_feature
=
(
self
.
resampler
(
selected_image_feature
)
.
flatten
(
2
)
.
transpose
(
1
,
2
)
.
contiguous
()
)
image_features
=
self
.
multi_modal_projector
(
selected_image_feature
)
return
image_features
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
pixel_values
:
Optional
[
List
[
Optional
[
np
.
array
]]]
=
None
,
image_sizes
:
Optional
[
List
[
List
[
int
]]]
=
None
,
image_offsets
:
Optional
[
List
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
if
input_metadata
.
forward_mode
==
ForwardMode
.
EXTEND
:
bs
=
input_metadata
.
batch_size
# Embed text input
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
# Embed vision input
need_vision
=
(
(
positions
[
input_metadata
.
extend_start_loc
]
<
self
.
image_feature_len
)
.
cpu
()
.
numpy
()
)
# FIXME: We need to substract the length of the system prompt
has_pixel
=
np
.
array
([
pixel_values
[
i
]
is
not
None
for
i
in
range
(
bs
)])
need_vision
=
need_vision
&
has_pixel
if
need_vision
.
any
():
pixel_values
=
[
pixel_values
[
i
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]]
image_sizes
=
[
image_sizes
[
i
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]]
########## Encode Image ########
if
pixel_values
[
0
].
ndim
==
4
:
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
np
.
concatenate
(
pixel_values
,
axis
=
0
)
# ndim=4
concat_images
=
torch
.
tensor
(
np
.
concatenate
(
pixel_values
,
axis
=
0
),
device
=
self
.
vision_tower
.
device
,
)
# image_features = self.encode_images(concat_images)
# split_sizes = [image.shape[0] for image in pixel_values]
# image_features = torch.split(image_features, split_sizes, dim=0)
image_features
=
self
.
encode_images
(
concat_images
)
# , prompts)#, image_counts, long_video=long_video)
split_sizes
=
[
image
.
shape
[
0
]
for
image
in
pixel_values
]
image_features
=
torch
.
split
(
image_features
,
split_sizes
,
dim
=
0
)
# hd image_features: BS, num_patch, 576, 4096
else
:
# normal pixel: BS, C=3, H=336, W=336
pixel_values
=
torch
.
tensor
(
np
.
array
(
pixel_values
),
device
=
self
.
vision_tower
.
device
)
image_features
=
self
.
encode_images
(
pixel_values
)
# image_features: BS, 576, 4096
new_image_features
=
[]
for
image_idx
,
image_feature
in
enumerate
(
image_features
):
new_image_features
.
append
(
image_feature
.
flatten
(
0
,
1
))
image_features
=
new_image_features
extend_start_loc_cpu
=
input_metadata
.
extend_start_loc
.
cpu
().
numpy
()
pt
=
0
for
i
in
range
(
bs
):
if
not
need_vision
[
i
]:
continue
start_idx
=
extend_start_loc_cpu
[
i
]
pad_len
,
pad_dim
=
image_features
[
pt
].
shape
# 576, 4096
dim
=
input_embeds
.
shape
[
1
]
assert
(
pad_dim
==
dim
),
"invalid pad_dim={}, input_embed_dim={}!"
.
format
(
pad_dim
,
dim
)
# Fill in the placeholder for the image
try
:
input_embeds
[
start_idx
+
image_offsets
[
i
]
:
start_idx
+
image_offsets
[
i
]
+
pad_len
]
=
image_features
[
pt
]
except
RuntimeError
as
e
:
print
(
f
"RuntimeError in llava image encoding:
{
e
}
"
)
print
(
input_embeds
.
shape
)
print
(
start_idx
,
image_offsets
[
i
])
pt
+=
1
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
=
input_embeds
)
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
)
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path
=
self
.
config
.
mm_vision_tower
self
.
vision_tower
=
CLIPVisionModel
.
from_pretrained
(
vision_path
,
torch_dtype
=
torch
.
float16
).
cuda
()
self
.
vision_tower
.
eval
()
self
.
vision_feature_layer
=
self
.
config
.
mm_vision_select_layer
self
.
vision_feature_select_strategy
=
self
.
config
.
mm_vision_select_feature
self
.
image_size
=
self
.
vision_tower
.
config
.
image_size
self
.
patch_size
=
self
.
vision_tower
.
config
.
patch_size
self
.
mm_patch_merge_type
=
getattr
(
self
.
config
,
"mm_patch_merge_type"
,
"flat"
)
self
.
image_aspect_ratio
=
getattr
(
self
.
config
,
"image_aspect_ratio"
,
"square"
)
self
.
image_grid_pinpoints
=
getattr
(
self
.
config
,
"image_grid_pinpoints"
,
None
)
print
(
f
"target_frames:
{
self
.
num_frames
}
"
)
self
.
image_feature_len
=
self
.
num_frames
*
int
(
(
self
.
image_size
/
self
.
patch_size
/
self
.
mm_spatial_pool_stride
)
**
2
)
if
self
.
vision_feature_select_strategy
==
"patch"
:
pass
elif
self
.
vision_feature_select_strategy
==
"cls_patch"
:
self
.
image_feature_len
+=
1
else
:
raise
ValueError
(
f
"Unexpected select feature:
{
self
.
select_feature
}
"
)
# load mm_projector
projector_weights
=
{
"model.mm_projector.0"
:
"multi_modal_projector.linear_1"
,
"model.mm_projector.2"
:
"multi_modal_projector.linear_2"
,
"model.vision_resampler.mm_projector.0"
:
"multi_modal_projector.linear_1"
,
"model.vision_resampler.mm_projector.2"
:
"multi_modal_projector.linear_2"
,
"model.vision_tower.vision_tower"
:
"vision_tower"
,
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
}
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
# FIXME: why projector weights read two times?
if
"projector"
in
name
or
"vision_tower"
in
name
:
for
weight_name
,
param_name
in
projector_weights
.
items
():
if
weight_name
in
name
:
name
=
name
.
replace
(
weight_name
,
param_name
)
if
name
in
params_dict
:
param
=
params_dict
[
name
]
else
:
print
(
f
"Warning:
{
name
}
not found in the model"
)
continue
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load language model
self
.
language_model
.
load_weights
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
)
monkey_path_clip_vision_embed_forward
()
@
property
def
num_patches_per_side
(
self
):
return
self
.
image_size
//
self
.
patch_size
first_call
=
True
def
clip_vision_embed_forward
(
self
,
pixel_values
:
torch
.
FloatTensor
)
->
torch
.
Tensor
:
batch_size
=
pixel_values
.
shape
[
0
]
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
global
first_call
if
first_call
:
self
.
patch_embedding
.
cpu
().
float
()
first_call
=
False
pixel_values
=
pixel_values
.
to
(
dtype
=
torch
.
float32
,
device
=
"cpu"
)
patch_embeds
=
self
.
patch_embedding
(
pixel_values
).
cuda
().
half
()
patch_embeds
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
class_embeds
=
self
.
class_embedding
.
expand
(
batch_size
,
1
,
-
1
)
embeddings
=
torch
.
cat
([
class_embeds
,
patch_embeds
],
dim
=
1
)
embeddings
=
embeddings
+
self
.
position_embedding
(
self
.
position_ids
)
return
embeddings
def
monkey_path_clip_vision_embed_forward
():
import
transformers
setattr
(
transformers
.
models
.
clip
.
modeling_clip
.
CLIPVisionEmbeddings
,
"forward"
,
clip_vision_embed_forward
,
)
EntryClass
=
LlavaVidForCausalLM
python/sglang/srt/models/mixtral.py
View file @
0992d85f
...
...
@@ -8,34 +8,28 @@ import torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
transformers
import
MixtralConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.distributed
import
(
tensor_model_parallel_all_reduce
,
)
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
MixtralMLP
(
nn
.
Module
):
...
...
python/sglang/srt/models/qwen.py
View file @
0992d85f
...
...
@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
...
...
@@ -10,24 +11,17 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
QWenMLP
(
nn
.
Module
):
...
...
@@ -132,7 +126,12 @@ class QWenAttention(nn.Module):
class
QWenBlock
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
layer_id
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,):
def
__init__
(
self
,
config
:
PretrainedConfig
,
layer_id
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
ln_1
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
...
...
@@ -181,7 +180,11 @@ class QWenBlock(nn.Module):
class
QWenModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
vocab_size
=
config
.
vocab_size
...
...
@@ -218,7 +221,11 @@ class QWenModel(nn.Module):
class
QWenLMHeadModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
transformer
=
QWenModel
(
config
,
quant_config
=
quant_config
)
...
...
@@ -276,4 +283,4 @@ class QWenLMHeadModel(nn.Module):
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
QWenLMHeadModel
\ No newline at end of file
EntryClass
=
QWenLMHeadModel
python/sglang/srt/models/qwen2.py
View file @
0992d85f
...
...
@@ -5,6 +5,7 @@ from typing import Any, Dict, Optional, Tuple
import
torch
from
torch
import
nn
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
...
...
@@ -12,24 +13,17 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
Qwen2Config
=
None
...
...
@@ -50,7 +44,10 @@ class Qwen2MLP(nn.Module):
quant_config
=
quant_config
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
...
...
python/sglang/srt/models/stablelm.py
View file @
0992d85f
...
...
@@ -7,35 +7,31 @@ from typing import Optional, Tuple
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
StablelmMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
@@ -48,7 +44,10 @@ class StablelmMLP(nn.Module):
quant_config
=
quant_config
,
)
self
.
down_proj
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
config
.
intermediate_size
,
config
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
act_fn
=
SiluAndMul
()
...
...
@@ -181,7 +180,9 @@ class StablelmDecoderLayer(nn.Module):
class
StableLMEpochModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
embed_tokens
=
VocabParallelEmbedding
(
...
...
python/sglang/srt/models/yivl.py
View file @
0992d85f
...
...
@@ -6,16 +6,13 @@ from typing import List, Optional
import
torch
import
torch.nn
as
nn
from
transformers
import
CLIPVisionModel
,
LlavaConfig
from
sglang.srt.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
from
sglang.srt.models.llava
import
(
LlavaLlamaForCausalLM
,
clip_vision_embed_forward
,
monkey_path_clip_vision_embed_forward
,
)
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
YiVLForCausalLM
(
LlavaLlamaForCausalLM
):
...
...
python/sglang/srt/server.py
View file @
0992d85f
...
...
@@ -107,7 +107,7 @@ async def openai_v1_chat_completions(raw_request: Request):
return
await
v1_chat_completions
(
tokenizer_manager
,
raw_request
)
def
launch_server
(
server_args
:
ServerArgs
,
pipe_finish_writer
):
def
launch_server
(
server_args
:
ServerArgs
,
pipe_finish_writer
,
model_overide_args
=
None
):
global
tokenizer_manager
logging
.
basicConfig
(
...
...
@@ -140,17 +140,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
)
# Launch processes
tokenizer_manager
=
TokenizerManager
(
server_args
,
port_args
)
tokenizer_manager
=
TokenizerManager
(
server_args
,
port_args
,
model_overide_args
)
pipe_router_reader
,
pipe_router_writer
=
mp
.
Pipe
(
duplex
=
False
)
pipe_detoken_reader
,
pipe_detoken_writer
=
mp
.
Pipe
(
duplex
=
False
)
proc_router
=
mp
.
Process
(
target
=
start_router_process
,
args
=
(
server_args
,
port_args
,
pipe_router_writer
,
),
args
=
(
server_args
,
port_args
,
pipe_router_writer
,
model_overide_args
),
)
proc_router
.
start
()
proc_detoken
=
mp
.
Process
(
...
...
@@ -170,8 +166,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
if
router_init_state
!=
"init ok"
or
detoken_init_state
!=
"init ok"
:
proc_router
.
kill
()
proc_detoken
.
kill
()
print
(
f
"Initialization failed. router_init_state:
{
router_init_state
}
"
,
flush
=
True
)
print
(
f
"Initialization failed. detoken_init_state:
{
detoken_init_state
}
"
,
flush
=
True
)
print
(
f
"Initialization failed. router_init_state:
{
router_init_state
}
"
,
flush
=
True
)
print
(
f
"Initialization failed. detoken_init_state:
{
detoken_init_state
}
"
,
flush
=
True
,
)
sys
.
exit
(
1
)
assert
proc_router
.
is_alive
()
and
proc_detoken
.
is_alive
()
...
...
@@ -189,6 +190,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
time
.
sleep
(
0.5
)
try
:
requests
.
get
(
url
+
"/get_model_info"
,
timeout
=
5
,
headers
=
headers
)
success
=
True
# Set flag to True if request succeeds
break
except
requests
.
exceptions
.
RequestException
as
e
:
pass
...
...
@@ -205,7 +207,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
},
},
headers
=
headers
,
timeout
=
60
,
timeout
=
60
0
,
)
assert
res
.
status_code
==
200
except
Exception
as
e
:
...
...
@@ -235,7 +237,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
class
Runtime
:
def
__init__
(
self
,
log_evel
=
"error"
,
log_evel
:
str
=
"error"
,
model_overide_args
:
Optional
[
dict
]
=
None
,
*
args
,
**
kwargs
,
):
...
...
@@ -244,7 +247,10 @@ class Runtime:
# Pre-allocate ports
self
.
server_args
.
port
,
self
.
server_args
.
additional_ports
=
allocate_init_ports
(
self
.
server_args
.
port
,
self
.
server_args
.
additional_ports
,
self
.
server_args
.
tp_size
)
self
.
server_args
.
port
,
self
.
server_args
.
additional_ports
,
self
.
server_args
.
tp_size
,
)
self
.
url
=
self
.
server_args
.
url
()
self
.
generate_url
=
(
...
...
@@ -253,7 +259,10 @@ class Runtime:
self
.
pid
=
None
pipe_reader
,
pipe_writer
=
mp
.
Pipe
(
duplex
=
False
)
proc
=
mp
.
Process
(
target
=
launch_server
,
args
=
(
self
.
server_args
,
pipe_writer
))
proc
=
mp
.
Process
(
target
=
launch_server
,
args
=
(
self
.
server_args
,
pipe_writer
,
model_overide_args
),
)
proc
.
start
()
pipe_writer
.
close
()
self
.
pid
=
proc
.
pid
...
...
@@ -265,7 +274,9 @@ class Runtime:
if
init_state
!=
"init ok"
:
self
.
shutdown
()
raise
RuntimeError
(
"Initialization failed. Please see the error messages above."
)
raise
RuntimeError
(
"Initialization failed. Please see the error messages above."
)
self
.
endpoint
=
RuntimeEndpoint
(
self
.
url
)
...
...
@@ -317,4 +328,4 @@ class Runtime:
pos
+=
len
(
cur
)
def
__del__
(
self
):
self
.
shutdown
()
\ No newline at end of file
self
.
shutdown
()
python/sglang/srt/server_args.py
View file @
0992d85f
...
...
@@ -80,10 +80,12 @@ class ServerArgs:
default
=
ServerArgs
.
tokenizer_path
,
help
=
"The path of the tokenizer."
,
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
ServerArgs
.
host
,
help
=
"The host of the server."
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
ServerArgs
.
port
,
help
=
"The port of the server."
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
ServerArgs
.
host
,
help
=
"The host of the server."
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
ServerArgs
.
port
,
help
=
"The port of the server."
)
parser
.
add_argument
(
"--additional-ports"
,
type
=
int
,
...
...
@@ -261,4 +263,4 @@ class PortArgs:
router_port
:
int
detokenizer_port
:
int
nccl_port
:
int
model_rpc_ports
:
List
[
int
]
\ No newline at end of file
model_rpc_ports
:
List
[
int
]
python/sglang/srt/utils.py
View file @
0992d85f
...
...
@@ -131,11 +131,13 @@ def alloc_usable_network_port(num, used_list=()):
continue
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
setsockopt
(
socket
.
SOL_SOCKET
,
socket
.
SO_REUSEADDR
,
1
)
try
:
s
.
bind
((
""
,
port
))
s
.
listen
(
1
)
# Attempt to listen on the port
port_list
.
append
(
port
)
except
socket
.
error
:
pass
pass
# If any error occurs, this port is not usable
if
len
(
port_list
)
==
num
:
return
port_list
...
...
@@ -265,20 +267,102 @@ def wrap_kernel_launcher(kernel):
def
is_multimodal_model
(
model
):
if
isinstance
(
model
,
str
):
return
"llava"
in
model
or
"yi-vl"
in
model
from
sglang.srt.model_config
import
ModelConfig
if
isinstance
(
model
,
str
):
model
=
model
.
lower
()
return
"llava"
in
model
or
"yi-vl"
in
model
or
"llava-next"
in
model
if
isinstance
(
model
,
ModelConfig
):
model_path
=
model
.
path
.
lower
()
return
"llava"
in
model_path
or
"yi-vl"
in
model_path
raise
Exception
(
"unrecognized type"
)
return
"llava"
in
model_path
or
"yi-vl"
in
model_path
or
"llava-next"
in
model_path
raise
ValueError
(
"unrecognized type"
)
def
decode_video_base64
(
video_base64
):
from
PIL
import
Image
# Decode the base64 string
video_bytes
=
base64
.
b64decode
(
video_base64
)
# Placeholder for the start indices of each PNG image
img_starts
=
[]
frame_format
=
"PNG"
# str(os.getenv('FRAME_FORMAT', "JPEG"))
assert
frame_format
in
[
"PNG"
,
"JPEG"
,
],
"FRAME_FORMAT must be either 'PNG' or 'JPEG'"
if
frame_format
==
"PNG"
:
# Find each PNG start signature to isolate images
i
=
0
while
i
<
len
(
video_bytes
)
-
7
:
# Adjusted for the length of the PNG signature
# Check if we found the start of a PNG file
if
(
video_bytes
[
i
]
==
0x89
and
video_bytes
[
i
+
1
]
==
0x50
and
video_bytes
[
i
+
2
]
==
0x4E
and
video_bytes
[
i
+
3
]
==
0x47
and
video_bytes
[
i
+
4
]
==
0x0D
and
video_bytes
[
i
+
5
]
==
0x0A
and
video_bytes
[
i
+
6
]
==
0x1A
and
video_bytes
[
i
+
7
]
==
0x0A
):
img_starts
.
append
(
i
)
i
+=
8
# Skip the PNG signature
else
:
i
+=
1
else
:
# Find each JPEG start (0xFFD8) to isolate images
i
=
0
while
(
i
<
len
(
video_bytes
)
-
1
):
# Adjusted for the length of the JPEG SOI signature
# Check if we found the start of a JPEG file
if
video_bytes
[
i
]
==
0xFF
and
video_bytes
[
i
+
1
]
==
0xD8
:
img_starts
.
append
(
i
)
# Move to the next byte to continue searching for the next image start
i
+=
2
else
:
i
+=
1
frames
=
[]
for
start_idx
in
img_starts
:
# Assuming each image is back-to-back, the end of one image is the start of another
# The last image goes until the end of the byte string
end_idx
=
(
img_starts
[
img_starts
.
index
(
start_idx
)
+
1
]
if
img_starts
.
index
(
start_idx
)
+
1
<
len
(
img_starts
)
else
len
(
video_bytes
)
)
img_bytes
=
video_bytes
[
start_idx
:
end_idx
]
# Convert bytes to a PIL Image
img
=
Image
.
open
(
BytesIO
(
img_bytes
))
# Convert PIL Image to a NumPy array
frame
=
np
.
array
(
img
)
# Append the frame to the list of frames
frames
.
append
(
frame
)
# Ensure there's at least one frame to avoid errors with np.stack
if
frames
:
return
np
.
stack
(
frames
,
axis
=
0
),
img
.
size
else
:
return
np
.
array
([]),
(
0
,
0
,
)
# Return an empty array and size tuple if no frames were found
def
load_image
(
image_file
):
from
PIL
import
Image
image
=
None
image
=
image_size
=
None
if
image_file
.
startswith
(
"http://"
)
or
image_file
.
startswith
(
"https://"
):
timeout
=
int
(
os
.
getenv
(
"REQUEST_TIMEOUT"
,
"3"
))
...
...
@@ -289,10 +373,13 @@ def load_image(image_file):
elif
image_file
.
startswith
(
"data:"
):
image_file
=
image_file
.
split
(
","
)[
1
]
image
=
Image
.
open
(
BytesIO
(
base64
.
b64decode
(
image_file
)))
elif
image_file
.
startswith
(
"video:"
):
image_file
=
image_file
.
replace
(
"video:"
,
""
)
image
,
image_size
=
decode_video_base64
(
image_file
)
else
:
image
=
Image
.
open
(
BytesIO
(
base64
.
b64decode
(
image_file
)))
return
image
return
image
,
image_size
def
assert_pkg_version
(
pkg
:
str
,
min_version
:
str
):
...
...
@@ -304,7 +391,9 @@ def assert_pkg_version(pkg: str, min_version: str):
f
"is less than the minimum required version
{
min_version
}
"
)
except
PackageNotFoundError
:
raise
Exception
(
f
"
{
pkg
}
with minimum required version
{
min_version
}
is not installed"
)
raise
Exception
(
f
"
{
pkg
}
with minimum required version
{
min_version
}
is not installed"
)
API_KEY_HEADER_NAME
=
"X-API-Key"
...
...
python/sglang/srt/weight_utils.py
View file @
0992d85f
...
...
@@ -19,11 +19,12 @@ import torch
from
huggingface_hub
import
HfFileSystem
,
snapshot_download
from
safetensors.torch
import
load_file
,
safe_open
,
save_file
from
tqdm.auto
import
tqdm
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
(
QuantizationConfig
,
get_quantization_config
)
from
vllm.model_executor.layers.quantization
import
(
QuantizationConfig
,
get_quantization_config
,
)
from
vllm.model_executor.layers.quantization.schema
import
QuantParamSchema
logger
=
init_logger
(
__name__
)
...
...
@@ -32,17 +33,21 @@ logger = init_logger(__name__)
# can share the same lock without error.
# lock files in the temp directory will be automatically deleted when the
# system reboots, so users will not complain about annoying lock files
temp_dir
=
os
.
environ
.
get
(
'TMPDIR'
)
or
os
.
environ
.
get
(
'TEMP'
)
or
os
.
environ
.
get
(
'TMP'
)
or
"/tmp/"
temp_dir
=
(
os
.
environ
.
get
(
"TMPDIR"
)
or
os
.
environ
.
get
(
"TEMP"
)
or
os
.
environ
.
get
(
"TMP"
)
or
"/tmp/"
)
def
enable_hf_transfer
():
"""automatically activates hf_transfer
"""
"""automatically activates hf_transfer"""
if
"HF_HUB_ENABLE_HF_TRANSFER"
not
in
os
.
environ
:
try
:
# enable hf hub transfer if available
import
hf_transfer
# type: ignore # noqa
huggingface_hub
.
constants
.
HF_HUB_ENABLE_HF_TRANSFER
=
True
except
ImportError
:
pass
...
...
@@ -65,8 +70,7 @@ def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
# add hash to avoid conflict with old users' lock files
lock_file_name
=
hash_name
+
model_name
+
".lock"
# mode 0o666 is required for the filelock to be shared across users
lock
=
filelock
.
FileLock
(
os
.
path
.
join
(
lock_dir
,
lock_file_name
),
mode
=
0o666
)
lock
=
filelock
.
FileLock
(
os
.
path
.
join
(
lock_dir
,
lock_file_name
),
mode
=
0o666
)
return
lock
...
...
@@ -104,10 +108,12 @@ def convert_bin_to_safetensor_file(
sf_size
=
os
.
stat
(
sf_filename
).
st_size
pt_size
=
os
.
stat
(
pt_filename
).
st_size
if
(
sf_size
-
pt_size
)
/
pt_size
>
0.01
:
raise
RuntimeError
(
f
"""The file size different is more than 1%:
raise
RuntimeError
(
f
"""The file size different is more than 1%:
-
{
sf_filename
}
:
{
sf_size
}
-
{
pt_filename
}
:
{
pt_size
}
"""
)
"""
)
# check if the tensors are the same
reloaded
=
load_file
(
sf_filename
)
...
...
@@ -122,8 +128,7 @@ def convert_bin_to_safetensor_file(
def
get_quant_config
(
model_config
:
ModelConfig
)
->
QuantizationConfig
:
quant_cls
=
get_quantization_config
(
model_config
.
quantization
)
# Read the quantization config from the HF model config, if available.
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
None
)
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
None
)
if
hf_quant_config
is
not
None
:
return
quant_cls
.
from_config
(
hf_quant_config
)
model_name_or_path
=
model_config
.
model
...
...
@@ -131,26 +136,29 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
if
not
is_local
:
# Download the config files.
with
get_lock
(
model_name_or_path
,
model_config
.
download_dir
):
hf_folder
=
snapshot_download
(
model_name_or_path
,
revision
=
model_config
.
revision
,
allow_patterns
=
"*.json"
,
cache_dir
=
model_config
.
download_dir
,
tqdm_class
=
Disabledtqdm
)
hf_folder
=
snapshot_download
(
model_name_or_path
,
revision
=
model_config
.
revision
,
allow_patterns
=
"*.json"
,
cache_dir
=
model_config
.
download_dir
,
tqdm_class
=
Disabledtqdm
,
)
else
:
hf_folder
=
model_name_or_path
config_files
=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
"*.json"
))
quant_config_files
=
[
f
for
f
in
config_files
if
any
(
f
.
endswith
(
x
)
for
x
in
quant_cls
.
get_config_filenames
())
f
for
f
in
config_files
if
any
(
f
.
endswith
(
x
)
for
x
in
quant_cls
.
get_config_filenames
())
]
if
len
(
quant_config_files
)
==
0
:
raise
ValueError
(
f
"Cannot find the config file for
{
model_config
.
quantization
}
"
)
raise
ValueError
(
f
"Cannot find the config file for
{
model_config
.
quantization
}
"
)
if
len
(
quant_config_files
)
>
1
:
raise
ValueError
(
f
"Found multiple config files for
{
model_config
.
quantization
}
: "
f
"
{
quant_config_files
}
"
)
f
"
{
quant_config_files
}
"
)
quant_config_file
=
quant_config_files
[
0
]
with
open
(
quant_config_file
,
"r"
)
as
f
:
...
...
@@ -166,8 +174,7 @@ def prepare_hf_model_weights(
revision
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
# Download model weights from huggingface.
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
\
and
load_format
!=
"tensorizer"
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
and
load_format
!=
"tensorizer"
use_safetensors
=
False
# Some quantized models use .pt files for storing the weights.
if
load_format
==
"auto"
:
...
...
@@ -203,11 +210,13 @@ def prepare_hf_model_weights(
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with
get_lock
(
model_name_or_path
,
cache_dir
):
hf_folder
=
snapshot_download
(
model_name_or_path
,
allow_patterns
=
allow_patterns
,
cache_dir
=
cache_dir
,
tqdm_class
=
Disabledtqdm
,
revision
=
revision
)
hf_folder
=
snapshot_download
(
model_name_or_path
,
allow_patterns
=
allow_patterns
,
cache_dir
=
cache_dir
,
tqdm_class
=
Disabledtqdm
,
revision
=
revision
,
)
else
:
hf_folder
=
model_name_or_path
hf_weights_files
:
List
[
str
]
=
[]
...
...
@@ -228,16 +237,14 @@ def prepare_hf_model_weights(
"scaler.pt"
,
]
hf_weights_files
=
[
f
for
f
in
hf_weights_files
if
not
any
(
f
.
endswith
(
x
)
for
x
in
blacklist
)
f
for
f
in
hf_weights_files
if
not
any
(
f
.
endswith
(
x
)
for
x
in
blacklist
)
]
if
load_format
==
"tensorizer"
:
return
hf_folder
,
hf_weights_files
,
use_safetensors
if
len
(
hf_weights_files
)
==
0
:
raise
RuntimeError
(
f
"Cannot find any model weights with `
{
model_name_or_path
}
`"
)
raise
RuntimeError
(
f
"Cannot find any model weights with `
{
model_name_or_path
}
`"
)
return
hf_folder
,
hf_weights_files
,
use_safetensors
...
...
@@ -254,7 +261,8 @@ def hf_model_weights_iterator(
cache_dir
=
cache_dir
,
load_format
=
load_format
,
fall_back_to_pt
=
fall_back_to_pt
,
revision
=
revision
)
revision
=
revision
,
)
if
load_format
==
"npcache"
:
# Currently np_cache only support *.bin checkpoints
...
...
@@ -289,22 +297,25 @@ def hf_model_weights_iterator(
param
=
np
.
load
(
f
)
yield
name
,
torch
.
from_numpy
(
param
)
elif
load_format
==
"tensorizer"
:
from
vllm.model_executor.tensorizer_loader
import
(
TensorDeserializer
,
open_stream
,
tensorizer_warning
)
from
vllm.model_executor.tensorizer_loader
import
(
TensorDeserializer
,
open_stream
,
tensorizer_warning
,
)
tensorizer_args
=
load_format
.
params
tensorizer_warning
(
"Deserializing HuggingFace models is not optimized for "
"loading on vLLM, as tensorizer is forced to load to CPU. "
"Consider deserializing a vLLM model instead for faster "
"load times. See the examples/tensorize_vllm_model.py example "
"script for serializing vLLM models."
)
"script for serializing vLLM models."
)
deserializer_args
=
tensorizer_args
.
deserializer_params
stream_params
=
tensorizer_args
.
stream_params
stream
=
open_stream
(
tensorizer_args
.
tensorizer_uri
,
**
stream_params
)
with
TensorDeserializer
(
stream
,
**
deserializer_args
,
device
=
"cpu"
)
as
state
:
with
TensorDeserializer
(
stream
,
**
deserializer_args
,
device
=
"cpu"
)
as
state
:
for
name
,
param
in
state
.
items
():
yield
name
,
param
del
state
...
...
@@ -324,8 +335,12 @@ def hf_model_weights_iterator(
def
kv_cache_scales_loader
(
filename
:
str
,
tp_rank
:
int
,
tp_size
:
int
,
num_hidden_layers
:
int
,
model_type
:
Optional
[
str
])
->
Iterable
[
Tuple
[
int
,
float
]]:
filename
:
str
,
tp_rank
:
int
,
tp_size
:
int
,
num_hidden_layers
:
int
,
model_type
:
Optional
[
str
],
)
->
Iterable
[
Tuple
[
int
,
float
]]:
"""
A simple utility to read in KV cache scaling factors that have been
previously serialized to disk. Used by the model to populate the appropriate
...
...
@@ -343,8 +358,7 @@ def kv_cache_scales_loader(
"tp_size"
:
tp_size
,
}
schema_dct
=
json
.
load
(
f
)
schema
=
QuantParamSchema
.
model_validate
(
schema_dct
,
context
=
context
)
schema
=
QuantParamSchema
.
model_validate
(
schema_dct
,
context
=
context
)
layer_scales_map
=
schema
.
kv_cache
.
scaling_factor
[
tp_rank
]
return
layer_scales_map
.
items
()
...
...
@@ -357,9 +371,11 @@ def kv_cache_scales_loader(
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
logger
.
warning
(
"Defaulting to KV cache scaling factors = 1.0 "
f
"for all layers in TP rank
{
tp_rank
}
"
"as an error occurred during loading."
)
logger
.
warning
(
"Defaulting to KV cache scaling factors = 1.0 "
f
"for all layers in TP rank
{
tp_rank
}
"
"as an error occurred during loading."
)
return
[]
...
...
@@ -378,8 +394,7 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
return
x
def
default_weight_loader
(
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
def
default_weight_loader
(
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
"""Default weight loader."""
assert
param
.
size
()
==
loaded_weight
.
size
()
param
.
data
.
copy_
(
loaded_weight
)
...
...
@@ -399,4 +414,4 @@ def initialize_dummy_weights(
"""
for
param
in
model
.
state_dict
().
values
():
if
torch
.
is_floating_point
(
param
):
param
.
data
.
uniform_
(
low
,
high
)
\ No newline at end of file
param
.
data
.
uniform_
(
low
,
high
)
python/sglang/utils.py
View file @
0992d85f
...
...
@@ -2,13 +2,16 @@
import
base64
import
json
import
os
import
sys
import
threading
import
traceback
import
urllib.request
from
concurrent.futures
import
ThreadPoolExecutor
from
io
import
BytesIO
from
json
import
dumps
import
numpy
as
np
import
requests
...
...
@@ -110,6 +113,74 @@ def encode_image_base64(image_path):
return
base64
.
b64encode
(
buffered
.
getvalue
()).
decode
(
"utf-8"
)
def
encode_frame
(
frame
):
import
cv2
# pip install opencv-python-headless
from
PIL
import
Image
# Convert the frame to RGB (OpenCV uses BGR by default)
frame
=
cv2
.
cvtColor
(
frame
,
cv2
.
COLOR_BGR2RGB
)
# Convert the frame to PIL Image to easily convert to bytes
im_pil
=
Image
.
fromarray
(
frame
)
# Convert to bytes
buffered
=
BytesIO
()
# frame_format = str(os.getenv('FRAME_FORMAT', "JPEG"))
im_pil
.
save
(
buffered
,
format
=
"PNG"
)
frame_bytes
=
buffered
.
getvalue
()
# Return the bytes of the frame
return
frame_bytes
def
encode_video_base64
(
video_path
,
num_frames
=
16
):
import
cv2
cap
=
cv2
.
VideoCapture
(
video_path
)
if
not
cap
.
isOpened
():
raise
IOError
(
f
"Could not open video file:
{
video_path
}
"
)
total_frames
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_COUNT
))
print
(
f
"target_frames:
{
num_frames
}
"
)
frame_indices
=
np
.
linspace
(
0
,
total_frames
-
1
,
num_frames
,
dtype
=
int
)
frames
=
[]
for
i
in
range
(
total_frames
):
ret
,
frame
=
cap
.
read
()
if
ret
:
frames
.
append
(
frame
)
else
:
# Handle the case where the frame could not be read
# print(f"Warning: Could not read frame at index {i}.")
pass
cap
.
release
()
# Safely select frames based on frame_indices, avoiding IndexError
frames
=
[
frames
[
i
]
for
i
in
frame_indices
if
i
<
len
(
frames
)]
# If there are not enough frames, duplicate the last frame until we reach the target
while
len
(
frames
)
<
num_frames
:
frames
.
append
(
frames
[
-
1
])
# Use ThreadPoolExecutor to process and encode frames in parallel
with
ThreadPoolExecutor
()
as
executor
:
encoded_frames
=
list
(
executor
.
map
(
encode_frame
,
frames
))
# encoded_frames = list(map(encode_frame, frames))
# Concatenate all frames bytes
video_bytes
=
b
""
.
join
(
encoded_frames
)
# Encode the concatenated bytes to base64
video_base64
=
"video:"
+
base64
.
b64encode
(
video_bytes
).
decode
(
"utf-8"
)
return
video_base64
def
_is_chinese_char
(
cp
):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
...
...
@@ -170,4 +241,4 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
if
not
ret_value
:
raise
RuntimeError
()
return
ret_value
[
0
]
\ No newline at end of file
return
ret_value
[
0
]
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment