Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
38625e21
Unverified
Commit
38625e21
authored
Nov 17, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 17, 2024
Browse files
Remove monkey_patch_vllm_dummy_weight_loader (#2064)
parent
c1f401fc
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
17 additions
and
70 deletions
+17
-70
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+2
-2
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+1
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-3
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+0
-51
test/srt/test_bench_latency.py
test/srt/test_bench_latency.py
+2
-2
test/srt/test_bench_serving.py
test/srt/test_bench_serving.py
+11
-11
No files found.
python/sglang/srt/managers/scheduler.py
View file @
38625e21
...
@@ -895,7 +895,7 @@ class Scheduler:
...
@@ -895,7 +895,7 @@ class Scheduler:
logits_output
,
next_token_ids
,
bid
=
result
logits_output
,
next_token_ids
,
bid
=
result
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
res
u
lve_batch_result
(
bid
)
logits_output
,
next_token_ids
=
self
.
tp_worker
.
res
o
lve_batch_result
(
bid
)
else
:
else
:
# Move next_token_ids and logprobs to cpu
# Move next_token_ids and logprobs to cpu
if
batch
.
return_logprob
:
if
batch
.
return_logprob
:
...
@@ -970,7 +970,7 @@ class Scheduler:
...
@@ -970,7 +970,7 @@ class Scheduler:
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
res
u
lve_batch_result
(
bid
)
logits_output
,
next_token_ids
=
self
.
tp_worker
.
res
o
lve_batch_result
(
bid
)
next_token_logprobs
=
logits_output
.
next_token_logprobs
next_token_logprobs
=
logits_output
.
next_token_logprobs
else
:
else
:
# Move next_token_ids and logprobs to cpu
# Move next_token_ids and logprobs to cpu
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
38625e21
...
@@ -141,7 +141,7 @@ class TpModelWorkerClient:
...
@@ -141,7 +141,7 @@ class TpModelWorkerClient:
self
.
launch_event
.
set
()
self
.
launch_event
.
set
()
self
.
output_queue
.
put
((
copy_event
,
logits_output
,
next_token_ids
))
self
.
output_queue
.
put
((
copy_event
,
logits_output
,
next_token_ids
))
def
res
u
lve_batch_result
(
self
,
bid
:
int
):
def
res
o
lve_batch_result
(
self
,
bid
:
int
):
copy_event
,
logits_output
,
next_token_ids
=
self
.
output_queue
.
get
()
copy_event
,
logits_output
,
next_token_ids
=
self
.
output_queue
.
get
()
while
not
copy_event
.
query
():
while
not
copy_event
.
query
():
time
.
sleep
(
1e-5
)
time
.
sleep
(
1e-5
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
38625e21
...
@@ -58,7 +58,6 @@ from sglang.srt.server_args import ServerArgs
...
@@ -58,7 +58,6 @@ from sglang.srt.server_args import ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
enable_show_time_cost
,
enable_show_time_cost
,
get_available_gpu_memory
,
get_available_gpu_memory
,
monkey_patch_vllm_dummy_weight_loader
,
monkey_patch_vllm_p2p_access_check
,
monkey_patch_vllm_p2p_access_check
,
)
)
...
@@ -242,7 +241,6 @@ class ModelRunner:
...
@@ -242,7 +241,6 @@ class ModelRunner:
raise
RuntimeError
(
"SGLang only supports sm75 and above."
)
raise
RuntimeError
(
"SGLang only supports sm75 and above."
)
# Prepare the vllm model config
# Prepare the vllm model config
monkey_patch_vllm_dummy_weight_loader
()
self
.
load_config
=
LoadConfig
(
self
.
load_config
=
LoadConfig
(
load_format
=
self
.
server_args
.
load_format
,
load_format
=
self
.
server_args
.
load_format
,
download_dir
=
self
.
server_args
.
download_dir
,
download_dir
=
self
.
server_args
.
download_dir
,
...
@@ -261,7 +259,6 @@ class ModelRunner:
...
@@ -261,7 +259,6 @@ class ModelRunner:
self
.
vllm_model_config
.
hf_config
.
update
(
self
.
vllm_model_config
.
hf_config
.
update
(
self
.
model_config
.
model_override_args
self
.
model_config
.
model_override_args
)
)
self
.
dtype
=
self
.
vllm_model_config
.
dtype
# Load the model
# Load the model
self
.
model
=
get_model
(
self
.
model
=
get_model
(
...
@@ -278,6 +275,7 @@ class ModelRunner:
...
@@ -278,6 +275,7 @@ class ModelRunner:
if
hasattr
(
self
.
model
,
"get_attention_sliding_window_size"
)
if
hasattr
(
self
.
model
,
"get_attention_sliding_window_size"
)
else
None
else
None
)
)
self
.
dtype
=
self
.
vllm_model_config
.
dtype
logger
.
info
(
logger
.
info
(
f
"Load weight end. "
f
"Load weight end. "
...
...
python/sglang/srt/utils.py
View file @
38625e21
...
@@ -405,57 +405,6 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int):
...
@@ -405,57 +405,6 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int):
setattr
(
tgt
,
"gpu_p2p_access_check"
,
lambda
*
arg
,
**
kwargs
:
True
)
setattr
(
tgt
,
"gpu_p2p_access_check"
,
lambda
*
arg
,
**
kwargs
:
True
)
def
monkey_patch_vllm_dummy_weight_loader
():
"""
Monkey patch the dummy weight loader in vllm to call process_weights_after_loading.
"""
from
vllm.model_executor.model_loader.loader
import
(
CacheConfig
,
DeviceConfig
,
DummyModelLoader
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
_initialize_model
,
initialize_dummy_weights
,
nn
,
set_default_torch_dtype
,
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
,
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
cache_config
,
)
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
quant_method
.
process_weights_after_loading
(
module
)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights
(
model
)
return
model
.
eval
()
setattr
(
DummyModelLoader
,
"load_model"
,
load_model
)
vllm_all_gather_backup
=
None
vllm_all_gather_backup
=
None
...
...
test/srt/test_bench_latency.py
View file @
38625e21
...
@@ -13,7 +13,7 @@ class TestBenchLatency(unittest.TestCase):
...
@@ -13,7 +13,7 @@ class TestBenchLatency(unittest.TestCase):
output_throughput
=
run_bench_latency
(
DEFAULT_MODEL_NAME_FOR_TEST
,
[])
output_throughput
=
run_bench_latency
(
DEFAULT_MODEL_NAME_FOR_TEST
,
[])
if
is_in_ci
():
if
is_in_ci
():
assert
output_throughput
>
13
0
,
f
"
{
output_throughput
=
}
"
self
.
assert
Greater
(
output_throughput
,
13
5
)
def
test_moe_default
(
self
):
def
test_moe_default
(
self
):
output_throughput
=
run_bench_latency
(
output_throughput
=
run_bench_latency
(
...
@@ -21,7 +21,7 @@ class TestBenchLatency(unittest.TestCase):
...
@@ -21,7 +21,7 @@ class TestBenchLatency(unittest.TestCase):
)
)
if
is_in_ci
():
if
is_in_ci
():
assert
output_throughput
>
125
,
f
"
{
output_throughput
=
}
"
self
.
assert
Greater
(
output_throughput
,
125
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
test/srt/test_bench_serving.py
View file @
38625e21
...
@@ -20,7 +20,7 @@ class TestBenchServing(unittest.TestCase):
...
@@ -20,7 +20,7 @@ class TestBenchServing(unittest.TestCase):
)
)
if
is_in_ci
():
if
is_in_ci
():
assert
res
[
"output_throughput"
]
>
28
30
self
.
assert
Greater
(
res
[
"output_throughput"
]
,
28
50
)
def
test_offline_throughput_non_stream_small_batch_size
(
self
):
def
test_offline_throughput_non_stream_small_batch_size
(
self
):
res
=
run_bench_serving
(
res
=
run_bench_serving
(
...
@@ -35,7 +35,7 @@ class TestBenchServing(unittest.TestCase):
...
@@ -35,7 +35,7 @@ class TestBenchServing(unittest.TestCase):
)
)
if
is_in_ci
():
if
is_in_ci
():
assert
res
[
"output_throughput"
]
>
1000
self
.
assert
Greater
(
res
[
"output_throughput"
]
,
950
)
def
test_offline_throughput_without_radix_cache
(
self
):
def
test_offline_throughput_without_radix_cache
(
self
):
res
=
run_bench_serving
(
res
=
run_bench_serving
(
...
@@ -46,7 +46,7 @@ class TestBenchServing(unittest.TestCase):
...
@@ -46,7 +46,7 @@ class TestBenchServing(unittest.TestCase):
)
)
if
is_in_ci
():
if
is_in_ci
():
assert
res
[
"output_throughput"
]
>
2880
self
.
assert
Greater
(
res
[
"output_throughput"
]
,
2900
)
def
test_offline_throughput_without_chunked_prefill
(
self
):
def
test_offline_throughput_without_chunked_prefill
(
self
):
res
=
run_bench_serving
(
res
=
run_bench_serving
(
...
@@ -57,7 +57,7 @@ class TestBenchServing(unittest.TestCase):
...
@@ -57,7 +57,7 @@ class TestBenchServing(unittest.TestCase):
)
)
if
is_in_ci
():
if
is_in_ci
():
assert
res
[
"output_throughput"
]
>
2600
self
.
assert
Greater
(
res
[
"output_throughput"
]
,
2600
)
def
test_offline_throughput_with_triton_attention_backend
(
self
):
def
test_offline_throughput_with_triton_attention_backend
(
self
):
res
=
run_bench_serving
(
res
=
run_bench_serving
(
...
@@ -73,7 +73,7 @@ class TestBenchServing(unittest.TestCase):
...
@@ -73,7 +73,7 @@ class TestBenchServing(unittest.TestCase):
)
)
if
is_in_ci
():
if
is_in_ci
():
assert
res
[
"output_throughput"
]
>
29
30
self
.
assert
Greater
(
res
[
"output_throughput"
]
,
29
50
)
def
test_offline_throughput_default_fp8
(
self
):
def
test_offline_throughput_default_fp8
(
self
):
res
=
run_bench_serving
(
res
=
run_bench_serving
(
...
@@ -84,7 +84,7 @@ class TestBenchServing(unittest.TestCase):
...
@@ -84,7 +84,7 @@ class TestBenchServing(unittest.TestCase):
)
)
if
is_in_ci
():
if
is_in_ci
():
assert
res
[
"output_throughput"
]
>
3
1
00
self
.
assert
Greater
(
res
[
"output_throughput"
]
,
3
2
00
)
def
test_online_latency_default
(
self
):
def
test_online_latency_default
(
self
):
res
=
run_bench_serving
(
res
=
run_bench_serving
(
...
@@ -95,9 +95,9 @@ class TestBenchServing(unittest.TestCase):
...
@@ -95,9 +95,9 @@ class TestBenchServing(unittest.TestCase):
)
)
if
is_in_ci
():
if
is_in_ci
():
assert
res
[
"median_e2e_latency_ms"
]
<
12000
self
.
assert
Less
(
res
[
"median_e2e_latency_ms"
]
,
12000
)
assert
res
[
"median_ttft_ms"
]
<
80
self
.
assert
Less
(
res
[
"median_ttft_ms"
]
,
80
)
assert
res
[
"median_itl_ms"
]
<
12
self
.
assert
Less
(
res
[
"median_itl_ms"
]
,
11
)
def
test_moe_offline_throughput_default
(
self
):
def
test_moe_offline_throughput_default
(
self
):
res
=
run_bench_serving
(
res
=
run_bench_serving
(
...
@@ -108,7 +108,7 @@ class TestBenchServing(unittest.TestCase):
...
@@ -108,7 +108,7 @@ class TestBenchServing(unittest.TestCase):
)
)
if
is_in_ci
():
if
is_in_ci
():
assert
res
[
"output_throughput"
]
>
1850
self
.
assert
Greater
(
res
[
"output_throughput"
]
,
1900
)
def
test_moe_offline_throughput_without_radix_cache
(
self
):
def
test_moe_offline_throughput_without_radix_cache
(
self
):
res
=
run_bench_serving
(
res
=
run_bench_serving
(
...
@@ -119,7 +119,7 @@ class TestBenchServing(unittest.TestCase):
...
@@ -119,7 +119,7 @@ class TestBenchServing(unittest.TestCase):
)
)
if
is_in_ci
():
if
is_in_ci
():
assert
res
[
"output_throughput"
]
>
1950
self
.
assert
Greater
(
res
[
"output_throughput"
]
,
1950
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment