Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
071a1f51
Unverified
Commit
071a1f51
authored
Jun 29, 2025
by
Lianmin Zheng
Committed by
GitHub
Jun 29, 2025
Browse files
[Minor] clean up multimodal processor and tokenizer manager (#7624)
parent
7c0db3a6
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
141 additions
and
159 deletions
+141
-159
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+11
-12
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+4
-4
python/sglang/srt/managers/multimodal_processor.py
python/sglang/srt/managers/multimodal_processor.py
+0
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+22
-20
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+54
-69
python/sglang/srt/multimodal/processors/base_processor.py
python/sglang/srt/multimodal/processors/base_processor.py
+1
-0
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+39
-1
python/sglang/srt/speculative/build_eagle_tree.py
python/sglang/srt/speculative/build_eagle_tree.py
+8
-8
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+2
-44
No files found.
python/sglang/srt/entrypoints/http_server.py
View file @
071a1f51
...
@@ -353,8 +353,7 @@ async def generate_from_file_request(file: UploadFile, request: Request):
...
@@ -353,8 +353,7 @@ async def generate_from_file_request(file: UploadFile, request: Request):
obj
=
GenerateReqInput
(
obj
=
GenerateReqInput
(
input_embeds
=
input_embeds
,
input_embeds
=
input_embeds
,
sampling_params
=
{
sampling_params
=
{
"repetition_penalty"
:
1.2
,
"temperature"
:
0.0
,
"temperature"
:
0.2
,
"max_new_tokens"
:
512
,
"max_new_tokens"
:
512
,
},
},
)
)
...
@@ -393,16 +392,6 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
...
@@ -393,16 +392,6 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
return
_create_error_response
(
e
)
return
_create_error_response
(
e
)
@
app
.
api_route
(
"/v1/rerank"
,
methods
=
[
"POST"
,
"PUT"
],
dependencies
=
[
Depends
(
validate_json_request
)]
)
async
def
v1_rerank_request
(
request
:
V1RerankReqInput
,
raw_request
:
Request
):
"""Endpoint for reranking documents based on query relevance."""
return
await
raw_request
.
app
.
state
.
openai_serving_rerank
.
handle_request
(
request
,
raw_request
)
@
app
.
api_route
(
"/flush_cache"
,
methods
=
[
"GET"
,
"POST"
])
@
app
.
api_route
(
"/flush_cache"
,
methods
=
[
"GET"
,
"POST"
])
async
def
flush_cache
():
async
def
flush_cache
():
"""Flush the radix cache."""
"""Flush the radix cache."""
...
@@ -841,6 +830,16 @@ async def v1_score_request(request: ScoringRequest, raw_request: Request):
...
@@ -841,6 +830,16 @@ async def v1_score_request(request: ScoringRequest, raw_request: Request):
)
)
@
app
.
api_route
(
"/v1/rerank"
,
methods
=
[
"POST"
,
"PUT"
],
dependencies
=
[
Depends
(
validate_json_request
)]
)
async
def
v1_rerank_request
(
request
:
V1RerankReqInput
,
raw_request
:
Request
):
"""Endpoint for reranking documents based on query relevance."""
return
await
raw_request
.
app
.
state
.
openai_serving_rerank
.
handle_request
(
request
,
raw_request
)
def
_create_error_response
(
e
):
def
_create_error_response
(
e
):
return
ORJSONResponse
(
return
ORJSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
...
...
python/sglang/srt/managers/io_struct.py
View file @
071a1f51
...
@@ -22,17 +22,16 @@ from dataclasses import dataclass, field
...
@@ -22,17 +22,16 @@ from dataclasses import dataclass, field
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Set
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Set
,
Union
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
from
sglang.srt.multimodal.mm_utils
import
has_valid_data
from
sglang.srt.multimodal.mm_utils
import
has_valid_data
from
sglang.srt.sampling.sampling_params
import
SamplingParams
#
h
andle serialization of Image for pydantic
#
H
andle serialization of Image for pydantic
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
PIL.Image
import
Image
from
PIL.Image
import
Image
else
:
else
:
Image
=
Any
Image
=
Any
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
from
sglang.srt.sampling.sampling_params
import
SamplingParams
@
dataclass
@
dataclass
class
SessionParams
:
class
SessionParams
:
...
@@ -182,6 +181,7 @@ class GenerateReqInput:
...
@@ -182,6 +181,7 @@ class GenerateReqInput:
# Determine parallel sample count
# Determine parallel sample count
if
self
.
sampling_params
is
None
:
if
self
.
sampling_params
is
None
:
self
.
parallel_sample_num
=
1
self
.
parallel_sample_num
=
1
return
elif
isinstance
(
self
.
sampling_params
,
dict
):
elif
isinstance
(
self
.
sampling_params
,
dict
):
self
.
parallel_sample_num
=
self
.
sampling_params
.
get
(
"n"
,
1
)
self
.
parallel_sample_num
=
self
.
sampling_params
.
get
(
"n"
,
1
)
else
:
# isinstance(self.sampling_params, list):
else
:
# isinstance(self.sampling_params, list):
...
...
python/sglang/srt/managers/multimodal_processor.py
View file @
071a1f51
...
@@ -25,7 +25,6 @@ def get_dummy_processor():
...
@@ -25,7 +25,6 @@ def get_dummy_processor():
return
DummyMultimodalProcessor
()
return
DummyMultimodalProcessor
()
@
lru_cache
()
def
import_processors
():
def
import_processors
():
package_name
=
"sglang.srt.multimodal.processors"
package_name
=
"sglang.srt.multimodal.processors"
package
=
importlib
.
import_module
(
package_name
)
package
=
importlib
.
import_module
(
package_name
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
071a1f51
...
@@ -180,46 +180,48 @@ class Modality(Enum):
...
@@ -180,46 +180,48 @@ class Modality(Enum):
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
MultimodalDataItem
:
class
MultimodalDataItem
:
"""
"""
A single multimodal data, from a single image/video/audio or others
A single multimodal data, from a single image/video/audio or others.
We put the common fields first and the model-specific fields last.
"""
"""
modality
:
Modality
modality
:
Modality
hash
:
int
=
None
hash
:
int
=
None
pad_value
:
int
=
None
pad_value
:
int
=
None
aspect_ratio_id
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
aspect_ratio_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
image_sizes
:
Tuple
[
int
,
int
]
=
None
image_sizes
:
Tuple
[
int
,
int
]
=
None
image_offsets
:
Optional
[
list
]
=
None
image_offsets
:
Optional
[
list
]
=
None
# the real data, pixel_values or audio_features
# the real data, pixel_values or audio_features
# data: Union[List[torch.Tensor], List[np.ndarray]]
# data: Union[List[torch.Tensor], List[np.ndarray]]
pixel_values
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
=
None
pixel_values
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
=
None
audio_features
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
=
None
audio_feature_lens
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
audio_offsets
:
Optional
[
List
[
Tuple
[
int
,
int
]]]
=
None
precomputed_features
:
Optional
[
Union
[
torch
.
Tensor
,
np
.
ndarray
]]
=
None
# For qwen-vl
image_grid_thw
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
=
None
image_grid_thw
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
=
None
video
_grid_t
hw
s
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
=
None
second_per
_grid_ts
:
Optional
[
List
[
torch
.
Tensor
]
]
=
None
# For deepseek-vl
image_emb_mask
:
Optional
[
torch
.
Tensor
]
=
None
image_emb_mask
:
Optional
[
torch
.
Tensor
]
=
None
image_spatial_crop
:
Optional
[
torch
.
Tensor
]
=
None
image_spatial_crop
:
Optional
[
torch
.
Tensor
]
=
None
second_per_grid_ts
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# For minicpmv
# [num_images, (n, w, h)]
# [num_images, (n, w, h)]
tgt_size
:
Tuple
[
int
,
int
]
=
None
tgt_size
:
Tuple
[
int
,
int
]
=
None
# kimi-vl related
# For mllama
image_grid_hws
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
aspect_ratio_id
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
aspect_ratio_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
audio_features
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
=
None
# For kimi-vl
audio_feature_lens
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
image_grid_hws
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
audio_offsets
:
Optional
[
List
[
Tuple
[
int
,
int
]]]
=
None
# gemma3n
related
#
For
gemma3n
input_features
:
Optional
[
torch
.
Tensor
]
=
None
input_features
:
Optional
[
torch
.
Tensor
]
=
None
input_features_mask
:
Optional
[
torch
.
Tensor
]
=
None
input_features_mask
:
Optional
[
torch
.
Tensor
]
=
None
precomputed_features
:
Optional
[
Union
[
torch
.
Tensor
,
np
.
ndarray
]]
=
None
@
staticmethod
@
staticmethod
def
is_empty_list
(
l
):
def
is_empty_list
(
l
):
if
l
is
None
:
if
l
is
None
:
...
@@ -339,10 +341,6 @@ class MultimodalInputs:
...
@@ -339,10 +341,6 @@ class MultimodalInputs:
image_pad_len
:
Optional
[
list
]
=
None
image_pad_len
:
Optional
[
list
]
=
None
num_image_tokens
:
Optional
[
int
]
=
None
num_image_tokens
:
Optional
[
int
]
=
None
# QWen2-VL related
mrope_positions
:
Optional
[
torch
.
Tensor
]
=
None
mrope_position_delta
:
Optional
[
torch
.
Tensor
]
=
None
# image
# image
im_token_id
:
Optional
[
int
]
=
None
im_token_id
:
Optional
[
int
]
=
None
im_start_id
:
Optional
[
int
]
=
None
im_start_id
:
Optional
[
int
]
=
None
...
@@ -358,6 +356,10 @@ class MultimodalInputs:
...
@@ -358,6 +356,10 @@ class MultimodalInputs:
audio_start_id
:
Optional
[
int
]
=
None
audio_start_id
:
Optional
[
int
]
=
None
audio_end_id
:
Optional
[
int
]
=
None
audio_end_id
:
Optional
[
int
]
=
None
# QWen2-VL related
mrope_positions
:
Optional
[
torch
.
Tensor
]
=
None
mrope_position_delta
:
Optional
[
torch
.
Tensor
]
=
None
@
staticmethod
@
staticmethod
def
from_dict
(
obj
:
dict
):
def
from_dict
(
obj
:
dict
):
ret
=
MultimodalInputs
(
ret
=
MultimodalInputs
(
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
071a1f51
...
@@ -150,7 +150,9 @@ class ReqState:
...
@@ -150,7 +150,9 @@ class ReqState:
# For streaming output
# For streaming output
last_output_offset
:
int
=
0
last_output_offset
:
int
=
0
# For incremental state update.
# For incremental state update.
# TODO(lianmin): do not initialize some lists if not needed.
text
:
str
=
""
text
:
str
=
""
output_ids
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
output_ids
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
input_token_logprobs_val
:
List
[
float
]
=
dataclasses
.
field
(
default_factory
=
list
)
input_token_logprobs_val
:
List
[
float
]
=
dataclasses
.
field
(
default_factory
=
list
)
...
@@ -199,7 +201,6 @@ class TokenizerManager:
...
@@ -199,7 +201,6 @@ class TokenizerManager:
self
.
model_path
=
server_args
.
model_path
self
.
model_path
=
server_args
.
model_path
self
.
served_model_name
=
server_args
.
served_model_name
self
.
served_model_name
=
server_args
.
served_model_name
self
.
model_config
=
ModelConfig
.
from_server_args
(
server_args
)
self
.
model_config
=
ModelConfig
.
from_server_args
(
server_args
)
self
.
is_generation
=
self
.
model_config
.
is_generation
self
.
is_generation
=
self
.
model_config
.
is_generation
self
.
is_image_gen
=
self
.
model_config
.
is_image_gen
self
.
is_image_gen
=
self
.
model_config
.
is_image_gen
self
.
context_len
=
self
.
model_config
.
context_len
self
.
context_len
=
self
.
model_config
.
context_len
...
@@ -251,19 +252,36 @@ class TokenizerManager:
...
@@ -251,19 +252,36 @@ class TokenizerManager:
self
.
dump_requests_threshold
=
1000
self
.
dump_requests_threshold
=
1000
self
.
dump_request_list
:
List
[
Tuple
]
=
[]
self
.
dump_request_list
:
List
[
Tuple
]
=
[]
self
.
log_request_metadata
=
self
.
get_log_request_metadata
()
self
.
log_request_metadata
=
self
.
get_log_request_metadata
()
self
.
asyncio_tasks
=
set
()
self
.
session_futures
=
{}
# session_id -> asyncio event
self
.
max_req_input_len
=
None
# The event to notify the weight sync is finished.
# The event to notify the weight sync is finished.
self
.
model_update_lock
=
RWLock
()
self
.
model_update_lock
=
RWLock
()
self
.
model_update_result
:
Optional
[
Awaitable
[
UpdateWeightFromDiskReqOutput
]]
=
(
self
.
model_update_result
:
Optional
[
Awaitable
[
UpdateWeightFromDiskReqOutput
]]
=
(
None
None
)
)
self
.
asyncio_tasks
=
set
()
# For session info
# For pd disaggregtion
self
.
session_futures
=
{}
# session_id -> asyncio event
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
server_args
.
disaggregation_mode
)
self
.
transfer_backend
=
TransferBackend
(
self
.
server_args
.
disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
BOOTSTRAP_SERVER
)
self
.
bootstrap_server
=
kv_bootstrap_server_class
(
self
.
server_args
.
disaggregation_bootstrap_port
)
# Set after scheduler is initialized
# For load balancing
self
.
max_req_input_len
=
None
self
.
current_load
=
0
self
.
current_load_lock
=
asyncio
.
Lock
()
# Metrics
# Metrics
if
self
.
enable_metrics
:
if
self
.
enable_metrics
:
...
@@ -393,34 +411,14 @@ class TokenizerManager:
...
@@ -393,34 +411,14 @@ class TokenizerManager:
]
]
)
)
# For pd disaggregtion
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
server_args
.
disaggregation_mode
)
self
.
transfer_backend
=
TransferBackend
(
self
.
server_args
.
disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
BOOTSTRAP_SERVER
)
self
.
bootstrap_server
=
kv_bootstrap_server_class
(
self
.
server_args
.
disaggregation_bootstrap_port
)
self
.
current_load
=
0
self
.
current_load_lock
=
asyncio
.
Lock
()
async
def
generate_request
(
async
def
generate_request
(
self
,
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
:
Optional
[
fastapi
.
Request
]
=
None
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
):
created_time
=
time
.
time
()
created_time
=
time
.
time
()
self
.
auto_create_handle_loop
()
self
.
auto_create_handle_loop
()
obj
.
normalize_batch_and_arguments
()
if
isinstance
(
obj
,
EmbeddingReqInput
)
and
self
.
is_generation
:
if
isinstance
(
obj
,
EmbeddingReqInput
)
and
self
.
is_generation
:
raise
ValueError
(
raise
ValueError
(
...
@@ -428,22 +426,6 @@ class TokenizerManager:
...
@@ -428,22 +426,6 @@ class TokenizerManager:
"Please add `--is-embedding` when launching the server or try another model."
"Please add `--is-embedding` when launching the server or try another model."
)
)
obj
.
normalize_batch_and_arguments
()
if
isinstance
(
obj
,
GenerateReqInput
):
return_hidden_states
=
obj
.
return_hidden_states
has_return_hidden_states
=
return_hidden_states
==
True
or
(
isinstance
(
return_hidden_states
,
list
)
and
any
(
return_hidden_states
)
)
if
(
not
self
.
server_args
.
enable_return_hidden_states
and
has_return_hidden_states
):
raise
ValueError
(
"return_hidden_states=True requires the server to be started "
"with --enable-return-hidden-states (ServerArgs.enable_return_hidden_states)."
)
if
self
.
log_requests
:
if
self
.
log_requests
:
max_length
,
skip_names
,
_
=
self
.
log_request_metadata
max_length
,
skip_names
,
_
=
self
.
log_request_metadata
logger
.
info
(
logger
.
info
(
...
@@ -451,8 +433,7 @@ class TokenizerManager:
...
@@ -451,8 +433,7 @@ class TokenizerManager:
)
)
async
with
self
.
model_update_lock
.
reader_lock
:
async
with
self
.
model_update_lock
.
reader_lock
:
is_single
=
obj
.
is_single
if
obj
.
is_single
:
if
is_single
:
tokenized_obj
=
await
self
.
_tokenize_one_request
(
obj
)
tokenized_obj
=
await
self
.
_tokenize_one_request
(
obj
)
state
=
self
.
_send_one_request
(
obj
,
tokenized_obj
,
created_time
)
state
=
self
.
_send_one_request
(
obj
,
tokenized_obj
,
created_time
)
async
for
response
in
self
.
_wait_one_response
(
obj
,
state
,
request
):
async
for
response
in
self
.
_wait_one_response
(
obj
,
state
,
request
):
...
@@ -514,12 +495,12 @@ class TokenizerManager:
...
@@ -514,12 +495,12 @@ class TokenizerManager:
else
:
else
:
image_inputs
:
Optional
[
Dict
]
=
None
image_inputs
:
Optional
[
Dict
]
=
None
self
.
_validate_
token_len
(
obj
,
input_ids
)
self
.
_validate_
one_request
(
obj
,
input_ids
)
return
self
.
_create_tokenized_object
(
return
self
.
_create_tokenized_object
(
obj
,
input_text
,
input_ids
,
input_embeds
,
image_inputs
,
token_type_ids
obj
,
input_text
,
input_ids
,
input_embeds
,
image_inputs
,
token_type_ids
)
)
def
_validate_
token_len
(
def
_validate_
one_request
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
input_ids
:
List
[
int
]
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
input_ids
:
List
[
int
]
)
->
None
:
)
->
None
:
"""Validates that the input token count and the requested token count doesn't exceed the model's context length."""
"""Validates that the input token count and the requested token count doesn't exceed the model's context length."""
...
@@ -548,24 +529,14 @@ class TokenizerManager:
...
@@ -548,24 +529,14 @@ class TokenizerManager:
)
)
raise
ValueError
(
error_msg
)
raise
ValueError
(
error_msg
)
def
_create_tokenized_object
(
if
isinstance
(
obj
,
GenerateReqInput
):
self
,
if
(
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
obj
.
return_hidden_states
input_text
:
str
,
and
not
self
.
server_args
.
enable_return_hidden_states
input_ids
:
List
[
int
],
):
input_embeds
:
Optional
[
Union
[
List
[
float
],
None
]]
=
None
,
raise
ValueError
(
image_inputs
:
Optional
[
Dict
]
=
None
,
"The server is not configured to return the hidden states. "
token_type_ids
:
Optional
[
List
[
int
]]
=
None
,
"Please set `--enable-return-hidden-states` to enable this feature."
)
->
Union
[
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
]:
"""Create a tokenized request object from common parameters."""
if
self
.
is_generation
:
return_logprob
=
obj
.
return_logprob
logprob_start_len
=
obj
.
logprob_start_len
top_logprobs_num
=
obj
.
top_logprobs_num
token_ids_logprob
=
obj
.
token_ids_logprob
session_params
=
(
SessionParams
(
**
obj
.
session_params
)
if
obj
.
session_params
else
None
)
)
if
(
if
(
obj
.
custom_logit_processor
obj
.
custom_logit_processor
...
@@ -576,6 +547,16 @@ class TokenizerManager:
...
@@ -576,6 +547,16 @@ class TokenizerManager:
"Please set `--enable-custom-logits-processor` to enable this feature."
"Please set `--enable-custom-logits-processor` to enable this feature."
)
)
def
_create_tokenized_object
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
input_text
:
str
,
input_ids
:
List
[
int
],
input_embeds
:
Optional
[
Union
[
List
[
float
],
None
]]
=
None
,
image_inputs
:
Optional
[
Dict
]
=
None
,
token_type_ids
:
Optional
[
List
[
int
]]
=
None
,
)
->
Union
[
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
]:
"""Create a tokenized request object from common parameters."""
# Parse sampling parameters
# Parse sampling parameters
# Note: if there are preferred sampling params, we use them if they are not
# Note: if there are preferred sampling params, we use them if they are not
# explicitly passed in sampling_params
# explicitly passed in sampling_params
...
@@ -589,16 +570,20 @@ class TokenizerManager:
...
@@ -589,16 +570,20 @@ class TokenizerManager:
# Build return object
# Build return object
if
isinstance
(
obj
,
GenerateReqInput
):
if
isinstance
(
obj
,
GenerateReqInput
):
session_params
=
(
SessionParams
(
**
obj
.
session_params
)
if
obj
.
session_params
else
None
)
tokenized_obj
=
TokenizedGenerateReqInput
(
tokenized_obj
=
TokenizedGenerateReqInput
(
obj
.
rid
,
obj
.
rid
,
input_text
,
input_text
,
input_ids
,
input_ids
,
image_inputs
,
image_inputs
,
sampling_params
,
sampling_params
,
return_logprob
,
obj
.
return_logprob
,
logprob_start_len
,
obj
.
logprob_start_len
,
top_logprobs_num
,
obj
.
top_logprobs_num
,
token_ids_logprob
,
obj
.
token_ids_logprob
,
obj
.
stream
,
obj
.
stream
,
bootstrap_host
=
obj
.
bootstrap_host
,
bootstrap_host
=
obj
.
bootstrap_host
,
bootstrap_port
=
obj
.
bootstrap_port
,
bootstrap_port
=
obj
.
bootstrap_port
,
...
...
python/sglang/srt/multimodal/processors/base_processor.py
View file @
071a1f51
...
@@ -98,6 +98,7 @@ class BaseMultimodalProcessor(ABC):
...
@@ -98,6 +98,7 @@ class BaseMultimodalProcessor(ABC):
self
.
_processor
=
_processor
self
.
_processor
=
_processor
self
.
arch
=
hf_config
.
architectures
[
0
]
self
.
arch
=
hf_config
.
architectures
[
0
]
self
.
server_args
=
server_args
self
.
server_args
=
server_args
# FIXME: not accurate, model and image specific
# FIXME: not accurate, model and image specific
self
.
NUM_TOKEN_PER_FRAME
=
330
self
.
NUM_TOKEN_PER_FRAME
=
330
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
071a1f51
...
@@ -10,7 +10,6 @@ import torch
...
@@ -10,7 +10,6 @@ import torch
import
sglang.srt.sampling.penaltylib
as
penaltylib
import
sglang.srt.sampling.penaltylib
as
penaltylib
from
sglang.srt.sampling.custom_logit_processor
import
CustomLogitProcessor
from
sglang.srt.sampling.custom_logit_processor
import
CustomLogitProcessor
from
sglang.srt.sampling.sampling_params
import
TOP_K_ALL
from
sglang.srt.sampling.sampling_params
import
TOP_K_ALL
from
sglang.srt.utils
import
merge_bias_tensor
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
...
@@ -345,3 +344,42 @@ class SamplingBatchInfo:
...
@@ -345,3 +344,42 @@ class SamplingBatchInfo:
self
.
logit_bias
=
merge_bias_tensor
(
self
.
logit_bias
=
merge_bias_tensor
(
self
.
logit_bias
,
other
.
logit_bias
,
len
(
self
),
len
(
other
),
self
.
device
,
0.0
self
.
logit_bias
,
other
.
logit_bias
,
len
(
self
),
len
(
other
),
self
.
device
,
0.0
)
)
def
merge_bias_tensor
(
lhs
:
Optional
[
torch
.
Tensor
],
rhs
:
Optional
[
torch
.
Tensor
],
bs1
:
int
,
bs2
:
int
,
device
:
str
,
default
:
float
,
):
"""Merge two bias tensors for batch merging.
Args:
lhs: Left-hand side tensor
rhs: Right-hand side tensor
bs1: Batch size of left-hand side tensor
bs2: Batch size of right-hand side tensor
device: Device to place the merged tensor on
default: Default value for missing tensor elements
Returns:
Merged tensor or None if both inputs are None
"""
if
lhs
is
None
and
rhs
is
None
:
return
None
if
lhs
is
not
None
and
rhs
is
not
None
:
return
torch
.
cat
([
lhs
,
rhs
])
else
:
if
lhs
is
not
None
:
shape
,
dtype
=
lhs
.
shape
[
1
:],
lhs
.
dtype
else
:
shape
,
dtype
=
rhs
.
shape
[
1
:],
rhs
.
dtype
if
lhs
is
None
:
lhs
=
torch
.
empty
((
bs1
,
*
shape
),
device
=
device
,
dtype
=
dtype
).
fill_
(
default
)
if
rhs
is
None
:
rhs
=
torch
.
empty
((
bs2
,
*
shape
),
device
=
device
,
dtype
=
dtype
).
fill_
(
default
)
return
torch
.
cat
([
lhs
,
rhs
])
python/sglang/srt/speculative/build_eagle_tree.py
View file @
071a1f51
...
@@ -4,7 +4,7 @@ from typing import List
...
@@ -4,7 +4,7 @@ from typing import List
import
torch
import
torch
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
rank0_
print
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
rank0_
log
if
is_cuda
()
or
is_hip
():
if
is_cuda
()
or
is_hip
():
from
sgl_kernel
import
(
from
sgl_kernel
import
(
...
@@ -344,13 +344,13 @@ def test_build_tree_kernel_efficient():
...
@@ -344,13 +344,13 @@ def test_build_tree_kernel_efficient():
num_verify_tokens
=
num_draft_token
,
num_verify_tokens
=
num_draft_token
,
)
)
rank0_
print
(
"=========== build tree kernel efficient =========="
)
rank0_
log
(
"=========== build tree kernel efficient =========="
)
# rank0_
print
(f"{tree_mask=}"
, flush=True
)
# rank0_
log
(f"{tree_mask=}")
rank0_
print
(
f
"
{
position
=
}
"
,
flush
=
True
)
rank0_
log
(
f
"
{
position
=
}
"
)
rank0_
print
(
f
"
{
retrive_index
=
}
"
,
flush
=
True
)
rank0_
log
(
f
"
{
retrive_index
=
}
"
)
rank0_
print
(
f
"
{
retrive_next_token
=
}
"
,
flush
=
True
)
rank0_
log
(
f
"
{
retrive_next_token
=
}
"
)
rank0_
print
(
f
"
{
retrive_next_sibling
=
}
"
,
flush
=
True
)
rank0_
log
(
f
"
{
retrive_next_sibling
=
}
"
)
rank0_
print
(
f
"
{
draft_tokens
=
}
"
,
flush
=
True
)
rank0_
log
(
f
"
{
draft_tokens
=
}
"
)
assert
position
.
tolist
()
==
[
5
,
6
,
6
,
7
,
7
,
8
,
8
,
9
,
10
,
11
,
12
,
12
,
12
,
12
,
13
,
14
]
assert
position
.
tolist
()
==
[
5
,
6
,
6
,
7
,
7
,
8
,
8
,
9
,
10
,
11
,
12
,
12
,
12
,
12
,
13
,
14
]
assert
retrive_index
.
tolist
()
==
[
assert
retrive_index
.
tolist
()
==
[
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
...
...
python/sglang/srt/utils.py
View file @
071a1f51
...
@@ -1917,14 +1917,11 @@ def configure_ipv6(dist_init_addr):
...
@@ -1917,14 +1917,11 @@ def configure_ipv6(dist_init_addr):
return
port
,
host
return
port
,
host
def
rank0_
print
(
msg
:
str
):
def
rank0_
log
(
msg
:
str
):
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
if
get_tensor_model_parallel_rank
()
==
0
:
if
get_tensor_model_parallel_rank
()
==
0
:
print
(
msg
,
flush
=
True
)
logger
.
info
(
msg
)
rank0_log
=
rank0_print
def
get_cuda_version
():
def
get_cuda_version
():
...
@@ -2344,45 +2341,6 @@ def require_mlp_sync(server_args):
...
@@ -2344,45 +2341,6 @@ def require_mlp_sync(server_args):
return
server_args
.
enable_dp_attention
or
require_gathered_buffer
(
server_args
)
return
server_args
.
enable_dp_attention
or
require_gathered_buffer
(
server_args
)
def
merge_bias_tensor
(
lhs
:
Optional
[
torch
.
Tensor
],
rhs
:
Optional
[
torch
.
Tensor
],
bs1
:
int
,
bs2
:
int
,
device
:
str
,
default
:
float
,
):
"""Merge two bias tensors for batch merging.
Args:
lhs: Left-hand side tensor
rhs: Right-hand side tensor
bs1: Batch size of left-hand side tensor
bs2: Batch size of right-hand side tensor
device: Device to place the merged tensor on
default: Default value for missing tensor elements
Returns:
Merged tensor or None if both inputs are None
"""
if
lhs
is
None
and
rhs
is
None
:
return
None
if
lhs
is
not
None
and
rhs
is
not
None
:
return
torch
.
cat
([
lhs
,
rhs
])
else
:
if
lhs
is
not
None
:
shape
,
dtype
=
lhs
.
shape
[
1
:],
lhs
.
dtype
else
:
shape
,
dtype
=
rhs
.
shape
[
1
:],
rhs
.
dtype
if
lhs
is
None
:
lhs
=
torch
.
empty
((
bs1
,
*
shape
),
device
=
device
,
dtype
=
dtype
).
fill_
(
default
)
if
rhs
is
None
:
rhs
=
torch
.
empty
((
bs2
,
*
shape
),
device
=
device
,
dtype
=
dtype
).
fill_
(
default
)
return
torch
.
cat
([
lhs
,
rhs
])
def
find_local_repo_dir
(
repo_id
:
str
,
revision
:
Optional
[
str
]
=
None
)
->
Optional
[
str
]:
def
find_local_repo_dir
(
repo_id
:
str
,
revision
:
Optional
[
str
]
=
None
)
->
Optional
[
str
]:
import
huggingface_hub
as
hf
import
huggingface_hub
as
hf
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment