Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b6cf3532
Unverified
Commit
b6cf3532
authored
May 08, 2025
by
fzyzcjy
Committed by
GitHub
May 08, 2025
Browse files
Tiny refactor ModelConfig.from_server_args (#5219)
parent
3b2680a4
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
23 additions
and
53 deletions
+23
-53
python/sglang/bench_one_batch.py
python/sglang/bench_one_batch.py
+1
-11
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+16
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-11
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+1
-11
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+3
-10
test/srt/test_gptqmodel_dynamic.py
test/srt/test_gptqmodel_dynamic.py
+1
-10
No files found.
python/sglang/bench_one_batch.py
View file @
b6cf3532
...
...
@@ -137,17 +137,7 @@ def load_model(server_args, port_args, tp_rank):
suppress_other_loggers
()
rank_print
=
print
if
tp_rank
==
0
else
lambda
*
args
,
**
kwargs
:
None
model_config
=
ModelConfig
(
server_args
.
model_path
,
trust_remote_code
=
server_args
.
trust_remote_code
,
revision
=
server_args
.
revision
,
context_length
=
server_args
.
context_length
,
model_override_args
=
server_args
.
json_model_override_args
,
is_embedding
=
server_args
.
is_embedding
,
enable_multimodal
=
server_args
.
enable_multimodal
,
dtype
=
server_args
.
dtype
,
quantization
=
server_args
.
quantization
,
)
model_config
=
ModelConfig
.
from_server_args
(
server_args
)
model_runner
=
ModelRunner
(
model_config
=
model_config
,
mem_fraction_static
=
server_args
.
mem_fraction_static
,
...
...
python/sglang/srt/configs/model_config.py
View file @
b6cf3532
...
...
@@ -24,6 +24,7 @@ from transformers import PretrainedConfig
from
sglang.srt.hf_transformers_utils
import
get_config
,
get_context_length
from
sglang.srt.layers.quantization
import
QUANTIZATION_METHODS
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -210,6 +211,21 @@ class ModelConfig:
self
.
hf_eos_token_id
=
self
.
get_hf_eos_token_id
()
self
.
image_token_id
=
getattr
(
self
.
hf_config
,
"image_token_id"
,
None
)
@
staticmethod
def
from_server_args
(
server_args
:
ServerArgs
,
model_path
:
str
=
None
,
**
kwargs
):
return
ModelConfig
(
model_path
=
model_path
or
server_args
.
model_path
,
trust_remote_code
=
server_args
.
trust_remote_code
,
revision
=
server_args
.
revision
,
context_length
=
server_args
.
context_length
,
model_override_args
=
server_args
.
json_model_override_args
,
is_embedding
=
server_args
.
is_embedding
,
enable_multimodal
=
server_args
.
enable_multimodal
,
dtype
=
server_args
.
dtype
,
quantization
=
server_args
.
quantization
,
**
kwargs
,
)
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def
get_total_num_kv_heads
(
self
)
->
int
:
"""Returns the total number of KV heads."""
...
...
python/sglang/srt/managers/scheduler.py
View file @
b6cf3532
...
...
@@ -455,17 +455,7 @@ class Scheduler(
def
init_tokenizer
(
self
):
server_args
=
self
.
server_args
self
.
model_config
=
ModelConfig
(
server_args
.
model_path
,
trust_remote_code
=
server_args
.
trust_remote_code
,
revision
=
server_args
.
revision
,
context_length
=
server_args
.
context_length
,
model_override_args
=
server_args
.
json_model_override_args
,
is_embedding
=
server_args
.
is_embedding
,
enable_multimodal
=
server_args
.
enable_multimodal
,
dtype
=
server_args
.
dtype
,
quantization
=
server_args
.
quantization
,
)
self
.
model_config
=
ModelConfig
.
from_server_args
(
server_args
)
self
.
is_generation
=
self
.
model_config
.
is_generation
if
server_args
.
skip_tokenizer_init
:
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
b6cf3532
...
...
@@ -165,17 +165,7 @@ class TokenizerManager:
# Read model args
self
.
model_path
=
server_args
.
model_path
self
.
served_model_name
=
server_args
.
served_model_name
self
.
model_config
=
ModelConfig
(
server_args
.
model_path
,
trust_remote_code
=
server_args
.
trust_remote_code
,
revision
=
server_args
.
revision
,
context_length
=
server_args
.
context_length
,
model_override_args
=
server_args
.
json_model_override_args
,
is_embedding
=
server_args
.
is_embedding
,
enable_multimodal
=
server_args
.
enable_multimodal
,
dtype
=
server_args
.
dtype
,
quantization
=
server_args
.
quantization
,
)
self
.
model_config
=
ModelConfig
.
from_server_args
(
server_args
)
self
.
is_generation
=
self
.
model_config
.
is_generation
self
.
is_image_gen
=
self
.
model_config
.
is_image_gen
...
...
python/sglang/srt/managers/tp_worker.py
View file @
b6cf3532
...
...
@@ -65,20 +65,13 @@ class TpModelWorker:
self
.
pp_rank
=
pp_rank
# Init model and tokenizer
self
.
model_config
=
ModelConfig
(
(
self
.
model_config
=
ModelConfig
.
from_server_args
(
server_args
,
model_path
=
(
server_args
.
model_path
if
not
is_draft_worker
else
server_args
.
speculative_draft_model_path
),
trust_remote_code
=
server_args
.
trust_remote_code
,
revision
=
server_args
.
revision
,
context_length
=
server_args
.
context_length
,
model_override_args
=
server_args
.
json_model_override_args
,
is_embedding
=
server_args
.
is_embedding
,
enable_multimodal
=
server_args
.
enable_multimodal
,
dtype
=
server_args
.
dtype
,
quantization
=
server_args
.
quantization
,
is_draft_model
=
is_draft_worker
,
)
...
...
test/srt/test_gptqmodel_dynamic.py
View file @
b6cf3532
...
...
@@ -43,16 +43,7 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool):
pass
server_args
=
ServerArgs
(
model_path
=
model_path
,
dtype
=
torch
.
float16
)
model_config
=
ModelConfig
(
server_args
.
model_path
,
trust_remote_code
=
server_args
.
trust_remote_code
,
revision
=
server_args
.
revision
,
context_length
=
server_args
.
context_length
,
model_override_args
=
server_args
.
json_model_override_args
,
is_embedding
=
server_args
.
is_embedding
,
dtype
=
server_args
.
dtype
,
quantization
=
server_args
.
quantization
,
)
model_config
=
ModelConfig
.
from_server_args
(
server_args
)
load_config
=
LoadConfig
()
device_config
=
DeviceConfig
(
"cuda"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment