Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
071a1f51
Unverified
Commit
071a1f51
authored
Jun 29, 2025
by
Lianmin Zheng
Committed by
GitHub
Jun 29, 2025
Browse files
[Minor] clean up multimodal processor and tokenizer manager (#7624)
parent
7c0db3a6
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
141 additions
and
159 deletions
+141
-159
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+11
-12
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+4
-4
python/sglang/srt/managers/multimodal_processor.py
python/sglang/srt/managers/multimodal_processor.py
+0
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+22
-20
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+54
-69
python/sglang/srt/multimodal/processors/base_processor.py
python/sglang/srt/multimodal/processors/base_processor.py
+1
-0
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+39
-1
python/sglang/srt/speculative/build_eagle_tree.py
python/sglang/srt/speculative/build_eagle_tree.py
+8
-8
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+2
-44
No files found.
python/sglang/srt/entrypoints/http_server.py
View file @
071a1f51
...
...
@@ -353,8 +353,7 @@ async def generate_from_file_request(file: UploadFile, request: Request):
obj
=
GenerateReqInput
(
input_embeds
=
input_embeds
,
sampling_params
=
{
"repetition_penalty"
:
1.2
,
"temperature"
:
0.2
,
"temperature"
:
0.0
,
"max_new_tokens"
:
512
,
},
)
...
...
@@ -393,16 +392,6 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
return
_create_error_response
(
e
)
@
app
.
api_route
(
"/v1/rerank"
,
methods
=
[
"POST"
,
"PUT"
],
dependencies
=
[
Depends
(
validate_json_request
)]
)
async
def
v1_rerank_request
(
request
:
V1RerankReqInput
,
raw_request
:
Request
):
"""Endpoint for reranking documents based on query relevance."""
return
await
raw_request
.
app
.
state
.
openai_serving_rerank
.
handle_request
(
request
,
raw_request
)
@
app
.
api_route
(
"/flush_cache"
,
methods
=
[
"GET"
,
"POST"
])
async
def
flush_cache
():
"""Flush the radix cache."""
...
...
@@ -841,6 +830,16 @@ async def v1_score_request(request: ScoringRequest, raw_request: Request):
)
@
app
.
api_route
(
"/v1/rerank"
,
methods
=
[
"POST"
,
"PUT"
],
dependencies
=
[
Depends
(
validate_json_request
)]
)
async
def
v1_rerank_request
(
request
:
V1RerankReqInput
,
raw_request
:
Request
):
"""Endpoint for reranking documents based on query relevance."""
return
await
raw_request
.
app
.
state
.
openai_serving_rerank
.
handle_request
(
request
,
raw_request
)
def
_create_error_response
(
e
):
return
ORJSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
...
...
python/sglang/srt/managers/io_struct.py
View file @
071a1f51
...
...
@@ -22,17 +22,16 @@ from dataclasses import dataclass, field
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Set
,
Union
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
from
sglang.srt.multimodal.mm_utils
import
has_valid_data
from
sglang.srt.sampling.sampling_params
import
SamplingParams
#
h
andle serialization of Image for pydantic
#
H
andle serialization of Image for pydantic
if
TYPE_CHECKING
:
from
PIL.Image
import
Image
else
:
Image
=
Any
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
from
sglang.srt.sampling.sampling_params
import
SamplingParams
@
dataclass
class
SessionParams
:
...
...
@@ -182,6 +181,7 @@ class GenerateReqInput:
# Determine parallel sample count
if
self
.
sampling_params
is
None
:
self
.
parallel_sample_num
=
1
return
elif
isinstance
(
self
.
sampling_params
,
dict
):
self
.
parallel_sample_num
=
self
.
sampling_params
.
get
(
"n"
,
1
)
else
:
# isinstance(self.sampling_params, list):
...
...
python/sglang/srt/managers/multimodal_processor.py
View file @
071a1f51
...
...
@@ -25,7 +25,6 @@ def get_dummy_processor():
return
DummyMultimodalProcessor
()
@
lru_cache
()
def
import_processors
():
package_name
=
"sglang.srt.multimodal.processors"
package
=
importlib
.
import_module
(
package_name
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
071a1f51
...
...
@@ -180,46 +180,48 @@ class Modality(Enum):
@
dataclasses
.
dataclass
class
MultimodalDataItem
:
"""
A single multimodal data, from a single image/video/audio or others
A single multimodal data, from a single image/video/audio or others.
We put the common fields first and the model-specific fields last.
"""
modality
:
Modality
hash
:
int
=
None
pad_value
:
int
=
None
aspect_ratio_id
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
aspect_ratio_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
image_sizes
:
Tuple
[
int
,
int
]
=
None
image_offsets
:
Optional
[
list
]
=
None
# the real data, pixel_values or audio_features
# data: Union[List[torch.Tensor], List[np.ndarray]]
pixel_values
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
=
None
audio_features
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
=
None
audio_feature_lens
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
audio_offsets
:
Optional
[
List
[
Tuple
[
int
,
int
]]]
=
None
precomputed_features
:
Optional
[
Union
[
torch
.
Tensor
,
np
.
ndarray
]]
=
None
# For qwen-vl
image_grid_thw
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
=
None
video
_grid_t
hw
s
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
=
None
second_per
_grid_ts
:
Optional
[
List
[
torch
.
Tensor
]
]
=
None
# For deepseek-vl
image_emb_mask
:
Optional
[
torch
.
Tensor
]
=
None
image_spatial_crop
:
Optional
[
torch
.
Tensor
]
=
None
second_per_grid_ts
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# For minicpmv
# [num_images, (n, w, h)]
tgt_size
:
Tuple
[
int
,
int
]
=
None
# kimi-vl related
image_grid_hws
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# For mllama
aspect_ratio_id
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
aspect_ratio_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
audio_features
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
=
None
audio_feature_lens
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
audio_offsets
:
Optional
[
List
[
Tuple
[
int
,
int
]]]
=
None
# For kimi-vl
image_grid_hws
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# gemma3n
related
#
For
gemma3n
input_features
:
Optional
[
torch
.
Tensor
]
=
None
input_features_mask
:
Optional
[
torch
.
Tensor
]
=
None
precomputed_features
:
Optional
[
Union
[
torch
.
Tensor
,
np
.
ndarray
]]
=
None
@
staticmethod
def
is_empty_list
(
l
):
if
l
is
None
:
...
...
@@ -339,10 +341,6 @@ class MultimodalInputs:
image_pad_len
:
Optional
[
list
]
=
None
num_image_tokens
:
Optional
[
int
]
=
None
# QWen2-VL related
mrope_positions
:
Optional
[
torch
.
Tensor
]
=
None
mrope_position_delta
:
Optional
[
torch
.
Tensor
]
=
None
# image
im_token_id
:
Optional
[
int
]
=
None
im_start_id
:
Optional
[
int
]
=
None
...
...
@@ -358,6 +356,10 @@ class MultimodalInputs:
audio_start_id
:
Optional
[
int
]
=
None
audio_end_id
:
Optional
[
int
]
=
None
# QWen2-VL related
mrope_positions
:
Optional
[
torch
.
Tensor
]
=
None
mrope_position_delta
:
Optional
[
torch
.
Tensor
]
=
None
@
staticmethod
def
from_dict
(
obj
:
dict
):
ret
=
MultimodalInputs
(
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
071a1f51
...
...
@@ -150,7 +150,9 @@ class ReqState:
# For streaming output
last_output_offset
:
int
=
0
# For incremental state update.
# TODO(lianmin): do not initialize some lists if not needed.
text
:
str
=
""
output_ids
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
input_token_logprobs_val
:
List
[
float
]
=
dataclasses
.
field
(
default_factory
=
list
)
...
...
@@ -199,7 +201,6 @@ class TokenizerManager:
self
.
model_path
=
server_args
.
model_path
self
.
served_model_name
=
server_args
.
served_model_name
self
.
model_config
=
ModelConfig
.
from_server_args
(
server_args
)
self
.
is_generation
=
self
.
model_config
.
is_generation
self
.
is_image_gen
=
self
.
model_config
.
is_image_gen
self
.
context_len
=
self
.
model_config
.
context_len
...
...
@@ -251,19 +252,36 @@ class TokenizerManager:
self
.
dump_requests_threshold
=
1000
self
.
dump_request_list
:
List
[
Tuple
]
=
[]
self
.
log_request_metadata
=
self
.
get_log_request_metadata
()
self
.
asyncio_tasks
=
set
()
self
.
session_futures
=
{}
# session_id -> asyncio event
self
.
max_req_input_len
=
None
# The event to notify the weight sync is finished.
self
.
model_update_lock
=
RWLock
()
self
.
model_update_result
:
Optional
[
Awaitable
[
UpdateWeightFromDiskReqOutput
]]
=
(
None
)
self
.
asyncio_tasks
=
set
()
# For session info
self
.
session_futures
=
{}
# session_id -> asyncio event
# For pd disaggregtion
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
server_args
.
disaggregation_mode
)
self
.
transfer_backend
=
TransferBackend
(
self
.
server_args
.
disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
BOOTSTRAP_SERVER
)
self
.
bootstrap_server
=
kv_bootstrap_server_class
(
self
.
server_args
.
disaggregation_bootstrap_port
)
# Set after scheduler is initialized
self
.
max_req_input_len
=
None
# For load balancing
self
.
current_load
=
0
self
.
current_load_lock
=
asyncio
.
Lock
()
# Metrics
if
self
.
enable_metrics
:
...
...
@@ -393,34 +411,14 @@ class TokenizerManager:
]
)
# For pd disaggregtion
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
server_args
.
disaggregation_mode
)
self
.
transfer_backend
=
TransferBackend
(
self
.
server_args
.
disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
BOOTSTRAP_SERVER
)
self
.
bootstrap_server
=
kv_bootstrap_server_class
(
self
.
server_args
.
disaggregation_bootstrap_port
)
self
.
current_load
=
0
self
.
current_load_lock
=
asyncio
.
Lock
()
async
def
generate_request
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
created_time
=
time
.
time
()
self
.
auto_create_handle_loop
()
obj
.
normalize_batch_and_arguments
()
if
isinstance
(
obj
,
EmbeddingReqInput
)
and
self
.
is_generation
:
raise
ValueError
(
...
...
@@ -428,22 +426,6 @@ class TokenizerManager:
"Please add `--is-embedding` when launching the server or try another model."
)
obj
.
normalize_batch_and_arguments
()
if
isinstance
(
obj
,
GenerateReqInput
):
return_hidden_states
=
obj
.
return_hidden_states
has_return_hidden_states
=
return_hidden_states
==
True
or
(
isinstance
(
return_hidden_states
,
list
)
and
any
(
return_hidden_states
)
)
if
(
not
self
.
server_args
.
enable_return_hidden_states
and
has_return_hidden_states
):
raise
ValueError
(
"return_hidden_states=True requires the server to be started "
"with --enable-return-hidden-states (ServerArgs.enable_return_hidden_states)."
)
if
self
.
log_requests
:
max_length
,
skip_names
,
_
=
self
.
log_request_metadata
logger
.
info
(
...
...
@@ -451,8 +433,7 @@ class TokenizerManager:
)
async
with
self
.
model_update_lock
.
reader_lock
:
is_single
=
obj
.
is_single
if
is_single
:
if
obj
.
is_single
:
tokenized_obj
=
await
self
.
_tokenize_one_request
(
obj
)
state
=
self
.
_send_one_request
(
obj
,
tokenized_obj
,
created_time
)
async
for
response
in
self
.
_wait_one_response
(
obj
,
state
,
request
):
...
...
@@ -514,12 +495,12 @@ class TokenizerManager:
else
:
image_inputs
:
Optional
[
Dict
]
=
None
self
.
_validate_
token_len
(
obj
,
input_ids
)
self
.
_validate_
one_request
(
obj
,
input_ids
)
return
self
.
_create_tokenized_object
(
obj
,
input_text
,
input_ids
,
input_embeds
,
image_inputs
,
token_type_ids
)
def
_validate_
token_len
(
def
_validate_
one_request
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
input_ids
:
List
[
int
]
)
->
None
:
"""Validates that the input token count and the requested token count doesn't exceed the model's context length."""
...
...
@@ -548,6 +529,24 @@ class TokenizerManager:
)
raise
ValueError
(
error_msg
)
if
isinstance
(
obj
,
GenerateReqInput
):
if
(
obj
.
return_hidden_states
and
not
self
.
server_args
.
enable_return_hidden_states
):
raise
ValueError
(
"The server is not configured to return the hidden states. "
"Please set `--enable-return-hidden-states` to enable this feature."
)
if
(
obj
.
custom_logit_processor
and
not
self
.
server_args
.
enable_custom_logit_processor
):
raise
ValueError
(
"The server is not configured to enable custom logit processor. "
"Please set `--enable-custom-logits-processor` to enable this feature."
)
def
_create_tokenized_object
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
...
...
@@ -558,24 +557,6 @@ class TokenizerManager:
token_type_ids
:
Optional
[
List
[
int
]]
=
None
,
)
->
Union
[
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
]:
"""Create a tokenized request object from common parameters."""
if
self
.
is_generation
:
return_logprob
=
obj
.
return_logprob
logprob_start_len
=
obj
.
logprob_start_len
top_logprobs_num
=
obj
.
top_logprobs_num
token_ids_logprob
=
obj
.
token_ids_logprob
session_params
=
(
SessionParams
(
**
obj
.
session_params
)
if
obj
.
session_params
else
None
)
if
(
obj
.
custom_logit_processor
and
not
self
.
server_args
.
enable_custom_logit_processor
):
raise
ValueError
(
"The server is not configured to enable custom logit processor. "
"Please set `--enable-custom-logits-processor` to enable this feature."
)
# Parse sampling parameters
# Note: if there are preferred sampling params, we use them if they are not
# explicitly passed in sampling_params
...
...
@@ -589,16 +570,20 @@ class TokenizerManager:
# Build return object
if
isinstance
(
obj
,
GenerateReqInput
):
session_params
=
(
SessionParams
(
**
obj
.
session_params
)
if
obj
.
session_params
else
None
)
tokenized_obj
=
TokenizedGenerateReqInput
(
obj
.
rid
,
input_text
,
input_ids
,
image_inputs
,
sampling_params
,
return_logprob
,
logprob_start_len
,
top_logprobs_num
,
token_ids_logprob
,
obj
.
return_logprob
,
obj
.
logprob_start_len
,
obj
.
top_logprobs_num
,
obj
.
token_ids_logprob
,
obj
.
stream
,
bootstrap_host
=
obj
.
bootstrap_host
,
bootstrap_port
=
obj
.
bootstrap_port
,
...
...
python/sglang/srt/multimodal/processors/base_processor.py
View file @
071a1f51
...
...
@@ -98,6 +98,7 @@ class BaseMultimodalProcessor(ABC):
self
.
_processor
=
_processor
self
.
arch
=
hf_config
.
architectures
[
0
]
self
.
server_args
=
server_args
# FIXME: not accurate, model and image specific
self
.
NUM_TOKEN_PER_FRAME
=
330
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
071a1f51
...
...
@@ -10,7 +10,6 @@ import torch
import
sglang.srt.sampling.penaltylib
as
penaltylib
from
sglang.srt.sampling.custom_logit_processor
import
CustomLogitProcessor
from
sglang.srt.sampling.sampling_params
import
TOP_K_ALL
from
sglang.srt.utils
import
merge_bias_tensor
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
...
...
@@ -345,3 +344,42 @@ class SamplingBatchInfo:
self
.
logit_bias
=
merge_bias_tensor
(
self
.
logit_bias
,
other
.
logit_bias
,
len
(
self
),
len
(
other
),
self
.
device
,
0.0
)
def
merge_bias_tensor
(
lhs
:
Optional
[
torch
.
Tensor
],
rhs
:
Optional
[
torch
.
Tensor
],
bs1
:
int
,
bs2
:
int
,
device
:
str
,
default
:
float
,
):
"""Merge two bias tensors for batch merging.
Args:
lhs: Left-hand side tensor
rhs: Right-hand side tensor
bs1: Batch size of left-hand side tensor
bs2: Batch size of right-hand side tensor
device: Device to place the merged tensor on
default: Default value for missing tensor elements
Returns:
Merged tensor or None if both inputs are None
"""
if
lhs
is
None
and
rhs
is
None
:
return
None
if
lhs
is
not
None
and
rhs
is
not
None
:
return
torch
.
cat
([
lhs
,
rhs
])
else
:
if
lhs
is
not
None
:
shape
,
dtype
=
lhs
.
shape
[
1
:],
lhs
.
dtype
else
:
shape
,
dtype
=
rhs
.
shape
[
1
:],
rhs
.
dtype
if
lhs
is
None
:
lhs
=
torch
.
empty
((
bs1
,
*
shape
),
device
=
device
,
dtype
=
dtype
).
fill_
(
default
)
if
rhs
is
None
:
rhs
=
torch
.
empty
((
bs2
,
*
shape
),
device
=
device
,
dtype
=
dtype
).
fill_
(
default
)
return
torch
.
cat
([
lhs
,
rhs
])
python/sglang/srt/speculative/build_eagle_tree.py
View file @
071a1f51
...
...
@@ -4,7 +4,7 @@ from typing import List
import
torch
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
rank0_
print
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
rank0_
log
if
is_cuda
()
or
is_hip
():
from
sgl_kernel
import
(
...
...
@@ -344,13 +344,13 @@ def test_build_tree_kernel_efficient():
num_verify_tokens
=
num_draft_token
,
)
rank0_
print
(
"=========== build tree kernel efficient =========="
)
# rank0_
print
(f"{tree_mask=}"
, flush=True
)
rank0_
print
(
f
"
{
position
=
}
"
,
flush
=
True
)
rank0_
print
(
f
"
{
retrive_index
=
}
"
,
flush
=
True
)
rank0_
print
(
f
"
{
retrive_next_token
=
}
"
,
flush
=
True
)
rank0_
print
(
f
"
{
retrive_next_sibling
=
}
"
,
flush
=
True
)
rank0_
print
(
f
"
{
draft_tokens
=
}
"
,
flush
=
True
)
rank0_
log
(
"=========== build tree kernel efficient =========="
)
# rank0_
log
(f"{tree_mask=}")
rank0_
log
(
f
"
{
position
=
}
"
)
rank0_
log
(
f
"
{
retrive_index
=
}
"
)
rank0_
log
(
f
"
{
retrive_next_token
=
}
"
)
rank0_
log
(
f
"
{
retrive_next_sibling
=
}
"
)
rank0_
log
(
f
"
{
draft_tokens
=
}
"
)
assert
position
.
tolist
()
==
[
5
,
6
,
6
,
7
,
7
,
8
,
8
,
9
,
10
,
11
,
12
,
12
,
12
,
12
,
13
,
14
]
assert
retrive_index
.
tolist
()
==
[
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
...
...
python/sglang/srt/utils.py
View file @
071a1f51
...
...
@@ -1917,14 +1917,11 @@ def configure_ipv6(dist_init_addr):
return
port
,
host
def
rank0_
print
(
msg
:
str
):
def
rank0_
log
(
msg
:
str
):
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
if
get_tensor_model_parallel_rank
()
==
0
:
print
(
msg
,
flush
=
True
)
rank0_log
=
rank0_print
logger
.
info
(
msg
)
def
get_cuda_version
():
...
...
@@ -2344,45 +2341,6 @@ def require_mlp_sync(server_args):
return
server_args
.
enable_dp_attention
or
require_gathered_buffer
(
server_args
)
def
merge_bias_tensor
(
lhs
:
Optional
[
torch
.
Tensor
],
rhs
:
Optional
[
torch
.
Tensor
],
bs1
:
int
,
bs2
:
int
,
device
:
str
,
default
:
float
,
):
"""Merge two bias tensors for batch merging.
Args:
lhs: Left-hand side tensor
rhs: Right-hand side tensor
bs1: Batch size of left-hand side tensor
bs2: Batch size of right-hand side tensor
device: Device to place the merged tensor on
default: Default value for missing tensor elements
Returns:
Merged tensor or None if both inputs are None
"""
if
lhs
is
None
and
rhs
is
None
:
return
None
if
lhs
is
not
None
and
rhs
is
not
None
:
return
torch
.
cat
([
lhs
,
rhs
])
else
:
if
lhs
is
not
None
:
shape
,
dtype
=
lhs
.
shape
[
1
:],
lhs
.
dtype
else
:
shape
,
dtype
=
rhs
.
shape
[
1
:],
rhs
.
dtype
if
lhs
is
None
:
lhs
=
torch
.
empty
((
bs1
,
*
shape
),
device
=
device
,
dtype
=
dtype
).
fill_
(
default
)
if
rhs
is
None
:
rhs
=
torch
.
empty
((
bs2
,
*
shape
),
device
=
device
,
dtype
=
dtype
).
fill_
(
default
)
return
torch
.
cat
([
lhs
,
rhs
])
def
find_local_repo_dir
(
repo_id
:
str
,
revision
:
Optional
[
str
]
=
None
)
->
Optional
[
str
]:
import
huggingface_hub
as
hf
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment