Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0992d85f
Unverified
Commit
0992d85f
authored
May 14, 2024
by
Yuanhan Zhang
Committed by
GitHub
May 13, 2024
Browse files
support llava video (#426)
parent
5dc55a5f
Changes
37
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
663 additions
and
183 deletions
+663
-183
python/sglang/srt/model_config.py
python/sglang/srt/model_config.py
+4
-0
python/sglang/srt/models/commandr.py
python/sglang/srt/models/commandr.py
+6
-10
python/sglang/srt/models/dbrx.py
python/sglang/srt/models/dbrx.py
+14
-15
python/sglang/srt/models/gemma.py
python/sglang/srt/models/gemma.py
+7
-10
python/sglang/srt/models/llama2.py
python/sglang/srt/models/llama2.py
+7
-10
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+2
-6
python/sglang/srt/models/llavavid.py
python/sglang/srt/models/llavavid.py
+307
-0
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+7
-13
python/sglang/srt/models/qwen.py
python/sglang/srt/models/qwen.py
+20
-13
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+7
-10
python/sglang/srt/models/stablelm.py
python/sglang/srt/models/stablelm.py
+13
-12
python/sglang/srt/models/yivl.py
python/sglang/srt/models/yivl.py
+1
-4
python/sglang/srt/server.py
python/sglang/srt/server.py
+26
-15
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-5
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+97
-8
python/sglang/srt/weight_utils.py
python/sglang/srt/weight_utils.py
+66
-51
python/sglang/utils.py
python/sglang/utils.py
+72
-1
No files found.
python/sglang/srt/model_config.py
View file @
0992d85f
...
@@ -10,12 +10,16 @@ class ModelConfig:
...
@@ -10,12 +10,16 @@ class ModelConfig:
trust_remote_code
:
bool
=
True
,
trust_remote_code
:
bool
=
True
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
context_length
:
Optional
[
int
]
=
None
,
context_length
:
Optional
[
int
]
=
None
,
model_overide_args
:
Optional
[
dict
]
=
None
,
)
->
None
:
)
->
None
:
self
.
path
=
path
self
.
path
=
path
self
.
trust_remote_code
=
trust_remote_code
self
.
trust_remote_code
=
trust_remote_code
self
.
revision
=
revision
self
.
revision
=
revision
self
.
hf_config
=
get_config
(
self
.
path
,
trust_remote_code
,
revision
)
self
.
hf_config
=
get_config
(
self
.
path
,
trust_remote_code
,
revision
)
if
model_overide_args
is
not
None
:
self
.
hf_config
.
update
(
model_overide_args
)
if
context_length
is
not
None
:
if
context_length
is
not
None
:
self
.
context_len
=
context_length
self
.
context_len
=
context_length
else
:
else
:
...
...
python/sglang/srt/models/commandr.py
View file @
0992d85f
...
@@ -27,29 +27,25 @@ import torch.utils.checkpoint
...
@@ -27,29 +27,25 @@ import torch.utils.checkpoint
from
torch
import
nn
from
torch
import
nn
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
sglang.srt.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
@
torch
.
compile
@
torch
.
compile
...
...
python/sglang/srt/models/dbrx.py
View file @
0992d85f
...
@@ -5,37 +5,31 @@ from typing import Optional
...
@@ -5,37 +5,31 @@ from typing import Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
vllm.distributed
import
(
tensor_model_parallel_all_reduce
,
)
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
sglang.srt.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.models.dbrx_config
import
DbrxConfig
from
sglang.srt.models.dbrx_config
import
DbrxConfig
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
DbrxRouter
(
nn
.
Module
):
class
DbrxRouter
(
nn
.
Module
):
...
@@ -291,7 +285,9 @@ class DbrxBlock(nn.Module):
...
@@ -291,7 +285,9 @@ class DbrxBlock(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
norm_attn_norm
=
DbrxFusedNormAttention
(
config
,
layer_id
,
quant_config
=
quant_config
)
self
.
norm_attn_norm
=
DbrxFusedNormAttention
(
config
,
layer_id
,
quant_config
=
quant_config
)
self
.
ffn
=
DbrxExperts
(
config
,
quant_config
=
quant_config
)
self
.
ffn
=
DbrxExperts
(
config
,
quant_config
=
quant_config
)
def
forward
(
def
forward
(
...
@@ -322,7 +318,10 @@ class DbrxModel(nn.Module):
...
@@ -322,7 +318,10 @@ class DbrxModel(nn.Module):
config
.
d_model
,
config
.
d_model
,
)
)
self
.
blocks
=
nn
.
ModuleList
(
self
.
blocks
=
nn
.
ModuleList
(
[
DbrxBlock
(
config
,
i
,
quant_config
=
quant_config
)
for
i
in
range
(
config
.
n_layers
)]
[
DbrxBlock
(
config
,
i
,
quant_config
=
quant_config
)
for
i
in
range
(
config
.
n_layers
)
]
)
)
self
.
norm_f
=
nn
.
LayerNorm
(
config
.
d_model
,
eps
=
1e-5
)
self
.
norm_f
=
nn
.
LayerNorm
(
config
.
d_model
,
eps
=
1e-5
)
for
module
in
self
.
modules
():
for
module
in
self
.
modules
():
...
...
python/sglang/srt/models/gemma.py
View file @
0992d85f
...
@@ -7,6 +7,7 @@ import torch
...
@@ -7,6 +7,7 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.config
import
LoRAConfig
from
vllm.config
import
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
...
@@ -14,21 +15,14 @@ from vllm.model_executor.layers.linear import (
...
@@ -14,21 +15,14 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
GemmaMLP
(
nn
.
Module
):
class
GemmaMLP
(
nn
.
Module
):
...
@@ -46,7 +40,10 @@ class GemmaMLP(nn.Module):
...
@@ -46,7 +40,10 @@ class GemmaMLP(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
self
.
down_proj
=
RowParallelLinear
(
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
)
)
self
.
act_fn
=
GeluAndMul
()
self
.
act_fn
=
GeluAndMul
()
...
...
python/sglang/srt/models/llama2.py
View file @
0992d85f
...
@@ -6,6 +6,7 @@ from typing import Any, Dict, Optional, Tuple
...
@@ -6,6 +6,7 @@ from typing import Any, Dict, Optional, Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
LlamaConfig
from
transformers
import
LlamaConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
...
@@ -13,24 +14,17 @@ from vllm.model_executor.layers.linear import (
...
@@ -13,24 +14,17 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
LlamaMLP
(
nn
.
Module
):
class
LlamaMLP
(
nn
.
Module
):
...
@@ -49,7 +43,10 @@ class LlamaMLP(nn.Module):
...
@@ -49,7 +43,10 @@ class LlamaMLP(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
self
.
down_proj
=
RowParallelLinear
(
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
)
)
if
hidden_act
!=
"silu"
:
if
hidden_act
!=
"silu"
:
raise
ValueError
(
raise
ValueError
(
...
...
python/sglang/srt/models/llava.py
View file @
0992d85f
...
@@ -7,12 +7,7 @@ import torch
...
@@ -7,12 +7,7 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
CLIPVisionModel
,
LlavaConfig
from
transformers
import
CLIPVisionModel
,
LlavaConfig
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
QuantizationConfig
)
from
sglang.srt.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
from
sglang.srt.managers.router.infer_batch
import
ForwardMode
from
sglang.srt.managers.router.infer_batch
import
ForwardMode
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.managers.router.model_runner
import
InputMetadata
...
@@ -22,6 +17,7 @@ from sglang.srt.mm_utils import (
...
@@ -22,6 +17,7 @@ from sglang.srt.mm_utils import (
unpad_image_shape
,
unpad_image_shape
,
)
)
from
sglang.srt.models.llama2
import
LlamaForCausalLM
from
sglang.srt.models.llama2
import
LlamaForCausalLM
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
LlavaLlamaForCausalLM
(
nn
.
Module
):
class
LlavaLlamaForCausalLM
(
nn
.
Module
):
...
...
python/sglang/srt/models/llavavid.py
0 → 100644
View file @
0992d85f
"""Inference-only LLaVa video model compatible with HuggingFace weights."""
import
os
from
typing
import
List
,
Optional
import
numpy
as
np
import
torch
from
torch
import
nn
from
transformers
import
CLIPVisionModel
,
LlamaConfig
,
LlavaConfig
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.router.infer_batch
import
ForwardMode
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.mm_utils
import
(
get_anyres_image_grid_shape
,
unpad_image
,
unpad_image_shape
,
)
from
sglang.srt.models.llama2
import
LlamaForCausalLM
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
LlavaVidForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LlavaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
vision_tower
=
None
self
.
config
.
vision_config
.
hidden_size
=
config
.
mm_hidden_size
self
.
config
.
text_config
.
hidden_size
=
config
.
hidden_size
self
.
multi_modal_projector
=
LlavaMultiModalProjector
(
config
)
self
.
mm_spatial_pool_stride
=
getattr
(
self
.
config
,
"mm_spatial_pool_stride"
,
2
)
self
.
resampler
=
nn
.
AvgPool2d
(
kernel_size
=
self
.
mm_spatial_pool_stride
,
stride
=
self
.
mm_spatial_pool_stride
)
self
.
language_model
=
LlamaForCausalLM
(
config
,
quant_config
=
quant_config
)
self
.
num_frames
=
getattr
(
self
.
config
,
"num_frames"
,
16
)
if
"unpad"
in
getattr
(
config
,
"mm_patch_merge_type"
,
""
):
self
.
language_model
.
model
.
image_newline
=
nn
.
Parameter
(
torch
.
empty
(
config
.
text_config
.
hidden_size
,
dtype
=
torch
.
float16
)
)
def
pad_input_ids
(
self
,
input_ids
,
pad_value
,
pt_shape
=
None
,
image_size
=
None
):
new_image_feature_len
=
self
.
image_feature_len
# now only support spatial_unpad + anyres
# if self.mm_patch_merge_type.startswith("spatial"):
# height = width = self.num_patches_per_side
# if pt_shape[0] > 1:
# if self.image_aspect_ratio == "anyres":
# num_patch_width, num_patch_height = get_anyres_image_grid_shape(
# image_size,
# self.image_grid_pinpoints,
# self.vision_tower.config.image_size,
# )
# if "unpad" in self.mm_patch_merge_type:
# h = num_patch_height * height
# w = num_patch_width * width
# new_h, new_w = unpad_image_shape(h, w, image_size)
# new_image_feature_len += new_h * (new_w + 1)
pad_ids
=
pad_value
*
(
(
new_image_feature_len
+
len
(
pad_value
))
//
len
(
pad_value
)
)
# print(input_ids)
offset
=
input_ids
.
index
(
self
.
config
.
image_token_index
)
# old_len + pad_len - 1, because we need to remove image_token_id
new_input_ids
=
(
input_ids
[:
offset
]
+
pad_ids
[:
new_image_feature_len
]
+
input_ids
[
offset
+
1
:]
)
return
new_input_ids
,
offset
def
encode_images
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
image_outputs
=
self
.
vision_tower
(
pixel_values
,
output_hidden_states
=
True
)
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
selected_image_feature
=
image_outputs
.
hidden_states
[
self
.
vision_feature_layer
]
if
self
.
vision_feature_select_strategy
in
[
"default"
,
"patch"
]:
selected_image_feature
=
selected_image_feature
[:,
1
:]
elif
self
.
vision_feature_select_strategy
==
"full"
:
selected_image_feature
=
selected_image_feature
else
:
raise
ValueError
(
f
"Unexpected select feature strategy:
{
self
.
config
.
vision_feature_select_strategy
}
"
)
height
=
width
=
self
.
num_patches_per_side
num_of_frames
=
selected_image_feature
.
shape
[
0
]
selected_image_feature
=
selected_image_feature
.
view
(
num_of_frames
,
height
,
width
,
-
1
)
selected_image_feature
=
selected_image_feature
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
selected_image_feature
=
(
self
.
resampler
(
selected_image_feature
)
.
flatten
(
2
)
.
transpose
(
1
,
2
)
.
contiguous
()
)
image_features
=
self
.
multi_modal_projector
(
selected_image_feature
)
return
image_features
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
pixel_values
:
Optional
[
List
[
Optional
[
np
.
array
]]]
=
None
,
image_sizes
:
Optional
[
List
[
List
[
int
]]]
=
None
,
image_offsets
:
Optional
[
List
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
if
input_metadata
.
forward_mode
==
ForwardMode
.
EXTEND
:
bs
=
input_metadata
.
batch_size
# Embed text input
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
# Embed vision input
need_vision
=
(
(
positions
[
input_metadata
.
extend_start_loc
]
<
self
.
image_feature_len
)
.
cpu
()
.
numpy
()
)
# FIXME: We need to substract the length of the system prompt
has_pixel
=
np
.
array
([
pixel_values
[
i
]
is
not
None
for
i
in
range
(
bs
)])
need_vision
=
need_vision
&
has_pixel
if
need_vision
.
any
():
pixel_values
=
[
pixel_values
[
i
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]]
image_sizes
=
[
image_sizes
[
i
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]]
########## Encode Image ########
if
pixel_values
[
0
].
ndim
==
4
:
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
np
.
concatenate
(
pixel_values
,
axis
=
0
)
# ndim=4
concat_images
=
torch
.
tensor
(
np
.
concatenate
(
pixel_values
,
axis
=
0
),
device
=
self
.
vision_tower
.
device
,
)
# image_features = self.encode_images(concat_images)
# split_sizes = [image.shape[0] for image in pixel_values]
# image_features = torch.split(image_features, split_sizes, dim=0)
image_features
=
self
.
encode_images
(
concat_images
)
# , prompts)#, image_counts, long_video=long_video)
split_sizes
=
[
image
.
shape
[
0
]
for
image
in
pixel_values
]
image_features
=
torch
.
split
(
image_features
,
split_sizes
,
dim
=
0
)
# hd image_features: BS, num_patch, 576, 4096
else
:
# normal pixel: BS, C=3, H=336, W=336
pixel_values
=
torch
.
tensor
(
np
.
array
(
pixel_values
),
device
=
self
.
vision_tower
.
device
)
image_features
=
self
.
encode_images
(
pixel_values
)
# image_features: BS, 576, 4096
new_image_features
=
[]
for
image_idx
,
image_feature
in
enumerate
(
image_features
):
new_image_features
.
append
(
image_feature
.
flatten
(
0
,
1
))
image_features
=
new_image_features
extend_start_loc_cpu
=
input_metadata
.
extend_start_loc
.
cpu
().
numpy
()
pt
=
0
for
i
in
range
(
bs
):
if
not
need_vision
[
i
]:
continue
start_idx
=
extend_start_loc_cpu
[
i
]
pad_len
,
pad_dim
=
image_features
[
pt
].
shape
# 576, 4096
dim
=
input_embeds
.
shape
[
1
]
assert
(
pad_dim
==
dim
),
"invalid pad_dim={}, input_embed_dim={}!"
.
format
(
pad_dim
,
dim
)
# Fill in the placeholder for the image
try
:
input_embeds
[
start_idx
+
image_offsets
[
i
]
:
start_idx
+
image_offsets
[
i
]
+
pad_len
]
=
image_features
[
pt
]
except
RuntimeError
as
e
:
print
(
f
"RuntimeError in llava image encoding:
{
e
}
"
)
print
(
input_embeds
.
shape
)
print
(
start_idx
,
image_offsets
[
i
])
pt
+=
1
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
=
input_embeds
)
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
)
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path
=
self
.
config
.
mm_vision_tower
self
.
vision_tower
=
CLIPVisionModel
.
from_pretrained
(
vision_path
,
torch_dtype
=
torch
.
float16
).
cuda
()
self
.
vision_tower
.
eval
()
self
.
vision_feature_layer
=
self
.
config
.
mm_vision_select_layer
self
.
vision_feature_select_strategy
=
self
.
config
.
mm_vision_select_feature
self
.
image_size
=
self
.
vision_tower
.
config
.
image_size
self
.
patch_size
=
self
.
vision_tower
.
config
.
patch_size
self
.
mm_patch_merge_type
=
getattr
(
self
.
config
,
"mm_patch_merge_type"
,
"flat"
)
self
.
image_aspect_ratio
=
getattr
(
self
.
config
,
"image_aspect_ratio"
,
"square"
)
self
.
image_grid_pinpoints
=
getattr
(
self
.
config
,
"image_grid_pinpoints"
,
None
)
print
(
f
"target_frames:
{
self
.
num_frames
}
"
)
self
.
image_feature_len
=
self
.
num_frames
*
int
(
(
self
.
image_size
/
self
.
patch_size
/
self
.
mm_spatial_pool_stride
)
**
2
)
if
self
.
vision_feature_select_strategy
==
"patch"
:
pass
elif
self
.
vision_feature_select_strategy
==
"cls_patch"
:
self
.
image_feature_len
+=
1
else
:
raise
ValueError
(
f
"Unexpected select feature:
{
self
.
select_feature
}
"
)
# load mm_projector
projector_weights
=
{
"model.mm_projector.0"
:
"multi_modal_projector.linear_1"
,
"model.mm_projector.2"
:
"multi_modal_projector.linear_2"
,
"model.vision_resampler.mm_projector.0"
:
"multi_modal_projector.linear_1"
,
"model.vision_resampler.mm_projector.2"
:
"multi_modal_projector.linear_2"
,
"model.vision_tower.vision_tower"
:
"vision_tower"
,
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
}
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
# FIXME: why projector weights read two times?
if
"projector"
in
name
or
"vision_tower"
in
name
:
for
weight_name
,
param_name
in
projector_weights
.
items
():
if
weight_name
in
name
:
name
=
name
.
replace
(
weight_name
,
param_name
)
if
name
in
params_dict
:
param
=
params_dict
[
name
]
else
:
print
(
f
"Warning:
{
name
}
not found in the model"
)
continue
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load language model
self
.
language_model
.
load_weights
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
)
monkey_path_clip_vision_embed_forward
()
@
property
def
num_patches_per_side
(
self
):
return
self
.
image_size
//
self
.
patch_size
first_call
=
True
def
clip_vision_embed_forward
(
self
,
pixel_values
:
torch
.
FloatTensor
)
->
torch
.
Tensor
:
batch_size
=
pixel_values
.
shape
[
0
]
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
global
first_call
if
first_call
:
self
.
patch_embedding
.
cpu
().
float
()
first_call
=
False
pixel_values
=
pixel_values
.
to
(
dtype
=
torch
.
float32
,
device
=
"cpu"
)
patch_embeds
=
self
.
patch_embedding
(
pixel_values
).
cuda
().
half
()
patch_embeds
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
class_embeds
=
self
.
class_embedding
.
expand
(
batch_size
,
1
,
-
1
)
embeddings
=
torch
.
cat
([
class_embeds
,
patch_embeds
],
dim
=
1
)
embeddings
=
embeddings
+
self
.
position_embedding
(
self
.
position_ids
)
return
embeddings
def
monkey_path_clip_vision_embed_forward
():
import
transformers
setattr
(
transformers
.
models
.
clip
.
modeling_clip
.
CLIPVisionEmbeddings
,
"forward"
,
clip_vision_embed_forward
,
)
EntryClass
=
LlavaVidForCausalLM
python/sglang/srt/models/mixtral.py
View file @
0992d85f
...
@@ -8,34 +8,28 @@ import torch
...
@@ -8,34 +8,28 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
from
transformers
import
MixtralConfig
from
transformers
import
MixtralConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
vllm.distributed
import
(
tensor_model_parallel_all_reduce
,
)
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
MixtralMLP
(
nn
.
Module
):
class
MixtralMLP
(
nn
.
Module
):
...
...
python/sglang/srt/models/qwen.py
View file @
0992d85f
...
@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
...
@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
...
@@ -10,24 +11,17 @@ from vllm.model_executor.layers.linear import (
...
@@ -10,24 +11,17 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
QWenMLP
(
nn
.
Module
):
class
QWenMLP
(
nn
.
Module
):
...
@@ -132,7 +126,12 @@ class QWenAttention(nn.Module):
...
@@ -132,7 +126,12 @@ class QWenAttention(nn.Module):
class
QWenBlock
(
nn
.
Module
):
class
QWenBlock
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
layer_id
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,):
def
__init__
(
self
,
config
:
PretrainedConfig
,
layer_id
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
super
().
__init__
()
self
.
ln_1
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_1
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
...
@@ -181,7 +180,11 @@ class QWenBlock(nn.Module):
...
@@ -181,7 +180,11 @@ class QWenBlock(nn.Module):
class
QWenModel
(
nn
.
Module
):
class
QWenModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
...
@@ -218,7 +221,11 @@ class QWenModel(nn.Module):
...
@@ -218,7 +221,11 @@ class QWenModel(nn.Module):
class
QWenLMHeadModel
(
nn
.
Module
):
class
QWenLMHeadModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
transformer
=
QWenModel
(
config
,
quant_config
=
quant_config
)
self
.
transformer
=
QWenModel
(
config
,
quant_config
=
quant_config
)
...
...
python/sglang/srt/models/qwen2.py
View file @
0992d85f
...
@@ -5,6 +5,7 @@ from typing import Any, Dict, Optional, Tuple
...
@@ -5,6 +5,7 @@ from typing import Any, Dict, Optional, Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
...
@@ -12,24 +13,17 @@ from vllm.model_executor.layers.linear import (
...
@@ -12,24 +13,17 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
Qwen2Config
=
None
Qwen2Config
=
None
...
@@ -50,7 +44,10 @@ class Qwen2MLP(nn.Module):
...
@@ -50,7 +44,10 @@ class Qwen2MLP(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
self
.
down_proj
=
RowParallelLinear
(
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
)
)
if
hidden_act
!=
"silu"
:
if
hidden_act
!=
"silu"
:
raise
ValueError
(
raise
ValueError
(
...
...
python/sglang/srt/models/stablelm.py
View file @
0992d85f
...
@@ -7,35 +7,31 @@ from typing import Optional, Tuple
...
@@ -7,35 +7,31 @@ from typing import Optional, Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
StablelmMLP
(
nn
.
Module
):
class
StablelmMLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -48,7 +44,10 @@ class StablelmMLP(nn.Module):
...
@@ -48,7 +44,10 @@ class StablelmMLP(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
self
.
down_proj
=
RowParallelLinear
(
self
.
down_proj
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
config
.
intermediate_size
,
config
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
)
)
self
.
act_fn
=
SiluAndMul
()
self
.
act_fn
=
SiluAndMul
()
...
@@ -181,7 +180,9 @@ class StablelmDecoderLayer(nn.Module):
...
@@ -181,7 +180,9 @@ class StablelmDecoderLayer(nn.Module):
class
StableLMEpochModel
(
nn
.
Module
):
class
StableLMEpochModel
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
...
...
python/sglang/srt/models/yivl.py
View file @
0992d85f
...
@@ -6,16 +6,13 @@ from typing import List, Optional
...
@@ -6,16 +6,13 @@ from typing import List, Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers
import
CLIPVisionModel
,
LlavaConfig
from
transformers
import
CLIPVisionModel
,
LlavaConfig
from
sglang.srt.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
from
sglang.srt.models.llava
import
(
from
sglang.srt.models.llava
import
(
LlavaLlamaForCausalLM
,
LlavaLlamaForCausalLM
,
clip_vision_embed_forward
,
clip_vision_embed_forward
,
monkey_path_clip_vision_embed_forward
,
monkey_path_clip_vision_embed_forward
,
)
)
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
YiVLForCausalLM
(
LlavaLlamaForCausalLM
):
class
YiVLForCausalLM
(
LlavaLlamaForCausalLM
):
...
...
python/sglang/srt/server.py
View file @
0992d85f
...
@@ -107,7 +107,7 @@ async def openai_v1_chat_completions(raw_request: Request):
...
@@ -107,7 +107,7 @@ async def openai_v1_chat_completions(raw_request: Request):
return
await
v1_chat_completions
(
tokenizer_manager
,
raw_request
)
return
await
v1_chat_completions
(
tokenizer_manager
,
raw_request
)
def
launch_server
(
server_args
:
ServerArgs
,
pipe_finish_writer
):
def
launch_server
(
server_args
:
ServerArgs
,
pipe_finish_writer
,
model_overide_args
=
None
):
global
tokenizer_manager
global
tokenizer_manager
logging
.
basicConfig
(
logging
.
basicConfig
(
...
@@ -140,17 +140,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
...
@@ -140,17 +140,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
)
)
# Launch processes
# Launch processes
tokenizer_manager
=
TokenizerManager
(
server_args
,
port_args
)
tokenizer_manager
=
TokenizerManager
(
server_args
,
port_args
,
model_overide_args
)
pipe_router_reader
,
pipe_router_writer
=
mp
.
Pipe
(
duplex
=
False
)
pipe_router_reader
,
pipe_router_writer
=
mp
.
Pipe
(
duplex
=
False
)
pipe_detoken_reader
,
pipe_detoken_writer
=
mp
.
Pipe
(
duplex
=
False
)
pipe_detoken_reader
,
pipe_detoken_writer
=
mp
.
Pipe
(
duplex
=
False
)
proc_router
=
mp
.
Process
(
proc_router
=
mp
.
Process
(
target
=
start_router_process
,
target
=
start_router_process
,
args
=
(
args
=
(
server_args
,
port_args
,
pipe_router_writer
,
model_overide_args
),
server_args
,
port_args
,
pipe_router_writer
,
),
)
)
proc_router
.
start
()
proc_router
.
start
()
proc_detoken
=
mp
.
Process
(
proc_detoken
=
mp
.
Process
(
...
@@ -170,8 +166,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
...
@@ -170,8 +166,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
if
router_init_state
!=
"init ok"
or
detoken_init_state
!=
"init ok"
:
if
router_init_state
!=
"init ok"
or
detoken_init_state
!=
"init ok"
:
proc_router
.
kill
()
proc_router
.
kill
()
proc_detoken
.
kill
()
proc_detoken
.
kill
()
print
(
f
"Initialization failed. router_init_state:
{
router_init_state
}
"
,
flush
=
True
)
print
(
print
(
f
"Initialization failed. detoken_init_state:
{
detoken_init_state
}
"
,
flush
=
True
)
f
"Initialization failed. router_init_state:
{
router_init_state
}
"
,
flush
=
True
)
print
(
f
"Initialization failed. detoken_init_state:
{
detoken_init_state
}
"
,
flush
=
True
,
)
sys
.
exit
(
1
)
sys
.
exit
(
1
)
assert
proc_router
.
is_alive
()
and
proc_detoken
.
is_alive
()
assert
proc_router
.
is_alive
()
and
proc_detoken
.
is_alive
()
...
@@ -189,6 +190,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
...
@@ -189,6 +190,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
time
.
sleep
(
0.5
)
time
.
sleep
(
0.5
)
try
:
try
:
requests
.
get
(
url
+
"/get_model_info"
,
timeout
=
5
,
headers
=
headers
)
requests
.
get
(
url
+
"/get_model_info"
,
timeout
=
5
,
headers
=
headers
)
success
=
True
# Set flag to True if request succeeds
break
break
except
requests
.
exceptions
.
RequestException
as
e
:
except
requests
.
exceptions
.
RequestException
as
e
:
pass
pass
...
@@ -205,7 +207,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
...
@@ -205,7 +207,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
},
},
},
},
headers
=
headers
,
headers
=
headers
,
timeout
=
60
,
timeout
=
60
0
,
)
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -235,7 +237,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
...
@@ -235,7 +237,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
class
Runtime
:
class
Runtime
:
def
__init__
(
def
__init__
(
self
,
self
,
log_evel
=
"error"
,
log_evel
:
str
=
"error"
,
model_overide_args
:
Optional
[
dict
]
=
None
,
*
args
,
*
args
,
**
kwargs
,
**
kwargs
,
):
):
...
@@ -244,7 +247,10 @@ class Runtime:
...
@@ -244,7 +247,10 @@ class Runtime:
# Pre-allocate ports
# Pre-allocate ports
self
.
server_args
.
port
,
self
.
server_args
.
additional_ports
=
allocate_init_ports
(
self
.
server_args
.
port
,
self
.
server_args
.
additional_ports
=
allocate_init_ports
(
self
.
server_args
.
port
,
self
.
server_args
.
additional_ports
,
self
.
server_args
.
tp_size
)
self
.
server_args
.
port
,
self
.
server_args
.
additional_ports
,
self
.
server_args
.
tp_size
,
)
self
.
url
=
self
.
server_args
.
url
()
self
.
url
=
self
.
server_args
.
url
()
self
.
generate_url
=
(
self
.
generate_url
=
(
...
@@ -253,7 +259,10 @@ class Runtime:
...
@@ -253,7 +259,10 @@ class Runtime:
self
.
pid
=
None
self
.
pid
=
None
pipe_reader
,
pipe_writer
=
mp
.
Pipe
(
duplex
=
False
)
pipe_reader
,
pipe_writer
=
mp
.
Pipe
(
duplex
=
False
)
proc
=
mp
.
Process
(
target
=
launch_server
,
args
=
(
self
.
server_args
,
pipe_writer
))
proc
=
mp
.
Process
(
target
=
launch_server
,
args
=
(
self
.
server_args
,
pipe_writer
,
model_overide_args
),
)
proc
.
start
()
proc
.
start
()
pipe_writer
.
close
()
pipe_writer
.
close
()
self
.
pid
=
proc
.
pid
self
.
pid
=
proc
.
pid
...
@@ -265,7 +274,9 @@ class Runtime:
...
@@ -265,7 +274,9 @@ class Runtime:
if
init_state
!=
"init ok"
:
if
init_state
!=
"init ok"
:
self
.
shutdown
()
self
.
shutdown
()
raise
RuntimeError
(
"Initialization failed. Please see the error messages above."
)
raise
RuntimeError
(
"Initialization failed. Please see the error messages above."
)
self
.
endpoint
=
RuntimeEndpoint
(
self
.
url
)
self
.
endpoint
=
RuntimeEndpoint
(
self
.
url
)
...
...
python/sglang/srt/server_args.py
View file @
0992d85f
...
@@ -80,10 +80,12 @@ class ServerArgs:
...
@@ -80,10 +80,12 @@ class ServerArgs:
default
=
ServerArgs
.
tokenizer_path
,
default
=
ServerArgs
.
tokenizer_path
,
help
=
"The path of the tokenizer."
,
help
=
"The path of the tokenizer."
,
)
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
ServerArgs
.
host
,
parser
.
add_argument
(
help
=
"The host of the server."
)
"--host"
,
type
=
str
,
default
=
ServerArgs
.
host
,
help
=
"The host of the server."
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
ServerArgs
.
port
,
)
help
=
"The port of the server."
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
ServerArgs
.
port
,
help
=
"The port of the server."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--additional-ports"
,
"--additional-ports"
,
type
=
int
,
type
=
int
,
...
...
python/sglang/srt/utils.py
View file @
0992d85f
...
@@ -131,11 +131,13 @@ def alloc_usable_network_port(num, used_list=()):
...
@@ -131,11 +131,13 @@ def alloc_usable_network_port(num, used_list=()):
continue
continue
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
setsockopt
(
socket
.
SOL_SOCKET
,
socket
.
SO_REUSEADDR
,
1
)
try
:
try
:
s
.
bind
((
""
,
port
))
s
.
bind
((
""
,
port
))
s
.
listen
(
1
)
# Attempt to listen on the port
port_list
.
append
(
port
)
port_list
.
append
(
port
)
except
socket
.
error
:
except
socket
.
error
:
pass
pass
# If any error occurs, this port is not usable
if
len
(
port_list
)
==
num
:
if
len
(
port_list
)
==
num
:
return
port_list
return
port_list
...
@@ -265,20 +267,102 @@ def wrap_kernel_launcher(kernel):
...
@@ -265,20 +267,102 @@ def wrap_kernel_launcher(kernel):
def
is_multimodal_model
(
model
):
def
is_multimodal_model
(
model
):
if
isinstance
(
model
,
str
):
return
"llava"
in
model
or
"yi-vl"
in
model
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_config
import
ModelConfig
if
isinstance
(
model
,
str
):
model
=
model
.
lower
()
return
"llava"
in
model
or
"yi-vl"
in
model
or
"llava-next"
in
model
if
isinstance
(
model
,
ModelConfig
):
if
isinstance
(
model
,
ModelConfig
):
model_path
=
model
.
path
.
lower
()
model_path
=
model
.
path
.
lower
()
return
"llava"
in
model_path
or
"yi-vl"
in
model_path
return
"llava"
in
model_path
or
"yi-vl"
in
model_path
or
"llava-next"
in
model_path
raise
Exception
(
"unrecognized type"
)
raise
ValueError
(
"unrecognized type"
)
def
decode_video_base64
(
video_base64
):
from
PIL
import
Image
# Decode the base64 string
video_bytes
=
base64
.
b64decode
(
video_base64
)
# Placeholder for the start indices of each PNG image
img_starts
=
[]
frame_format
=
"PNG"
# str(os.getenv('FRAME_FORMAT', "JPEG"))
assert
frame_format
in
[
"PNG"
,
"JPEG"
,
],
"FRAME_FORMAT must be either 'PNG' or 'JPEG'"
if
frame_format
==
"PNG"
:
# Find each PNG start signature to isolate images
i
=
0
while
i
<
len
(
video_bytes
)
-
7
:
# Adjusted for the length of the PNG signature
# Check if we found the start of a PNG file
if
(
video_bytes
[
i
]
==
0x89
and
video_bytes
[
i
+
1
]
==
0x50
and
video_bytes
[
i
+
2
]
==
0x4E
and
video_bytes
[
i
+
3
]
==
0x47
and
video_bytes
[
i
+
4
]
==
0x0D
and
video_bytes
[
i
+
5
]
==
0x0A
and
video_bytes
[
i
+
6
]
==
0x1A
and
video_bytes
[
i
+
7
]
==
0x0A
):
img_starts
.
append
(
i
)
i
+=
8
# Skip the PNG signature
else
:
i
+=
1
else
:
# Find each JPEG start (0xFFD8) to isolate images
i
=
0
while
(
i
<
len
(
video_bytes
)
-
1
):
# Adjusted for the length of the JPEG SOI signature
# Check if we found the start of a JPEG file
if
video_bytes
[
i
]
==
0xFF
and
video_bytes
[
i
+
1
]
==
0xD8
:
img_starts
.
append
(
i
)
# Move to the next byte to continue searching for the next image start
i
+=
2
else
:
i
+=
1
frames
=
[]
for
start_idx
in
img_starts
:
# Assuming each image is back-to-back, the end of one image is the start of another
# The last image goes until the end of the byte string
end_idx
=
(
img_starts
[
img_starts
.
index
(
start_idx
)
+
1
]
if
img_starts
.
index
(
start_idx
)
+
1
<
len
(
img_starts
)
else
len
(
video_bytes
)
)
img_bytes
=
video_bytes
[
start_idx
:
end_idx
]
# Convert bytes to a PIL Image
img
=
Image
.
open
(
BytesIO
(
img_bytes
))
# Convert PIL Image to a NumPy array
frame
=
np
.
array
(
img
)
# Append the frame to the list of frames
frames
.
append
(
frame
)
# Ensure there's at least one frame to avoid errors with np.stack
if
frames
:
return
np
.
stack
(
frames
,
axis
=
0
),
img
.
size
else
:
return
np
.
array
([]),
(
0
,
0
,
)
# Return an empty array and size tuple if no frames were found
def
load_image
(
image_file
):
def
load_image
(
image_file
):
from
PIL
import
Image
from
PIL
import
Image
image
=
None
image
=
image_size
=
None
if
image_file
.
startswith
(
"http://"
)
or
image_file
.
startswith
(
"https://"
):
if
image_file
.
startswith
(
"http://"
)
or
image_file
.
startswith
(
"https://"
):
timeout
=
int
(
os
.
getenv
(
"REQUEST_TIMEOUT"
,
"3"
))
timeout
=
int
(
os
.
getenv
(
"REQUEST_TIMEOUT"
,
"3"
))
...
@@ -289,10 +373,13 @@ def load_image(image_file):
...
@@ -289,10 +373,13 @@ def load_image(image_file):
elif
image_file
.
startswith
(
"data:"
):
elif
image_file
.
startswith
(
"data:"
):
image_file
=
image_file
.
split
(
","
)[
1
]
image_file
=
image_file
.
split
(
","
)[
1
]
image
=
Image
.
open
(
BytesIO
(
base64
.
b64decode
(
image_file
)))
image
=
Image
.
open
(
BytesIO
(
base64
.
b64decode
(
image_file
)))
elif
image_file
.
startswith
(
"video:"
):
image_file
=
image_file
.
replace
(
"video:"
,
""
)
image
,
image_size
=
decode_video_base64
(
image_file
)
else
:
else
:
image
=
Image
.
open
(
BytesIO
(
base64
.
b64decode
(
image_file
)))
image
=
Image
.
open
(
BytesIO
(
base64
.
b64decode
(
image_file
)))
return
image
return
image
,
image_size
def
assert_pkg_version
(
pkg
:
str
,
min_version
:
str
):
def
assert_pkg_version
(
pkg
:
str
,
min_version
:
str
):
...
@@ -304,7 +391,9 @@ def assert_pkg_version(pkg: str, min_version: str):
...
@@ -304,7 +391,9 @@ def assert_pkg_version(pkg: str, min_version: str):
f
"is less than the minimum required version
{
min_version
}
"
f
"is less than the minimum required version
{
min_version
}
"
)
)
except
PackageNotFoundError
:
except
PackageNotFoundError
:
raise
Exception
(
f
"
{
pkg
}
with minimum required version
{
min_version
}
is not installed"
)
raise
Exception
(
f
"
{
pkg
}
with minimum required version
{
min_version
}
is not installed"
)
API_KEY_HEADER_NAME
=
"X-API-Key"
API_KEY_HEADER_NAME
=
"X-API-Key"
...
...
python/sglang/srt/weight_utils.py
View file @
0992d85f
...
@@ -19,11 +19,12 @@ import torch
...
@@ -19,11 +19,12 @@ import torch
from
huggingface_hub
import
HfFileSystem
,
snapshot_download
from
huggingface_hub
import
HfFileSystem
,
snapshot_download
from
safetensors.torch
import
load_file
,
safe_open
,
save_file
from
safetensors.torch
import
load_file
,
safe_open
,
save_file
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
(
QuantizationConfig
,
from
vllm.model_executor.layers.quantization
import
(
get_quantization_config
)
QuantizationConfig
,
get_quantization_config
,
)
from
vllm.model_executor.layers.quantization.schema
import
QuantParamSchema
from
vllm.model_executor.layers.quantization.schema
import
QuantParamSchema
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -32,17 +33,21 @@ logger = init_logger(__name__)
...
@@ -32,17 +33,21 @@ logger = init_logger(__name__)
# can share the same lock without error.
# can share the same lock without error.
# lock files in the temp directory will be automatically deleted when the
# lock files in the temp directory will be automatically deleted when the
# system reboots, so users will not complain about annoying lock files
# system reboots, so users will not complain about annoying lock files
temp_dir
=
os
.
environ
.
get
(
'TMPDIR'
)
or
os
.
environ
.
get
(
temp_dir
=
(
'TEMP'
)
or
os
.
environ
.
get
(
'TMP'
)
or
"/tmp/"
os
.
environ
.
get
(
"TMPDIR"
)
or
os
.
environ
.
get
(
"TEMP"
)
or
os
.
environ
.
get
(
"TMP"
)
or
"/tmp/"
)
def
enable_hf_transfer
():
def
enable_hf_transfer
():
"""automatically activates hf_transfer
"""automatically activates hf_transfer"""
"""
if
"HF_HUB_ENABLE_HF_TRANSFER"
not
in
os
.
environ
:
if
"HF_HUB_ENABLE_HF_TRANSFER"
not
in
os
.
environ
:
try
:
try
:
# enable hf hub transfer if available
# enable hf hub transfer if available
import
hf_transfer
# type: ignore # noqa
import
hf_transfer
# type: ignore # noqa
huggingface_hub
.
constants
.
HF_HUB_ENABLE_HF_TRANSFER
=
True
huggingface_hub
.
constants
.
HF_HUB_ENABLE_HF_TRANSFER
=
True
except
ImportError
:
except
ImportError
:
pass
pass
...
@@ -65,8 +70,7 @@ def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
...
@@ -65,8 +70,7 @@ def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
# add hash to avoid conflict with old users' lock files
# add hash to avoid conflict with old users' lock files
lock_file_name
=
hash_name
+
model_name
+
".lock"
lock_file_name
=
hash_name
+
model_name
+
".lock"
# mode 0o666 is required for the filelock to be shared across users
# mode 0o666 is required for the filelock to be shared across users
lock
=
filelock
.
FileLock
(
os
.
path
.
join
(
lock_dir
,
lock_file_name
),
lock
=
filelock
.
FileLock
(
os
.
path
.
join
(
lock_dir
,
lock_file_name
),
mode
=
0o666
)
mode
=
0o666
)
return
lock
return
lock
...
@@ -104,10 +108,12 @@ def convert_bin_to_safetensor_file(
...
@@ -104,10 +108,12 @@ def convert_bin_to_safetensor_file(
sf_size
=
os
.
stat
(
sf_filename
).
st_size
sf_size
=
os
.
stat
(
sf_filename
).
st_size
pt_size
=
os
.
stat
(
pt_filename
).
st_size
pt_size
=
os
.
stat
(
pt_filename
).
st_size
if
(
sf_size
-
pt_size
)
/
pt_size
>
0.01
:
if
(
sf_size
-
pt_size
)
/
pt_size
>
0.01
:
raise
RuntimeError
(
f
"""The file size different is more than 1%:
raise
RuntimeError
(
f
"""The file size different is more than 1%:
-
{
sf_filename
}
:
{
sf_size
}
-
{
sf_filename
}
:
{
sf_size
}
-
{
pt_filename
}
:
{
pt_size
}
-
{
pt_filename
}
:
{
pt_size
}
"""
)
"""
)
# check if the tensors are the same
# check if the tensors are the same
reloaded
=
load_file
(
sf_filename
)
reloaded
=
load_file
(
sf_filename
)
...
@@ -122,8 +128,7 @@ def convert_bin_to_safetensor_file(
...
@@ -122,8 +128,7 @@ def convert_bin_to_safetensor_file(
def
get_quant_config
(
model_config
:
ModelConfig
)
->
QuantizationConfig
:
def
get_quant_config
(
model_config
:
ModelConfig
)
->
QuantizationConfig
:
quant_cls
=
get_quantization_config
(
model_config
.
quantization
)
quant_cls
=
get_quantization_config
(
model_config
.
quantization
)
# Read the quantization config from the HF model config, if available.
# Read the quantization config from the HF model config, if available.
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
None
)
None
)
if
hf_quant_config
is
not
None
:
if
hf_quant_config
is
not
None
:
return
quant_cls
.
from_config
(
hf_quant_config
)
return
quant_cls
.
from_config
(
hf_quant_config
)
model_name_or_path
=
model_config
.
model
model_name_or_path
=
model_config
.
model
...
@@ -131,26 +136,29 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
...
@@ -131,26 +136,29 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
if
not
is_local
:
if
not
is_local
:
# Download the config files.
# Download the config files.
with
get_lock
(
model_name_or_path
,
model_config
.
download_dir
):
with
get_lock
(
model_name_or_path
,
model_config
.
download_dir
):
hf_folder
=
snapshot_download
(
model_name_or_path
,
hf_folder
=
snapshot_download
(
model_name_or_path
,
revision
=
model_config
.
revision
,
revision
=
model_config
.
revision
,
allow_patterns
=
"*.json"
,
allow_patterns
=
"*.json"
,
cache_dir
=
model_config
.
download_dir
,
cache_dir
=
model_config
.
download_dir
,
tqdm_class
=
Disabledtqdm
)
tqdm_class
=
Disabledtqdm
,
)
else
:
else
:
hf_folder
=
model_name_or_path
hf_folder
=
model_name_or_path
config_files
=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
"*.json"
))
config_files
=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
"*.json"
))
quant_config_files
=
[
quant_config_files
=
[
f
for
f
in
config_files
if
any
(
f
f
.
endswith
(
x
)
for
x
in
quant_cls
.
get_config_filenames
())
for
f
in
config_files
if
any
(
f
.
endswith
(
x
)
for
x
in
quant_cls
.
get_config_filenames
())
]
]
if
len
(
quant_config_files
)
==
0
:
if
len
(
quant_config_files
)
==
0
:
raise
ValueError
(
raise
ValueError
(
f
"Cannot find the config file for
{
model_config
.
quantization
}
"
)
f
"Cannot find the config file for
{
model_config
.
quantization
}
"
)
if
len
(
quant_config_files
)
>
1
:
if
len
(
quant_config_files
)
>
1
:
raise
ValueError
(
raise
ValueError
(
f
"Found multiple config files for
{
model_config
.
quantization
}
: "
f
"Found multiple config files for
{
model_config
.
quantization
}
: "
f
"
{
quant_config_files
}
"
)
f
"
{
quant_config_files
}
"
)
quant_config_file
=
quant_config_files
[
0
]
quant_config_file
=
quant_config_files
[
0
]
with
open
(
quant_config_file
,
"r"
)
as
f
:
with
open
(
quant_config_file
,
"r"
)
as
f
:
...
@@ -166,8 +174,7 @@ def prepare_hf_model_weights(
...
@@ -166,8 +174,7 @@ def prepare_hf_model_weights(
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
# Download model weights from huggingface.
# Download model weights from huggingface.
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
\
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
and
load_format
!=
"tensorizer"
and
load_format
!=
"tensorizer"
use_safetensors
=
False
use_safetensors
=
False
# Some quantized models use .pt files for storing the weights.
# Some quantized models use .pt files for storing the weights.
if
load_format
==
"auto"
:
if
load_format
==
"auto"
:
...
@@ -203,11 +210,13 @@ def prepare_hf_model_weights(
...
@@ -203,11 +210,13 @@ def prepare_hf_model_weights(
# Use file lock to prevent multiple processes from
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
# downloading the same model weights at the same time.
with
get_lock
(
model_name_or_path
,
cache_dir
):
with
get_lock
(
model_name_or_path
,
cache_dir
):
hf_folder
=
snapshot_download
(
model_name_or_path
,
hf_folder
=
snapshot_download
(
model_name_or_path
,
allow_patterns
=
allow_patterns
,
allow_patterns
=
allow_patterns
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
tqdm_class
=
Disabledtqdm
,
tqdm_class
=
Disabledtqdm
,
revision
=
revision
)
revision
=
revision
,
)
else
:
else
:
hf_folder
=
model_name_or_path
hf_folder
=
model_name_or_path
hf_weights_files
:
List
[
str
]
=
[]
hf_weights_files
:
List
[
str
]
=
[]
...
@@ -228,16 +237,14 @@ def prepare_hf_model_weights(
...
@@ -228,16 +237,14 @@ def prepare_hf_model_weights(
"scaler.pt"
,
"scaler.pt"
,
]
]
hf_weights_files
=
[
hf_weights_files
=
[
f
for
f
in
hf_weights_files
f
for
f
in
hf_weights_files
if
not
any
(
f
.
endswith
(
x
)
for
x
in
blacklist
)
if
not
any
(
f
.
endswith
(
x
)
for
x
in
blacklist
)
]
]
if
load_format
==
"tensorizer"
:
if
load_format
==
"tensorizer"
:
return
hf_folder
,
hf_weights_files
,
use_safetensors
return
hf_folder
,
hf_weights_files
,
use_safetensors
if
len
(
hf_weights_files
)
==
0
:
if
len
(
hf_weights_files
)
==
0
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Cannot find any model weights with `
{
model_name_or_path
}
`"
)
f
"Cannot find any model weights with `
{
model_name_or_path
}
`"
)
return
hf_folder
,
hf_weights_files
,
use_safetensors
return
hf_folder
,
hf_weights_files
,
use_safetensors
...
@@ -254,7 +261,8 @@ def hf_model_weights_iterator(
...
@@ -254,7 +261,8 @@ def hf_model_weights_iterator(
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
load_format
=
load_format
,
load_format
=
load_format
,
fall_back_to_pt
=
fall_back_to_pt
,
fall_back_to_pt
=
fall_back_to_pt
,
revision
=
revision
)
revision
=
revision
,
)
if
load_format
==
"npcache"
:
if
load_format
==
"npcache"
:
# Currently np_cache only support *.bin checkpoints
# Currently np_cache only support *.bin checkpoints
...
@@ -289,22 +297,25 @@ def hf_model_weights_iterator(
...
@@ -289,22 +297,25 @@ def hf_model_weights_iterator(
param
=
np
.
load
(
f
)
param
=
np
.
load
(
f
)
yield
name
,
torch
.
from_numpy
(
param
)
yield
name
,
torch
.
from_numpy
(
param
)
elif
load_format
==
"tensorizer"
:
elif
load_format
==
"tensorizer"
:
from
vllm.model_executor.tensorizer_loader
import
(
TensorDeserializer
,
from
vllm.model_executor.tensorizer_loader
import
(
TensorDeserializer
,
open_stream
,
open_stream
,
tensorizer_warning
)
tensorizer_warning
,
)
tensorizer_args
=
load_format
.
params
tensorizer_args
=
load_format
.
params
tensorizer_warning
(
tensorizer_warning
(
"Deserializing HuggingFace models is not optimized for "
"Deserializing HuggingFace models is not optimized for "
"loading on vLLM, as tensorizer is forced to load to CPU. "
"loading on vLLM, as tensorizer is forced to load to CPU. "
"Consider deserializing a vLLM model instead for faster "
"Consider deserializing a vLLM model instead for faster "
"load times. See the examples/tensorize_vllm_model.py example "
"load times. See the examples/tensorize_vllm_model.py example "
"script for serializing vLLM models."
)
"script for serializing vLLM models."
)
deserializer_args
=
tensorizer_args
.
deserializer_params
deserializer_args
=
tensorizer_args
.
deserializer_params
stream_params
=
tensorizer_args
.
stream_params
stream_params
=
tensorizer_args
.
stream_params
stream
=
open_stream
(
tensorizer_args
.
tensorizer_uri
,
**
stream_params
)
stream
=
open_stream
(
tensorizer_args
.
tensorizer_uri
,
**
stream_params
)
with
TensorDeserializer
(
stream
,
**
deserializer_args
,
with
TensorDeserializer
(
stream
,
**
deserializer_args
,
device
=
"cpu"
)
as
state
:
device
=
"cpu"
)
as
state
:
for
name
,
param
in
state
.
items
():
for
name
,
param
in
state
.
items
():
yield
name
,
param
yield
name
,
param
del
state
del
state
...
@@ -324,8 +335,12 @@ def hf_model_weights_iterator(
...
@@ -324,8 +335,12 @@ def hf_model_weights_iterator(
def
kv_cache_scales_loader
(
def
kv_cache_scales_loader
(
filename
:
str
,
tp_rank
:
int
,
tp_size
:
int
,
num_hidden_layers
:
int
,
filename
:
str
,
model_type
:
Optional
[
str
])
->
Iterable
[
Tuple
[
int
,
float
]]:
tp_rank
:
int
,
tp_size
:
int
,
num_hidden_layers
:
int
,
model_type
:
Optional
[
str
],
)
->
Iterable
[
Tuple
[
int
,
float
]]:
"""
"""
A simple utility to read in KV cache scaling factors that have been
A simple utility to read in KV cache scaling factors that have been
previously serialized to disk. Used by the model to populate the appropriate
previously serialized to disk. Used by the model to populate the appropriate
...
@@ -343,8 +358,7 @@ def kv_cache_scales_loader(
...
@@ -343,8 +358,7 @@ def kv_cache_scales_loader(
"tp_size"
:
tp_size
,
"tp_size"
:
tp_size
,
}
}
schema_dct
=
json
.
load
(
f
)
schema_dct
=
json
.
load
(
f
)
schema
=
QuantParamSchema
.
model_validate
(
schema_dct
,
schema
=
QuantParamSchema
.
model_validate
(
schema_dct
,
context
=
context
)
context
=
context
)
layer_scales_map
=
schema
.
kv_cache
.
scaling_factor
[
tp_rank
]
layer_scales_map
=
schema
.
kv_cache
.
scaling_factor
[
tp_rank
]
return
layer_scales_map
.
items
()
return
layer_scales_map
.
items
()
...
@@ -357,9 +371,11 @@ def kv_cache_scales_loader(
...
@@ -357,9 +371,11 @@ def kv_cache_scales_loader(
# This section is reached if and only if any of the excepts are hit
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
# which ultimately defaults to 1.0 scales
logger
.
warning
(
"Defaulting to KV cache scaling factors = 1.0 "
logger
.
warning
(
"Defaulting to KV cache scaling factors = 1.0 "
f
"for all layers in TP rank
{
tp_rank
}
"
f
"for all layers in TP rank
{
tp_rank
}
"
"as an error occurred during loading."
)
"as an error occurred during loading."
)
return
[]
return
[]
...
@@ -378,8 +394,7 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
...
@@ -378,8 +394,7 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
return
x
return
x
def
default_weight_loader
(
param
:
torch
.
Tensor
,
def
default_weight_loader
(
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
loaded_weight
:
torch
.
Tensor
)
->
None
:
"""Default weight loader."""
"""Default weight loader."""
assert
param
.
size
()
==
loaded_weight
.
size
()
assert
param
.
size
()
==
loaded_weight
.
size
()
param
.
data
.
copy_
(
loaded_weight
)
param
.
data
.
copy_
(
loaded_weight
)
...
...
python/sglang/utils.py
View file @
0992d85f
...
@@ -2,13 +2,16 @@
...
@@ -2,13 +2,16 @@
import
base64
import
base64
import
json
import
json
import
os
import
sys
import
sys
import
threading
import
threading
import
traceback
import
traceback
import
urllib.request
import
urllib.request
from
concurrent.futures
import
ThreadPoolExecutor
from
io
import
BytesIO
from
io
import
BytesIO
from
json
import
dumps
from
json
import
dumps
import
numpy
as
np
import
requests
import
requests
...
@@ -110,6 +113,74 @@ def encode_image_base64(image_path):
...
@@ -110,6 +113,74 @@ def encode_image_base64(image_path):
return
base64
.
b64encode
(
buffered
.
getvalue
()).
decode
(
"utf-8"
)
return
base64
.
b64encode
(
buffered
.
getvalue
()).
decode
(
"utf-8"
)
def
encode_frame
(
frame
):
import
cv2
# pip install opencv-python-headless
from
PIL
import
Image
# Convert the frame to RGB (OpenCV uses BGR by default)
frame
=
cv2
.
cvtColor
(
frame
,
cv2
.
COLOR_BGR2RGB
)
# Convert the frame to PIL Image to easily convert to bytes
im_pil
=
Image
.
fromarray
(
frame
)
# Convert to bytes
buffered
=
BytesIO
()
# frame_format = str(os.getenv('FRAME_FORMAT', "JPEG"))
im_pil
.
save
(
buffered
,
format
=
"PNG"
)
frame_bytes
=
buffered
.
getvalue
()
# Return the bytes of the frame
return
frame_bytes
def
encode_video_base64
(
video_path
,
num_frames
=
16
):
import
cv2
cap
=
cv2
.
VideoCapture
(
video_path
)
if
not
cap
.
isOpened
():
raise
IOError
(
f
"Could not open video file:
{
video_path
}
"
)
total_frames
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_COUNT
))
print
(
f
"target_frames:
{
num_frames
}
"
)
frame_indices
=
np
.
linspace
(
0
,
total_frames
-
1
,
num_frames
,
dtype
=
int
)
frames
=
[]
for
i
in
range
(
total_frames
):
ret
,
frame
=
cap
.
read
()
if
ret
:
frames
.
append
(
frame
)
else
:
# Handle the case where the frame could not be read
# print(f"Warning: Could not read frame at index {i}.")
pass
cap
.
release
()
# Safely select frames based on frame_indices, avoiding IndexError
frames
=
[
frames
[
i
]
for
i
in
frame_indices
if
i
<
len
(
frames
)]
# If there are not enough frames, duplicate the last frame until we reach the target
while
len
(
frames
)
<
num_frames
:
frames
.
append
(
frames
[
-
1
])
# Use ThreadPoolExecutor to process and encode frames in parallel
with
ThreadPoolExecutor
()
as
executor
:
encoded_frames
=
list
(
executor
.
map
(
encode_frame
,
frames
))
# encoded_frames = list(map(encode_frame, frames))
# Concatenate all frames bytes
video_bytes
=
b
""
.
join
(
encoded_frames
)
# Encode the concatenated bytes to base64
video_base64
=
"video:"
+
base64
.
b64encode
(
video_bytes
).
decode
(
"utf-8"
)
return
video_base64
def
_is_chinese_char
(
cp
):
def
_is_chinese_char
(
cp
):
"""Checks whether CP is the codepoint of a CJK character."""
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# This defines a "chinese character" as anything in the CJK Unicode block:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment