Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fd28640d
Unverified
Commit
fd28640d
authored
Dec 29, 2024
by
fzyzcjy
Committed by
GitHub
Dec 28, 2024
Browse files
Add `update_weights_from_tensor` (#2631)
parent
7863e436
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
120 additions
and
1 deletion
+120
-1
.gitignore
.gitignore
+2
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+14
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+19
-1
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+26
-0
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+7
-0
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+5
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+4
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+10
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_update_weights_from_tensor.py
test/srt/test_update_weights_from_tensor.py
+32
-0
No files found.
.gitignore
View file @
fd28640d
...
...
@@ -220,3 +220,5 @@ work_dirs/
*.app
compile_commands.json
*.iml
python/sglang/srt/managers/io_struct.py
View file @
fd28640d
...
...
@@ -21,6 +21,8 @@ from dataclasses import dataclass
from
enum
import
Enum
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
from
sglang.srt.sampling.sampling_params
import
SamplingParams
...
...
@@ -407,6 +409,18 @@ class UpdateWeightsFromDistributedReqOutput:
message
:
str
@
dataclass
class
UpdateWeightsFromTensorReqInput
:
name
:
str
tensor
:
torch
.
Tensor
@
dataclass
class
UpdateWeightsFromTensorReqOutput
:
success
:
bool
message
:
str
@
dataclass
class
InitWeightsUpdateGroupReqInput
:
# The master address
...
...
python/sglang/srt/managers/scheduler.py
View file @
fd28640d
...
...
@@ -22,7 +22,7 @@ import warnings
from
collections
import
deque
from
concurrent
import
futures
from
types
import
SimpleNamespace
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
import
psutil
import
setproctitle
...
...
@@ -52,6 +52,8 @@ from sglang.srt.managers.io_struct import (
UpdateWeightFromDiskReqOutput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromDistributedReqOutput
,
UpdateWeightsFromTensorReqInput
,
UpdateWeightsFromTensorReqOutput
,
)
from
sglang.srt.managers.schedule_batch
import
(
FINISH_ABORT
,
...
...
@@ -478,6 +480,11 @@ class Scheduler:
self
.
send_to_tokenizer
.
send_pyobj
(
UpdateWeightsFromDistributedReqOutput
(
success
,
message
)
)
elif
isinstance
(
recv_req
,
UpdateWeightsFromTensorReqInput
):
success
,
message
=
self
.
update_weights_from_tensor
(
recv_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
UpdateWeightsFromTensorReqOutput
(
success
,
message
)
)
elif
isinstance
(
recv_req
,
GetWeightsByNameReqInput
):
parameter
=
self
.
get_weights_by_name
(
recv_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
GetWeightsByNameReqOutput
(
parameter
))
...
...
@@ -1458,6 +1465,17 @@ class Scheduler:
logger
.
error
(
message
)
return
success
,
message
def
update_weights_from_tensor
(
self
,
recv_req
:
UpdateWeightsFromTensorReqInput
):
"""Update the online model parameter from tensors."""
success
,
message
=
self
.
tp_worker
.
update_weights_from_tensor
(
recv_req
)
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
if
success
:
flash_cache_success
=
self
.
flush_cache
()
assert
flash_cache_success
,
"Cache flush failed after updating weights"
else
:
logger
.
error
(
message
)
return
success
,
message
def
get_weights_by_name
(
self
,
recv_req
:
GetWeightsByNameReqInput
):
parameter
=
self
.
tp_worker
.
get_weights_by_name
(
recv_req
)
return
parameter
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
fd28640d
...
...
@@ -59,6 +59,8 @@ from sglang.srt.managers.io_struct import (
UpdateWeightFromDiskReqOutput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromDistributedReqOutput
,
UpdateWeightsFromTensorReqInput
,
UpdateWeightsFromTensorReqOutput
,
)
from
sglang.srt.metrics.collector
import
TokenizerMetricsCollector
from
sglang.srt.sampling.sampling_params
import
SamplingParams
...
...
@@ -179,6 +181,9 @@ class TokenizerManager:
self
.
update_weights_from_distributed_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
update_weights_from_tensor_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
get_weights_by_name_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
...
...
@@ -515,6 +520,22 @@ class TokenizerManager:
result
=
(
await
self
.
update_weights_from_distributed_communicator
(
obj
))[
0
]
return
result
.
success
,
result
.
message
async
def
update_weights_from_tensor
(
self
,
obj
:
UpdateWeightsFromTensorReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
Tuple
[
bool
,
str
]:
self
.
auto_create_handle_loop
()
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be for update weights from distributed"
# This means that weight sync
# cannot run while requests are in progress.
async
with
self
.
model_update_lock
.
writer_lock
:
result
=
(
await
self
.
update_weights_from_tensor_communicator
(
obj
))[
0
]
return
result
.
success
,
result
.
message
async
def
get_weights_by_name
(
self
,
obj
:
GetWeightsByNameReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
):
...
...
@@ -708,6 +729,11 @@ class TokenizerManager:
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for update weights from distributed"
self
.
update_weights_from_distributed_communicator
.
handle_recv
(
recv_obj
)
elif
isinstance
(
recv_obj
,
UpdateWeightsFromTensorReqOutput
):
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for update weights from distributed"
self
.
update_weights_from_tensor_communicator
.
handle_recv
(
recv_obj
)
elif
isinstance
(
recv_obj
,
GetWeightsByNameReqOutput
):
self
.
get_weights_by_name_communicator
.
handle_recv
(
recv_obj
)
else
:
...
...
python/sglang/srt/managers/tp_worker.py
View file @
fd28640d
...
...
@@ -24,6 +24,7 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromTensorReqInput
,
)
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
...
...
@@ -188,6 +189,12 @@ class TpModelWorker:
)
return
success
,
message
def
update_weights_from_tensor
(
self
,
recv_req
:
UpdateWeightsFromTensorReqInput
):
success
,
message
=
self
.
model_runner
.
update_weights_from_tensor
(
recv_req
.
name
,
recv_req
.
tensor
)
return
success
,
message
def
get_weights_by_name
(
self
,
recv_req
:
GetWeightsByNameReqInput
):
parameter
=
self
.
model_runner
.
get_weights_by_name
(
recv_req
.
name
,
recv_req
.
truncate_size
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
fd28640d
...
...
@@ -28,6 +28,7 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromTensorReqInput
,
)
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.tp_worker
import
TpModelWorker
...
...
@@ -225,6 +226,10 @@ class TpModelWorkerClient:
success
,
message
=
self
.
worker
.
update_weights_from_distributed
(
recv_req
)
return
success
,
message
def
update_weights_from_tensor
(
self
,
recv_req
:
UpdateWeightsFromTensorReqInput
):
success
,
message
=
self
.
worker
.
update_weights_from_tensor
(
recv_req
)
return
success
,
message
def
get_weights_by_name
(
self
,
recv_req
:
GetWeightsByNameReqInput
):
return
self
.
worker
.
get_weights_by_name
(
recv_req
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
fd28640d
...
...
@@ -429,6 +429,10 @@ class ModelRunner:
logger
.
error
(
error_msg
)
return
False
,
error_msg
def
update_weights_from_tensor
(
self
,
name
,
tensor
:
torch
.
Tensor
):
self
.
model
.
load_weights
([(
name
,
tensor
)])
return
True
,
"Success"
# TODO error handling
def
get_weights_by_name
(
self
,
name
:
str
,
truncate_size
:
int
=
100
)
->
Optional
[
torch
.
Tensor
]:
...
...
python/sglang/srt/server.py
View file @
fd28640d
...
...
@@ -57,6 +57,7 @@ from sglang.srt.managers.io_struct import (
OpenSessionReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromTensorReqInput
,
)
from
sglang.srt.managers.scheduler
import
run_scheduler_process
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
...
...
@@ -109,6 +110,7 @@ app.add_middleware(
tokenizer_manager
:
TokenizerManager
=
None
scheduler_info
:
Dict
=
None
##### Native API endpoints #####
...
...
@@ -866,6 +868,14 @@ class Engine:
tokenizer_manager
.
update_weights_from_distributed
(
obj
,
None
)
)
def
update_weights_from_tensor
(
self
,
name
,
tensor
):
"""Update weights from distributed source."""
obj
=
UpdateWeightsFromTensorReqInput
(
name
=
name
,
tensor
=
tensor
)
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
tokenizer_manager
.
update_weights_from_tensor
(
obj
,
None
)
)
def
get_weights_by_name
(
self
,
name
,
truncate_size
=
100
):
"""Get weights by parameter name."""
obj
=
GetWeightsByNameReqInput
(
name
=
name
,
truncate_size
=
truncate_size
)
...
...
test/srt/run_suite.py
View file @
fd28640d
...
...
@@ -40,6 +40,7 @@ suites = {
"test_triton_attention_kernels.py"
,
"test_triton_attention_backend.py"
,
"test_update_weights_from_disk.py"
,
"test_update_weights_from_tensor.py"
,
"test_vision_chunked_prefill.py"
,
"test_vision_openai_server.py"
,
"test_session_control.py"
,
...
...
test/srt/test_update_weights_from_tensor.py
0 → 100644
View file @
fd28640d
import
unittest
import
torch
import
sglang
as
sgl
from
sglang.test.test_utils
import
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
class
TestReleaseGPUOccupation
(
unittest
.
TestCase
):
def
test_release_and_resume_occupation
(
self
):
engine
=
sgl
.
Engine
(
model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
)
param_name
=
"model.layers.2.self_attn.k_proj.weight"
def
_check_param
(
expect_values
):
actual_values
=
torch
.
tensor
(
engine
.
get_weights_by_name
(
param_name
))[
0
,
:
5
]
assert
torch
.
allclose
(
actual_values
,
torch
.
tensor
(
expect_values
),
atol
=
0.001
),
f
"
{
actual_values
=
}
"
_check_param
([
0.0571
,
-
0.0114
,
0.0444
,
0.0215
,
-
0.0149
])
new_tensor
=
torch
.
full
((
3072
,
2048
),
1.5
)
engine
.
update_weights_from_tensor
(
param_name
,
new_tensor
)
_check_param
([
1.5
]
*
5
)
engine
.
shutdown
()
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment