Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
983bfcf3
Unverified
Commit
983bfcf3
authored
Dec 01, 2024
by
Chayenne
Committed by
GitHub
Dec 01, 2024
Browse files
Online weight updates from torch.distributed (#2279)
parent
28bc60dc
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1119 additions
and
60 deletions
+1119
-60
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+7
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+35
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+34
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+60
-0
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+21
-0
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+12
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+85
-1
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+2
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+76
-13
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+73
-0
test/srt/test_get_weights_by_name.py
test/srt/test_get_weights_by_name.py
+100
-46
test/srt/test_update_weights_from_distributed.py
test/srt/test_update_weights_from_distributed.py
+614
-0
No files found.
.github/workflows/pr-test.yml
View file @
983bfcf3
...
@@ -27,6 +27,7 @@ concurrency:
...
@@ -27,6 +27,7 @@ concurrency:
cancel-in-progress
:
true
cancel-in-progress
:
true
jobs
:
jobs
:
unit-test-frontend
:
unit-test-frontend
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on
:
1-gpu-runner
runs-on
:
1-gpu-runner
...
@@ -98,6 +99,11 @@ jobs:
...
@@ -98,6 +99,11 @@ jobs:
python3 test_mla_fp8.py
python3 test_mla_fp8.py
python3 test_dp_attention.py
python3 test_dp_attention.py
-
name
:
Test update weights from distributed
timeout-minutes
:
10
run
:
|
cd test/srt
python3 test_update_weights_from_distributed.py
performance-test-1-gpu-part-1
:
performance-test-1-gpu-part-1
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...
@@ -245,6 +251,7 @@ jobs:
...
@@ -245,6 +251,7 @@ jobs:
cd test/srt
cd test/srt
python3 test_moe_eval_accuracy_large.py
python3 test_moe_eval_accuracy_large.py
finish
:
finish
:
needs
:
[
needs
:
[
unit-test-frontend
,
unit-test-backend-1-gpu
,
unit-test-backend-2-gpu
,
unit-test-frontend
,
unit-test-backend-1-gpu
,
unit-test-backend-2-gpu
,
...
...
python/sglang/srt/managers/io_struct.py
View file @
983bfcf3
...
@@ -365,6 +365,41 @@ class UpdateWeightFromDiskReqOutput:
...
@@ -365,6 +365,41 @@ class UpdateWeightFromDiskReqOutput:
message
:
str
message
:
str
@
dataclass
class
UpdateWeightsFromDistributedReqInput
:
name
:
str
dtype
:
str
shape
:
List
[
int
]
@
dataclass
class
UpdateWeightsFromDistributedReqOutput
:
success
:
bool
message
:
str
@
dataclass
class
InitWeightsUpdateGroupReqInput
:
# The master address
master_address
:
str
# The master port
master_port
:
int
# The rank offset
rank_offset
:
int
# The world size
world_size
:
int
# The group name
group_name
:
str
=
"weight_update_group"
# The backend
backend
:
str
=
"nccl"
@
dataclass
class
InitWeightsUpdateGroupReqOutput
:
success
:
bool
message
:
str
@
dataclass
@
dataclass
class
GetWeightsByNameReqInput
:
class
GetWeightsByNameReqInput
:
name
:
str
name
:
str
...
...
python/sglang/srt/managers/scheduler.py
View file @
983bfcf3
...
@@ -40,6 +40,8 @@ from sglang.srt.managers.io_struct import (
...
@@ -40,6 +40,8 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq
,
FlushCacheReq
,
GetWeightsByNameReqInput
,
GetWeightsByNameReqInput
,
GetWeightsByNameReqOutput
,
GetWeightsByNameReqOutput
,
InitWeightsUpdateGroupReqInput
,
InitWeightsUpdateGroupReqOutput
,
OpenSessionReqInput
,
OpenSessionReqInput
,
OpenSessionReqOutput
,
OpenSessionReqOutput
,
ProfileReq
,
ProfileReq
,
...
@@ -47,6 +49,8 @@ from sglang.srt.managers.io_struct import (
...
@@ -47,6 +49,8 @@ from sglang.srt.managers.io_struct import (
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqOutput
,
UpdateWeightFromDiskReqOutput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromDistributedReqOutput
,
)
)
from
sglang.srt.managers.schedule_batch
import
(
from
sglang.srt.managers.schedule_batch
import
(
FINISH_ABORT
,
FINISH_ABORT
,
...
@@ -516,6 +520,19 @@ class Scheduler:
...
@@ -516,6 +520,19 @@ class Scheduler:
elif
isinstance
(
recv_req
,
GetWeightsByNameReqInput
):
elif
isinstance
(
recv_req
,
GetWeightsByNameReqInput
):
parameter
=
self
.
get_weights_by_name
(
recv_req
)
parameter
=
self
.
get_weights_by_name
(
recv_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
GetWeightsByNameReqOutput
(
parameter
))
self
.
send_to_tokenizer
.
send_pyobj
(
GetWeightsByNameReqOutput
(
parameter
))
elif
isinstance
(
recv_req
,
InitWeightsUpdateGroupReqInput
):
success
,
message
=
self
.
init_weights_update_group
(
recv_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
InitWeightsUpdateGroupReqOutput
(
success
,
message
)
)
elif
isinstance
(
recv_req
,
UpdateWeightsFromDistributedReqInput
):
success
,
message
=
self
.
update_weights_from_distributed
(
recv_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
UpdateWeightsFromDistributedReqOutput
(
success
,
message
)
)
elif
isinstance
(
recv_req
,
GetWeightsByNameReqInput
):
parameter
=
self
.
get_weights_by_name
(
recv_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
GetWeightsByNameReqOutput
(
parameter
))
elif
isinstance
(
recv_req
,
ProfileReq
):
elif
isinstance
(
recv_req
,
ProfileReq
):
if
recv_req
==
ProfileReq
.
START_PROFILE
:
if
recv_req
==
ProfileReq
.
START_PROFILE
:
self
.
start_profile
()
self
.
start_profile
()
...
@@ -1378,6 +1395,23 @@ class Scheduler:
...
@@ -1378,6 +1395,23 @@ class Scheduler:
logger
.
error
(
message
)
logger
.
error
(
message
)
return
success
,
message
return
success
,
message
def
init_weights_update_group
(
self
,
recv_req
:
InitWeightsUpdateGroupReqInput
):
"""Initialize the online model parameter update group."""
success
,
message
=
self
.
tp_worker
.
init_weights_update_group
(
recv_req
)
return
success
,
message
def
update_weights_from_distributed
(
self
,
recv_req
:
UpdateWeightsFromDistributedReqInput
):
"""Update the online model parameter."""
success
,
message
=
self
.
tp_worker
.
update_weights_from_distributed
(
recv_req
)
if
success
:
flash_cache_success
=
self
.
flush_cache
()
assert
flash_cache_success
,
"Cache flush failed after updating weights"
else
:
logger
.
error
(
message
)
return
success
,
message
def
get_weights_by_name
(
self
,
recv_req
:
GetWeightsByNameReqInput
):
def
get_weights_by_name
(
self
,
recv_req
:
GetWeightsByNameReqInput
):
parameter
=
self
.
tp_worker
.
get_weights_by_name
(
recv_req
)
parameter
=
self
.
tp_worker
.
get_weights_by_name
(
recv_req
)
return
parameter
return
parameter
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
983bfcf3
...
@@ -48,6 +48,8 @@ from sglang.srt.managers.io_struct import (
...
@@ -48,6 +48,8 @@ from sglang.srt.managers.io_struct import (
GenerateReqInput
,
GenerateReqInput
,
GetWeightsByNameReqInput
,
GetWeightsByNameReqInput
,
GetWeightsByNameReqOutput
,
GetWeightsByNameReqOutput
,
InitWeightsUpdateGroupReqInput
,
InitWeightsUpdateGroupReqOutput
,
OpenSessionReqInput
,
OpenSessionReqInput
,
OpenSessionReqOutput
,
OpenSessionReqOutput
,
ProfileReq
,
ProfileReq
,
...
@@ -55,6 +57,8 @@ from sglang.srt.managers.io_struct import (
...
@@ -55,6 +57,8 @@ from sglang.srt.managers.io_struct import (
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqOutput
,
UpdateWeightFromDiskReqOutput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromDistributedReqOutput
,
)
)
from
sglang.srt.metrics.collector
import
TokenizerMetricsCollector
from
sglang.srt.metrics.collector
import
TokenizerMetricsCollector
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
...
@@ -456,6 +460,48 @@ class TokenizerManager:
...
@@ -456,6 +460,48 @@ class TokenizerManager:
else
:
else
:
return
False
,
"Another update is in progress. Please try again later."
return
False
,
"Another update is in progress. Please try again later."
async
def
init_weights_update_group
(
self
,
obj
:
InitWeightsUpdateGroupReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
bool
:
if
self
.
to_create_loop
:
self
.
create_handle_loop
()
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
self
.
init_weights_update_group_result
=
asyncio
.
Future
()
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for init parameter update group"
result
=
await
self
.
init_weights_update_group_result
return
result
.
success
,
result
.
message
async
def
update_weights_from_distributed
(
self
,
obj
:
UpdateWeightsFromDistributedReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
if
self
.
to_create_loop
:
self
.
create_handle_loop
()
if
not
self
.
model_update_lock
.
locked
():
async
with
self
.
model_update_lock
:
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
self
.
parameter_update_result
=
asyncio
.
Future
()
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be for update weights from distributed"
result
=
await
self
.
parameter_update_result
return
result
.
success
,
result
.
message
else
:
logger
.
error
(
f
"Another parameter update is in progress in tokenizer manager"
)
return
(
False
,
"Another parameter update is in progress. Please try again later."
,
)
async
def
get_weights_by_name
(
async
def
get_weights_by_name
(
self
,
obj
:
GetWeightsByNameReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
self
,
obj
:
GetWeightsByNameReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
):
):
...
@@ -546,7 +592,9 @@ class TokenizerManager:
...
@@ -546,7 +592,9 @@ class TokenizerManager:
BatchEmbeddingOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
,
BatchTokenIDOut
,
UpdateWeightFromDiskReqOutput
,
UpdateWeightFromDiskReqOutput
,
UpdateWeightsFromDistributedReqOutput
,
GetWeightsByNameReqOutput
,
GetWeightsByNameReqOutput
,
InitWeightsUpdateGroupReqOutput
,
]
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
]
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
if
isinstance
(
recv_obj
,
UpdateWeightFromDiskReqOutput
):
if
isinstance
(
recv_obj
,
UpdateWeightFromDiskReqOutput
):
...
@@ -558,6 +606,12 @@ class TokenizerManager:
...
@@ -558,6 +606,12 @@ class TokenizerManager:
if
len
(
self
.
model_update_tmp
)
==
self
.
server_args
.
dp_size
:
if
len
(
self
.
model_update_tmp
)
==
self
.
server_args
.
dp_size
:
self
.
model_update_result
.
set_result
(
self
.
model_update_tmp
)
self
.
model_update_result
.
set_result
(
self
.
model_update_tmp
)
continue
continue
elif
isinstance
(
recv_obj
,
UpdateWeightsFromDistributedReqOutput
):
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for update weights from distributed"
self
.
parameter_update_result
.
set_result
(
recv_obj
)
continue
elif
isinstance
(
recv_obj
,
GetWeightsByNameReqOutput
):
elif
isinstance
(
recv_obj
,
GetWeightsByNameReqOutput
):
if
self
.
server_args
.
dp_size
==
1
:
if
self
.
server_args
.
dp_size
==
1
:
self
.
get_weights_by_name_result
.
set_result
(
recv_obj
)
self
.
get_weights_by_name_result
.
set_result
(
recv_obj
)
...
@@ -568,6 +622,12 @@ class TokenizerManager:
...
@@ -568,6 +622,12 @@ class TokenizerManager:
self
.
get_weights_by_name_tmp
self
.
get_weights_by_name_tmp
)
)
continue
continue
elif
isinstance
(
recv_obj
,
InitWeightsUpdateGroupReqOutput
):
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for init parameter update group"
self
.
init_weights_update_group_result
.
set_result
(
recv_obj
)
continue
elif
isinstance
(
recv_obj
,
OpenSessionReqOutput
):
elif
isinstance
(
recv_obj
,
OpenSessionReqOutput
):
self
.
session_futures
[
recv_obj
.
session_id
].
set_result
(
self
.
session_futures
[
recv_obj
.
session_id
].
set_result
(
recv_obj
.
session_id
recv_obj
.
session_id
...
...
python/sglang/srt/managers/tp_worker.py
View file @
983bfcf3
...
@@ -21,7 +21,9 @@ from sglang.srt.configs.model_config import ModelConfig
...
@@ -21,7 +21,9 @@ from sglang.srt.configs.model_config import ModelConfig
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
GetWeightsByNameReqInput
,
GetWeightsByNameReqInput
,
InitWeightsUpdateGroupReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightsFromDistributedReqInput
,
)
)
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
...
@@ -164,6 +166,25 @@ class TpModelWorker:
...
@@ -164,6 +166,25 @@ class TpModelWorker:
)
)
return
success
,
message
return
success
,
message
def
init_weights_update_group
(
self
,
recv_req
:
InitWeightsUpdateGroupReqInput
):
success
,
message
=
self
.
model_runner
.
init_weights_update_group
(
recv_req
.
master_address
,
recv_req
.
master_port
,
recv_req
.
rank_offset
,
recv_req
.
world_size
,
recv_req
.
group_name
,
recv_req
.
backend
,
)
return
success
,
message
def
update_weights_from_distributed
(
self
,
recv_req
:
UpdateWeightsFromDistributedReqInput
):
success
,
message
=
self
.
model_runner
.
update_weights_from_distributed
(
recv_req
.
name
,
recv_req
.
dtype
,
recv_req
.
shape
)
return
success
,
message
def
get_weights_by_name
(
self
,
recv_req
:
GetWeightsByNameReqInput
):
def
get_weights_by_name
(
self
,
recv_req
:
GetWeightsByNameReqInput
):
parameter
=
self
.
model_runner
.
get_weights_by_name
(
parameter
=
self
.
model_runner
.
get_weights_by_name
(
recv_req
.
name
,
recv_req
.
truncate_size
recv_req
.
name
,
recv_req
.
truncate_size
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
983bfcf3
...
@@ -25,7 +25,9 @@ import torch
...
@@ -25,7 +25,9 @@ import torch
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
GetWeightsByNameReqInput
,
GetWeightsByNameReqInput
,
InitWeightsUpdateGroupReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightsFromDistributedReqInput
,
)
)
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.managers.tp_worker
import
TpModelWorker
...
@@ -211,6 +213,16 @@ class TpModelWorkerClient:
...
@@ -211,6 +213,16 @@ class TpModelWorkerClient:
success
,
message
=
self
.
worker
.
update_weights_from_disk
(
recv_req
)
success
,
message
=
self
.
worker
.
update_weights_from_disk
(
recv_req
)
return
success
,
message
return
success
,
message
def
init_weights_update_group
(
self
,
recv_req
:
InitWeightsUpdateGroupReqInput
):
success
,
message
=
self
.
worker
.
init_weights_update_group
(
recv_req
)
return
success
,
message
def
update_weights_from_distributed
(
self
,
recv_req
:
UpdateWeightsFromDistributedReqInput
):
success
,
message
=
self
.
worker
.
update_weights_from_distributed
(
recv_req
)
return
success
,
message
def
get_weights_by_name
(
self
,
recv_req
:
GetWeightsByNameReqInput
):
def
get_weights_by_name
(
self
,
recv_req
:
GetWeightsByNameReqInput
):
return
self
.
worker
.
get_weights_by_name
(
recv_req
)
return
self
.
worker
.
get_weights_by_name
(
recv_req
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
983bfcf3
...
@@ -20,10 +20,13 @@ import inspect
...
@@ -20,10 +20,13 @@ import inspect
import
json
import
json
import
logging
import
logging
import
pkgutil
import
pkgutil
import
time
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
Optional
,
Type
from
tokenize
import
tabsize
from
typing
import
Any
,
Optional
,
Type
,
Union
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.config
import
DeviceConfig
,
LoadConfig
from
vllm.config
import
DeviceConfig
,
LoadConfig
from
vllm.config
import
ModelConfig
as
VllmModelConfig
from
vllm.config
import
ModelConfig
as
VllmModelConfig
...
@@ -59,6 +62,7 @@ from sglang.srt.utils import (
...
@@ -59,6 +62,7 @@ from sglang.srt.utils import (
crash_on_warnings
,
crash_on_warnings
,
enable_show_time_cost
,
enable_show_time_cost
,
get_available_gpu_memory
,
get_available_gpu_memory
,
init_custom_process_group
,
is_hip
,
is_hip
,
monkey_patch_vllm_gguf_config
,
monkey_patch_vllm_gguf_config
,
monkey_patch_vllm_model_config
,
monkey_patch_vllm_model_config
,
...
@@ -404,6 +408,86 @@ class ModelRunner:
...
@@ -404,6 +408,86 @@ class ModelRunner:
logger
.
info
(
"Update weights end."
)
logger
.
info
(
"Update weights end."
)
return
True
,
"Succeeded to update model weights."
return
True
,
"Succeeded to update model weights."
def
init_weights_update_group
(
self
,
master_address
,
master_port
,
rank_offset
,
world_size
,
group_name
,
backend
=
"nccl"
,
):
"""Initialize the Torch process group for model parameter updates.
`_model_update_group` is used in the RLHF workflow, where rank
0 is the actor model in the training engine, and the other ranks are
the inference engine, which is used for rollout.
In the RLHF workflow, the training engine updates the model
weights/parameters online, and broadcasts them to the inference
engine through the `_model_update_group` process group.
"""
assert
(
torch
.
distributed
.
is_initialized
()
),
"Default torch process group must be initialized"
assert
group_name
!=
""
,
"Group name cannot be empty"
rank
=
rank_offset
+
self
.
tp_rank
logger
.
info
(
f
"init custom process group: master_address=
{
master_address
}
, master_port=
{
master_port
}
, "
f
"rank_offset=
{
rank_offset
}
, world_size=
{
world_size
}
, group_name=
{
group_name
}
, backend=
{
backend
}
"
)
try
:
self
.
_model_update_group
=
init_custom_process_group
(
backend
=
backend
,
init_method
=
f
"tcp://
{
master_address
}
:
{
master_port
}
"
,
world_size
=
world_size
,
rank
=
rank
,
group_name
=
group_name
,
)
dist
.
barrier
(
group
=
self
.
_model_update_group
,
device_ids
=
[
rank
])
return
True
,
"Succeeded to initialize custom process group."
except
Exception
as
e
:
message
=
f
"Failed to initialize custom process group:
{
e
}
."
logger
.
error
(
message
)
return
False
,
message
def
update_weights_from_distributed
(
self
,
name
,
dtype
,
shape
):
"""
Update specific parameter in the model weights online
through `_model_update_group` process group.
Args:
name: the name of the parameter to be updated.
dtype: the data type of the parameter to be updated.
shape: the shape of the parameter to be updated.
"""
target_dtype
=
(
dtype
if
isinstance
(
dtype
,
torch
.
dtype
)
else
getattr
(
torch
,
dtype
)
)
current_dtype
=
self
.
dtype
if
isinstance
(
self
.
dtype
,
str
)
else
self
.
dtype
assert
(
self
.
_model_update_group
is
not
None
),
"model update group must be initialized"
try
:
weights
=
torch
.
empty
(
shape
,
dtype
=
target_dtype
,
device
=
self
.
device
)
torch
.
distributed
.
broadcast
(
weights
,
src
=
0
,
group
=
self
.
_model_update_group
)
self
.
model
.
load_weights
([(
name
,
weights
)])
return
True
,
f
"Succeeded to update parameter
{
name
}
online."
except
Exception
as
e
:
error_msg
=
(
f
"Failed to update parameter online:
{
e
}
. "
f
"The full weights of the ModelRunner are partially updated. "
f
"Please discard the whole weights."
)
logger
.
error
(
error_msg
)
return
False
,
error_msg
def
get_weights_by_name
(
def
get_weights_by_name
(
self
,
name
:
str
,
truncate_size
:
int
=
100
self
,
name
:
str
,
truncate_size
:
int
=
100
)
->
Optional
[
torch
.
Tensor
]:
)
->
Optional
[
torch
.
Tensor
]:
...
...
python/sglang/srt/models/llama.py
View file @
983bfcf3
...
@@ -307,6 +307,8 @@ class LlamaForCausalLM(nn.Module):
...
@@ -307,6 +307,8 @@ class LlamaForCausalLM(nn.Module):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
torchao_config
=
global_server_args_dict
[
"torchao_config"
]
self
.
torchao_config
=
global_server_args_dict
[
"torchao_config"
]
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
)
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
)
# Llama 3.2 1B Insturct set tie_word_embeddings to True
# Llama 3.1 8B Insturct set tie_word_embeddings to False
if
self
.
config
.
tie_word_embeddings
:
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
else
:
...
...
python/sglang/srt/server.py
View file @
983bfcf3
...
@@ -53,8 +53,10 @@ from sglang.srt.managers.io_struct import (
...
@@ -53,8 +53,10 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput
,
EmbeddingReqInput
,
GenerateReqInput
,
GenerateReqInput
,
GetWeightsByNameReqInput
,
GetWeightsByNameReqInput
,
InitWeightsUpdateGroupReqInput
,
OpenSessionReqInput
,
OpenSessionReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightsFromDistributedReqInput
,
)
)
from
sglang.srt.managers.scheduler
import
run_scheduler_process
from
sglang.srt.managers.scheduler
import
run_scheduler_process
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
...
@@ -80,6 +82,7 @@ from sglang.srt.utils import (
...
@@ -80,6 +82,7 @@ from sglang.srt.utils import (
assert_pkg_version
,
assert_pkg_version
,
configure_logger
,
configure_logger
,
delete_directory
,
delete_directory
,
init_custom_process_group
,
is_port_available
,
is_port_available
,
kill_process_tree
,
kill_process_tree
,
maybe_set_triton_cache_manager
,
maybe_set_triton_cache_manager
,
...
@@ -211,6 +214,34 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R
...
@@ -211,6 +214,34 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R
)
)
@
app
.
post
(
"/init_weights_update_group"
)
async
def
init_weights_update_group
(
obj
:
InitWeightsUpdateGroupReqInput
,
request
:
Request
):
"""Initialize the parameter update group."""
success
,
message
=
await
tokenizer_manager
.
init_weights_update_group
(
obj
,
request
)
content
=
{
"success"
:
success
,
"message"
:
message
}
if
success
:
return
ORJSONResponse
(
content
,
status_code
=
200
)
else
:
return
ORJSONResponse
(
content
,
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
post
(
"/update_weights_from_distributed"
)
async
def
update_weights_from_distributed
(
obj
:
UpdateWeightsFromDistributedReqInput
,
request
:
Request
):
"""Update model parameter from distributed online."""
success
,
message
=
await
tokenizer_manager
.
update_weights_from_distributed
(
obj
,
request
)
content
=
{
"success"
:
success
,
"message"
:
message
}
if
success
:
return
ORJSONResponse
(
content
,
status_code
=
200
)
else
:
return
ORJSONResponse
(
content
,
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
api_route
(
"/get_weights_by_name"
,
methods
=
[
"GET"
,
"POST"
])
@
app
.
api_route
(
"/get_weights_by_name"
,
methods
=
[
"GET"
,
"POST"
])
async
def
get_weights_by_name
(
obj
:
GetWeightsByNameReqInput
,
request
:
Request
):
async
def
get_weights_by_name
(
obj
:
GetWeightsByNameReqInput
,
request
:
Request
):
"""Get model parameter by name."""
"""Get model parameter by name."""
...
@@ -288,18 +319,6 @@ async def generate_request(obj: GenerateReqInput, request: Request):
...
@@ -288,18 +319,6 @@ async def generate_request(obj: GenerateReqInput, request: Request):
)
)
@
time_func_latency
async
def
get_weights_by_name_request
(
obj
:
GetWeightsByNameReqInput
,
request
:
Request
):
"""Handle a get parameter by name request."""
try
:
ret
=
await
tokenizer_manager
.
get_weights_by_name
(
obj
,
request
)
return
ret
except
ValueError
as
e
:
return
ORJSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
api_route
(
"/encode"
,
methods
=
[
"POST"
,
"PUT"
])
@
app
.
api_route
(
"/encode"
,
methods
=
[
"POST"
,
"PUT"
])
@
time_func_latency
@
time_func_latency
async
def
encode_request
(
obj
:
EmbeddingReqInput
,
request
:
Request
):
async
def
encode_request
(
obj
:
EmbeddingReqInput
,
request
:
Request
):
...
@@ -970,7 +989,51 @@ class Engine:
...
@@ -970,7 +989,51 @@ class Engine:
async
def
get_server_info
(
self
):
async
def
get_server_info
(
self
):
return
await
_get_server_info
()
return
await
_get_server_info
()
def
init_weights_update_group
(
self
,
master_address
:
str
,
master_port
:
int
,
rank_offset
:
int
,
world_size
:
int
,
group_name
:
str
,
backend
:
str
=
"nccl"
,
):
"""Initialize parameter update group."""
obj
=
InitWeightsUpdateGroupReqInput
(
master_address
=
master_address
,
master_port
=
master_port
,
rank_offset
=
rank_offset
,
world_size
=
world_size
,
group_name
=
group_name
,
backend
=
backend
,
)
async
def
_init_group
():
return
await
tokenizer_manager
.
init_weights_update_group
(
obj
,
None
)
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
_init_group
())
def
update_weights_from_distributed
(
self
,
name
,
dtype
,
shape
):
"""Update weights from distributed source."""
obj
=
UpdateWeightsFromDistributedReqInput
(
name
=
name
,
dtype
=
dtype
,
shape
=
shape
,
)
async
def
_update_weights
():
return
await
tokenizer_manager
.
update_weights_from_distributed
(
obj
,
None
)
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
_update_weights
())
def
get_weights_by_name
(
self
,
name
,
truncate_size
=
100
):
def
get_weights_by_name
(
self
,
name
,
truncate_size
=
100
):
"""Get weights by parameter name."""
obj
=
GetWeightsByNameReqInput
(
name
=
name
,
truncate_size
=
truncate_size
)
obj
=
GetWeightsByNameReqInput
(
name
=
name
,
truncate_size
=
truncate_size
)
async
def
_get_weights
():
return
await
tokenizer_manager
.
get_weights_by_name
(
obj
,
None
)
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
get_weights
_by_name_request
(
obj
,
None
))
return
loop
.
run_until_complete
(
_
get_weights
(
))
python/sglang/srt/utils.py
View file @
983bfcf3
...
@@ -39,6 +39,7 @@ import numpy as np
...
@@ -39,6 +39,7 @@ import numpy as np
import
psutil
import
psutil
import
requests
import
requests
import
torch
import
torch
import
torch.distributed
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
triton
import
triton
import
zmq
import
zmq
...
@@ -962,6 +963,78 @@ def get_nvgpu_memory_capacity():
...
@@ -962,6 +963,78 @@ def get_nvgpu_memory_capacity():
)
)
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
def
init_custom_process_group
(
backend
=
None
,
init_method
=
None
,
timeout
=
None
,
world_size
=-
1
,
rank
=-
1
,
store
=
None
,
group_name
=
None
,
pg_options
=
None
,
):
from
torch.distributed.distributed_c10d
import
(
Backend
,
PrefixStore
,
_new_process_group_helper
,
_world
,
default_pg_timeout
,
rendezvous
,
)
assert
(
store
is
None
)
or
(
init_method
is
None
),
"Cannot specify both init_method and store."
if
store
is
not
None
:
assert
world_size
>
0
,
"world_size must be positive if using store"
assert
rank
>=
0
,
"rank must be non-negative if using store"
elif
init_method
is
None
:
init_method
=
"env://"
if
backend
:
backend
=
Backend
(
backend
)
else
:
backend
=
Backend
(
"undefined"
)
if
timeout
is
None
:
timeout
=
default_pg_timeout
# backward compatible API
if
store
is
None
:
rendezvous_iterator
=
rendezvous
(
init_method
,
rank
,
world_size
,
timeout
=
timeout
)
store
,
rank
,
world_size
=
next
(
rendezvous_iterator
)
store
.
set_timeout
(
timeout
)
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
store
=
PrefixStore
(
group_name
,
store
)
# NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
# https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
# We need to determine the appropriate parameter name based on PyTorch version
pg_options_param_name
=
(
"backend_options"
if
str
(
torch
.
__version__
)
>=
"2.6"
else
"pg_options"
)
pg
,
_
=
_new_process_group_helper
(
world_size
,
rank
,
[],
backend
,
store
,
group_name
=
group_name
,
**
{
pg_options_param_name
:
pg_options
},
timeout
=
timeout
,
)
_world
.
pg_group_ranks
[
pg
]
=
{
i
:
i
for
i
in
range
(
world_size
)}
return
pg
def
crash_on_warnings
():
def
crash_on_warnings
():
# Crash on warning if we are running CI tests
# Crash on warning if we are running CI tests
return
get_bool_env_var
(
"SGLANG_IS_IN_CI"
)
return
get_bool_env_var
(
"SGLANG_IS_IN_CI"
)
...
...
test/srt/test_get_weights_by_name.py
View file @
983bfcf3
...
@@ -8,47 +8,46 @@ from transformers import AutoModelForCausalLM
...
@@ -8,47 +8,46 @@ from transformers import AutoModelForCausalLM
import
sglang
as
sgl
import
sglang
as
sgl
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
is_in_ci
,
popen_launch_server
,
popen_launch_server
,
)
)
from
sglang.utils
import
terminate_process
from
sglang.utils
import
terminate_process
def
_process_return
(
ret
):
if
isinstance
(
ret
,
list
)
and
len
(
ret
)
==
2
:
print
(
f
"running assert_allclose on data parallel"
)
np
.
testing
.
assert_allclose
(
ret
[
0
],
ret
[
1
])
return
np
.
array
(
ret
[
0
])
return
np
.
array
(
ret
)
class
TestGetWeightsByName
(
unittest
.
TestCase
):
class
TestGetWeightsByName
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
hf_model
=
AutoModelForCausalLM
.
from_pretrained
(
cls
.
model
,
torch_dtype
=
"bfloat16"
).
to
(
"cuda:0"
)
@
classmethod
def
init_hf_model
(
self
,
model_name
,
tie_word_embeddings
):
def
tearDownClass
(
cls
):
self
.
hf_model
=
AutoModelForCausalLM
.
from_pretrained
(
del
cls
.
hf_model
model_name
,
torch_dtype
=
"bfloat16"
,
tie_word_embeddings
=
tie_word_embeddings
gc
.
collect
()
).
to
(
"cuda:0"
)
torch
.
cuda
.
empty_cache
()
def
init_backend
(
self
,
backend
,
dp
,
tp
):
def
init_backend
(
self
,
backend
,
dp
,
tp
,
model_name
):
self
.
engine
=
None
self
.
process
=
None
self
.
backend
=
backend
self
.
backend
=
backend
self
.
dp
=
dp
self
.
dp
=
dp
self
.
tp
=
tp
self
.
tp
=
tp
if
backend
==
"Engine"
:
if
backend
==
"Engine"
:
self
.
engine
=
sgl
.
Engine
(
self
.
engine
=
sgl
.
Engine
(
model_path
=
self
.
model
,
model_path
=
model
_name
,
random_seed
=
42
,
random_seed
=
42
,
tp_size
=
self
.
tp
,
tp_size
=
tp
,
dp_size
=
self
.
dp
,
dp_size
=
dp
,
mem_fraction_static
=
0.85
,
)
)
else
:
else
:
self
.
process
=
popen_launch_server
(
self
.
process
=
popen_launch_server
(
self
.
model
,
model
_name
,
self
.
base_url
,
DEFAULT_URL_FOR_TEST
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
(
other_args
=
(
"--tp-size"
,
"--tp-size"
,
...
@@ -58,12 +57,50 @@ class TestGetWeightsByName(unittest.TestCase):
...
@@ -58,12 +57,50 @@ class TestGetWeightsByName(unittest.TestCase):
),
),
)
)
def
close_engine_and_server
(
self
):
def
clean_up
(
self
):
if
self
.
engine
:
del
self
.
hf_model
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
if
self
.
backend
==
"Engine"
:
self
.
engine
.
shutdown
()
self
.
engine
.
shutdown
()
if
self
.
process
:
else
:
terminate_process
(
self
.
process
)
terminate_process
(
self
.
process
)
def
assert_tie_word_embeddings
(
self
,
truncate_size
):
print
(
f
"assert_tie_word_embeddings"
)
if
self
.
backend
==
"Engine"
:
backend_ret
=
_process_return
(
self
.
engine
.
get_weights_by_name
(
"lm_head.weight"
,
truncate_size
)
)
else
:
backend_ret
=
_process_return
(
requests
.
get
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/get_weights_by_name"
,
json
=
{
"name"
:
"lm_head.weight"
,
"truncate_size"
:
truncate_size
},
).
json
()
)
print
(
f
"assert_tie_word_embeddings of hf and backend"
)
assert
np
.
allclose
(
self
.
hf_model
.
get_parameter
(
"model.embed_tokens.weight"
)
.
cpu
()
.
detach
()
.
float
()
.
numpy
()[:
truncate_size
],
backend_ret
,
)
assert
np
.
allclose
(
self
.
hf_model
.
get_parameter
(
"lm_head.weight"
)
.
cpu
()
.
detach
()
.
float
()
.
numpy
()[:
truncate_size
],
self
.
hf_model
.
get_parameter
(
"model.embed_tokens.weight"
)
.
cpu
()
.
detach
()
.
float
()
.
numpy
()[:
truncate_size
],
)
def
assert_weights_all_close
(
self
,
param_name
,
truncate_size
):
def
assert_weights_all_close
(
self
,
param_name
,
truncate_size
):
print
(
print
(
f
"param_name:
{
param_name
}
, backend:
{
self
.
backend
}
, dp:
{
self
.
dp
}
, tp:
{
self
.
tp
}
"
f
"param_name:
{
param_name
}
, backend:
{
self
.
backend
}
, dp:
{
self
.
dp
}
, tp:
{
self
.
tp
}
"
...
@@ -73,34 +110,38 @@ class TestGetWeightsByName(unittest.TestCase):
...
@@ -73,34 +110,38 @@ class TestGetWeightsByName(unittest.TestCase):
if
self
.
backend
==
"Engine"
:
if
self
.
backend
==
"Engine"
:
engine_ret
=
self
.
engine
.
get_weights_by_name
(
param_name
,
truncate_size
)
engine_ret
=
self
.
engine
.
get_weights_by_name
(
param_name
,
truncate_size
)
engine_ret
=
self
.
_process_return
(
engine_ret
)
engine_ret
=
_process_return
(
engine_ret
)
np
.
testing
.
assert_allclose
(
engine_ret
,
param_np
,
rtol
=
1e-5
,
atol
=
1e-5
)
np
.
testing
.
assert_allclose
(
engine_ret
,
param_np
,
rtol
=
1e-5
,
atol
=
1e-5
)
if
self
.
backend
==
"Runtime"
:
if
self
.
backend
==
"Runtime"
:
runtime_ret
=
requests
.
get
(
runtime_ret
=
requests
.
get
(
f
"
{
self
.
base_url
}
/get_weights_by_name"
,
f
"
{
DEFAULT_URL_FOR_TEST
}
/get_weights_by_name"
,
json
=
{
"name"
:
param_name
,
"truncate_size"
:
truncate_size
},
json
=
{
"name"
:
param_name
,
"truncate_size"
:
truncate_size
},
).
json
()
).
json
()
runtime_ret
=
self
.
_process_return
(
runtime_ret
)
runtime_ret
=
_process_return
(
runtime_ret
)
np
.
testing
.
assert_allclose
(
runtime_ret
,
param_np
,
rtol
=
1e-5
,
atol
=
1e-5
)
np
.
testing
.
assert_allclose
(
runtime_ret
,
param_np
,
rtol
=
1e-5
,
atol
=
1e-5
)
@
staticmethod
def
test_get_weights_by_name
(
self
):
def
_process_return
(
ret
):
if
is_in_ci
():
if
isinstance
(
ret
,
list
)
and
len
(
ret
)
==
2
:
test_suits
=
[
print
(
"running assert_allclose on data parallel"
)
(
"Engine"
,
1
,
1
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
),
np
.
testing
.
assert_allclose
(
ret
[
0
],
ret
[
1
])
]
return
np
.
array
(
ret
[
0
])
else
:
return
np
.
array
(
ret
)
test_suits
=
[
(
"Runtime"
,
1
,
1
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
),
def
test_get_parameters_by_name
(
self
):
(
"Engine"
,
1
,
1
,
DEFAULT_MODEL_NAME_FOR_TEST
),
test_suits
=
[(
"Engine"
,
1
,
1
),
(
"Runtime"
,
1
,
1
)]
]
if
torch
.
cuda
.
device_count
()
>=
2
:
if
torch
.
cuda
.
device_count
()
>=
2
:
test_suits
.
append
((
"Engine"
,
1
,
2
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
))
test_suits
.
append
((
"Engine"
,
1
,
2
))
test_suits
.
append
((
"Runtime"
,
2
,
1
,
DEFAULT_MODEL_NAME_FOR_TEST
))
test_suits
.
append
((
"Runtime"
,
2
,
1
))
if
torch
.
cuda
.
device_count
()
>=
4
:
if
torch
.
cuda
.
device_count
()
>=
4
:
test_suits
.
extend
(
test_suits
.
extend
([(
"Engine"
,
2
,
2
),
(
"Runtime"
,
2
,
2
)])
[
(
"Engine"
,
2
,
2
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
),
(
"Runtime"
,
2
,
2
,
DEFAULT_MODEL_NAME_FOR_TEST
),
]
)
parameters
=
[
parameters
=
[
"model.embed_tokens.weight"
,
"model.embed_tokens.weight"
,
...
@@ -117,11 +158,24 @@ class TestGetWeightsByName(unittest.TestCase):
...
@@ -117,11 +158,24 @@ class TestGetWeightsByName(unittest.TestCase):
"lm_head.weight"
,
"lm_head.weight"
,
]
]
truncate_size
=
100
for
test_suit
in
test_suits
:
for
test_suit
in
test_suits
:
if
test_suit
[
-
1
]
==
DEFAULT_MODEL_NAME_FOR_TEST
:
tie_word_embeddings
=
False
else
:
tie_word_embeddings
=
True
self
.
init_hf_model
(
test_suit
[
-
1
],
tie_word_embeddings
)
self
.
init_backend
(
*
test_suit
)
self
.
init_backend
(
*
test_suit
)
for
param_name
in
parameters
:
for
param_name
in
parameters
:
self
.
assert_weights_all_close
(
param_name
,
100
)
self
.
assert_weights_all_close
(
param_name
,
truncate_size
)
self
.
close_engine_and_server
()
if
tie_word_embeddings
:
self
.
assert_tie_word_embeddings
(
truncate_size
)
self
.
clean_up
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
test/srt/test_update_weights_from_distributed.py
0 → 100644
View file @
983bfcf3
"""Test distributed weight updates.
This test suite simulates a distributed training environment to ensure
correct weight synchronization. On rank 0, the instruct model represents
pre-training weights, and the base model represents post-training weights.
The base model's weights are broadcasted to other ranks using the online
weight update API.
On other ranks, an engine is initialized with the instruct model, and its
parameters are verified against the Hugging Face model. After updating
weights from the distributed system, post-training weights are loaded
and verified again to ensure consistency and accuracy across the
distributed setup.
"""
import
gc
import
os
import
time
import
unittest
import
numpy
as
np
import
requests
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
transformers
import
AutoModelForCausalLM
import
sglang
as
sgl
from
sglang.srt.utils
import
init_custom_process_group
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
is_in_ci
,
popen_launch_server
,
)
from
sglang.utils
import
terminate_process
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
def
verify_params_close
(
params1
,
params2
,
error_msg
):
"""Verify if two parameter arrays are close enough."""
try
:
assert
np
.
allclose
(
np
.
array
(
params1
),
np
.
array
(
params2
)),
error_msg
except
Exception
as
e
:
print
(
f
"Parameters not close for
{
error_msg
}
"
)
print
(
"Params1:"
,
np
.
array
(
params1
))
print
(
"Params2:"
,
np
.
array
(
params2
))
raise
e
def
verify_params_not_close
(
params1
,
params2
,
error_msg
):
"""Verify if two parameter arrays are different enough."""
assert
not
np
.
allclose
(
np
.
array
(
params1
),
np
.
array
(
params2
)),
error_msg
def
init_process
(
rank
,
world_size
,
param_queue
,
truncate_size
,
state_dict_key_to_shape
,
tp_size
,
model_name
,
backend
,
checking_parameters
,
tie_word_embeddings
,
):
torch
.
cuda
.
set_device
(
rank
)
if
rank
==
0
:
init_process_hf
(
rank
,
world_size
,
param_queue
,
truncate_size
,
model_name
,
checking_parameters
,
tie_word_embeddings
,
state_dict_key_to_shape
,
)
elif
rank
in
[
1
,
2
]:
init_process_sgl
(
rank
,
world_size
,
param_queue
,
truncate_size
,
model_name
,
checking_parameters
,
tie_word_embeddings
,
state_dict_key_to_shape
,
backend
,
tp_size
,
)
def
init_process_hf
(
rank
,
world_size
,
param_queue
,
truncate_size
,
model_name
,
checking_parameters
,
tie_word_embeddings
,
state_dict_key_to_shape
,
):
# These two environment variables are very important
# to avoid unexpected behaviors of CUDA and NCCL.
os
.
environ
[
"NCCL_CUMEM_ENABLE"
]
=
"0"
os
.
environ
[
"NCCL_NVLS_ENABLE"
]
=
"0"
# Load model and get parameters
hf_instruct_model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
"bfloat16"
,
tie_word_embeddings
=
tie_word_embeddings
,
).
to
(
"cuda:0"
)
base_model_name
=
model_name
.
replace
(
"-Instruct"
,
""
)
hf_base_model
=
AutoModelForCausalLM
.
from_pretrained
(
base_model_name
,
torch_dtype
=
"bfloat16"
,
tie_word_embeddings
=
tie_word_embeddings
,
).
to
(
"cuda:0"
)
hf_instruct_params
=
[]
hf_base_params
=
[]
print
(
f
"get parameter in hf instruct model and base model"
)
for
parameter_name
in
checking_parameters
:
hf_instruct_params
.
append
(
hf_instruct_model
.
get_parameter
(
parameter_name
)[:
truncate_size
]
.
cpu
()
.
detach
()
.
float
()
.
numpy
()
.
tolist
()
)
hf_base_params
.
append
(
hf_base_model
.
get_parameter
(
parameter_name
)[:
truncate_size
]
.
cpu
()
.
detach
()
.
float
()
.
numpy
()
.
tolist
()
)
param_queue
.
put
((
"hf_instruct_params"
,
hf_instruct_params
))
param_queue
.
put
((
"hf_base_params"
,
hf_base_params
))
# Init weight update group for rank 0 (the training engine in RLHF).
print
(
f
"rank
{
rank
}
world_size:
{
world_size
}
init custom process group"
)
group
=
init_custom_process_group
(
backend
=
"nccl"
,
init_method
=
"tcp://localhost:65500"
,
world_size
=
world_size
,
rank
=
rank
,
group_name
=
"test_parameter_update_group"
,
)
dist
.
barrier
(
group
=
group
,
device_ids
=
[
rank
])
torch
.
cuda
.
synchronize
()
time_begin_broadcast
=
time
.
time
()
# The last parameter is lm_head.weight, which is tied
# with embed_tokens.weight. Actually, we only need
# to broadcast embed_tokens.weight once.
broadcast_parameters
=
list
(
state_dict_key_to_shape
.
keys
())
if
tie_word_embeddings
:
broadcast_parameters
.
remove
(
"lm_head.weight"
)
# Broadcast all the weights from the training
# engine to other ranks (inference engine).
for
parameter_name
in
broadcast_parameters
:
torch
.
distributed
.
broadcast
(
hf_base_model
.
get_parameter
(
parameter_name
),
src
=
0
,
group
=
group
,
)
torch
.
cuda
.
synchronize
()
time_end_broadcast
=
time
.
time
()
# Measure the latency of broadcasting/weights update.
broadcast_time
=
time_end_broadcast
-
time_begin_broadcast
print
(
f
"rank
{
rank
}
broadcast parameter time:
{
broadcast_time
:.
3
f
}
s"
)
param_queue
.
put
((
"broadcast_time"
,
broadcast_time
))
# Delete the huggingface models to free up memory.
del
hf_instruct_model
del
hf_base_model
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
init_process_sgl
(
rank
,
world_size
,
param_queue
,
truncate_size
,
model_name
,
checking_parameters
,
tie_word_embeddings
,
state_dict_key_to_shape
,
backend
,
tp_size
,
):
torch
.
cuda
.
set_device
(
rank
)
torch
.
cuda
.
synchronize
()
base_gpu_id
=
1
if
rank
==
1
else
1
+
tp_size
if
backend
==
"Engine"
:
engine
=
sgl
.
Engine
(
model_path
=
model_name
,
random_seed
=
42
,
base_gpu_id
=
base_gpu_id
,
tp_size
=
tp_size
,
)
else
:
if
rank
==
1
:
url
=
DEFAULT_URL_FOR_TEST
else
:
url
=
DEFAULT_URL_FOR_TEST
.
replace
(
"2157"
,
"2159"
)
process
=
popen_launch_server
(
model_name
,
url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
(
"--base-gpu-id"
,
str
(
base_gpu_id
),
"--tp-size"
,
str
(
tp_size
),
),
)
torch
.
cuda
.
synchronize
()
if
backend
==
"Engine"
:
print
(
f
"rank
{
rank
}
init engine"
)
else
:
print
(
f
"rank
{
rank
}
init server on url:
{
url
}
"
)
# Get weights of instruct model, i.e. pre-training weights.
instruct_params
=
[]
for
parameter_name
in
checking_parameters
:
instruct_params
.
append
(
engine
.
get_weights_by_name
(
parameter_name
,
truncate_size
)
if
backend
==
"Engine"
else
requests
.
get
(
f
"
{
url
}
/get_weights_by_name"
,
json
=
{
"name"
:
parameter_name
,
"truncate_size"
:
truncate_size
},
).
json
()
)
param_queue
.
put
((
f
"sgl_dp_
{
rank
}
_instruct_params"
,
instruct_params
))
# Init weight update group with the training engine.
if
backend
==
"Engine"
:
engine
.
init_weights_update_group
(
master_address
=
"localhost"
,
master_port
=
"65500"
,
rank_offset
=
base_gpu_id
,
world_size
=
world_size
,
group_name
=
"test_parameter_update_group"
,
backend
=
"nccl"
,
)
else
:
requests
.
post
(
f
"
{
url
}
/init_weights_update_group"
,
json
=
{
"master_address"
:
"localhost"
,
"master_port"
:
"65500"
,
"rank_offset"
:
base_gpu_id
,
"world_size"
:
world_size
,
"group_name"
:
"test_parameter_update_group"
,
"backend"
:
"nccl"
,
},
)
torch
.
cuda
.
synchronize
()
time_begin_update
=
time
.
time
()
# The last parameter is lm_head.weight, which is tied
# with embed_tokens.weight. Actually, we only need
# to update embed_tokens.weight once.
tie_word_embeddings
=
(
True
if
model_name
==
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
else
False
)
update_parameters
=
list
(
state_dict_key_to_shape
.
keys
())
if
tie_word_embeddings
:
update_parameters
.
remove
(
"lm_head.weight"
)
# Get weights from the training engine and update the inference engine.
for
parameter_name
in
update_parameters
:
if
backend
==
"Engine"
:
engine
.
update_weights_from_distributed
(
parameter_name
,
dtype
=
torch
.
bfloat16
,
shape
=
state_dict_key_to_shape
[
parameter_name
],
)
else
:
requests
.
post
(
f
"
{
url
}
/update_weights_from_distributed"
,
json
=
{
"name"
:
parameter_name
,
"dtype"
:
"bfloat16"
,
"shape"
:
state_dict_key_to_shape
[
parameter_name
],
},
)
torch
.
cuda
.
synchronize
()
time_end_update
=
time
.
time
()
# Measure the latency of broadcast/weights update.
update_time
=
time_end_update
-
time_begin_update
print
(
f
"fully update model_name
{
model_name
}
rank
{
rank
}
parameter from distributed time:
{
update_time
:.
3
f
}
s"
)
param_queue
.
put
((
f
"update_sgl_dp_
{
rank
}
_time"
,
update_time
))
# Get the weights of post-training model after weights update for correctness check.
base_params
=
[]
for
parameter_name
in
checking_parameters
:
if
backend
==
"Engine"
:
base_params
.
append
(
engine
.
get_weights_by_name
(
parameter_name
,
truncate_size
)
)
else
:
base_params
.
append
(
requests
.
get
(
f
"
{
url
}
/get_weights_by_name"
,
json
=
{
"name"
:
parameter_name
,
"truncate_size"
:
truncate_size
,
},
).
json
()
)
param_queue
.
put
((
f
"sgl_dp_
{
rank
}
_base_params"
,
base_params
))
# Shutdown the engine or terminate the server process.
if
backend
==
"Engine"
:
engine
.
shutdown
()
else
:
terminate_process
(
process
)
def
assert_tied_weights
(
params_list
,
message
,
should_be_tied
):
for
params
in
params_list
:
if
should_be_tied
:
assert
np
.
allclose
(
params
[
0
],
params
[
-
1
]),
message
else
:
assert
not
np
.
allclose
(
params
[
0
],
params
[
-
1
]),
message
def
test_update_weights_from_distributed
(
tp_size
,
dp_size
,
model_name
,
backend
,
state_dict_key_to_shape
,
truncate_size
,
checking_parameters
,
):
tie_word_embeddings
=
(
True
if
model_name
==
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
else
False
)
print
(
f
"Testing model:
{
model_name
}
tp_size:
{
tp_size
}
, dp_size:
{
dp_size
}
backend:
{
backend
}
"
)
param_queue
=
mp
.
Queue
()
results
=
{}
context
=
mp
.
spawn
(
init_process
,
args
=
(
1
+
tp_size
*
dp_size
,
param_queue
,
truncate_size
,
state_dict_key_to_shape
,
tp_size
,
model_name
,
backend
,
checking_parameters
,
tie_word_embeddings
,
),
nprocs
=
1
+
dp_size
,
join
=
False
,
)
while
len
(
results
)
<
3
*
(
1
+
dp_size
):
try
:
key
,
value
=
param_queue
.
get
(
timeout
=
5
)
results
[
key
]
=
value
except
Exception
as
e
:
if
all
(
not
p
.
is_alive
()
for
p
in
context
.
processes
):
break
context
.
join
()
if
len
(
results
)
!=
3
*
(
1
+
dp_size
):
raise
RuntimeError
(
f
"Expected
{
3
*
(
1
+
dp_size
)
}
parameters but got
{
len
(
results
)
}
"
)
params
=
{
"hf_instruct"
:
results
.
get
(
"hf_instruct_params"
),
"hf_base"
:
results
.
get
(
"hf_base_params"
),
"sgl_dp_1_instruct"
:
results
.
get
(
"sgl_dp_1_instruct_params"
),
"sgl_dp_1_base"
:
results
.
get
(
"sgl_dp_1_base_params"
),
"broadcast_time"
:
results
.
get
(
"broadcast_time"
),
"update_sgl_dp_1_time"
:
results
.
get
(
"update_sgl_dp_1_time"
),
}
if
dp_size
==
2
:
dp2_params
=
{
"sgl_dp_2_instruct"
:
results
.
get
(
"sgl_dp_2_instruct_params"
),
"sgl_dp_2_base"
:
results
.
get
(
"sgl_dp_2_base_params"
),
"update_sgl_dp_2_time"
:
results
.
get
(
"update_sgl_dp_2_time"
),
}
assert
all
(
v
is
not
None
for
v
in
dp2_params
.
values
())
params
.
update
(
dp2_params
)
# Check the correctness of weights update by verifying
# the weights of instruct model and base model.
for
i
in
range
(
len
(
params
[
"hf_instruct"
])):
verify_params_close
(
params
[
"hf_instruct"
][
i
],
params
[
"sgl_dp_1_instruct"
][
i
],
f
"sgl_dp_1_instruct_params rank
{
i
}
"
,
)
verify_params_close
(
params
[
"hf_base"
][
i
],
params
[
"sgl_dp_1_base"
][
i
],
f
"sgl_dp_1_base_params rank
{
i
}
"
,
)
verify_params_not_close
(
params
[
"hf_instruct"
][
i
],
params
[
"hf_base"
][
i
],
f
"hf_instruct_params rank
{
i
}
"
,
)
if
dp_size
==
2
:
verify_params_close
(
params
[
"hf_base"
][
i
],
params
[
"sgl_dp_2_base"
][
i
],
f
"sgl_dp_2_base_params rank
{
i
}
"
,
)
verify_params_close
(
params
[
"hf_instruct"
][
i
],
params
[
"sgl_dp_2_instruct"
][
i
],
f
"sgl_dp_2_instruct_params rank
{
i
}
"
,
)
assert
len
(
params
[
"hf_instruct"
])
==
len
(
params
[
"hf_base"
]
),
"hf_instruct_params and hf_base_params have different lengths"
# Check if the weights of lm_head are tied with embed_tokens.
params_to_check
=
[
(
params
[
"hf_instruct"
],
"lm_head.weight is not tied with embed_tokens.weight"
,
),
(
params
[
"hf_base"
],
"lm_head.weight is not tied with embed_tokens.weight"
,
),
(
params
[
"sgl_dp_1_instruct"
],
"lm_head.weight is not tied with embed_tokens.weight"
,
),
(
params
[
"sgl_dp_1_base"
],
"lm_head.weight is not tied with embed_tokens.weight"
,
),
]
if
dp_size
==
2
:
params_to_check
.
extend
(
[
(
params
[
"sgl_dp_2_instruct"
],
"lm_head.weight is not tied with embed_tokens.weight"
,
),
(
params
[
"sgl_dp_2_base"
],
"lm_head.weight is not tied with embed_tokens.weight"
,
),
]
)
assert_tied_weights
(
[
params
for
params
,
_
in
params_to_check
],
(
"lm_head.weight is not tied with embed_tokens.weight"
if
tie_word_embeddings
else
"lm_head.weight is tied with embed_tokens.weight"
),
tie_word_embeddings
,
)
# Time limit for broadcast and update on CI is 3 / 6
# On local H100, it's 1 / 2
time_limit
=
3
if
model_name
==
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
else
6
assert
(
params
[
"broadcast_time"
]
<
time_limit
),
f
"broadcast_time exceeds time limit
{
time_limit
}
s"
assert
(
params
[
"update_sgl_dp_1_time"
]
<
time_limit
),
f
"update_sgl_dp_one_time exceeds time limit
{
time_limit
}
s"
if
dp_size
==
2
:
assert
(
params
[
"update_sgl_dp_2_time"
]
<
time_limit
),
f
"update_sgl_dp_two_time exceeds time limit
{
time_limit
}
s"
# Delete the context and close the parameter queue.
del
context
param_queue
.
close
()
param_queue
.
join_thread
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
class
TestUpdateWeightsFromDistributed
(
unittest
.
TestCase
):
def
test_update_weights_from_distributed
(
self
):
assert
torch
.
cuda
.
device_count
()
>=
2
,
"At least 2 GPUs are required"
# test_suits : tp, dp, model_name, backend
if
is_in_ci
():
test_suits
=
[
(
1
,
1
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
"Engine"
),
]
else
:
test_suits
=
[
(
1
,
1
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
"Engine"
),
(
1
,
1
,
DEFAULT_MODEL_NAME_FOR_TEST
,
"Sever"
),
]
if
torch
.
cuda
.
device_count
()
>=
4
:
test_suits
.
extend
(
[
(
2
,
1
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
"Engine"
),
(
1
,
2
,
DEFAULT_MODEL_NAME_FOR_TEST
,
"Server"
),
]
)
if
torch
.
cuda
.
device_count
()
>=
5
:
test_suits
.
extend
(
[
(
2
,
2
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
"Engine"
),
(
2
,
2
,
DEFAULT_MODEL_NAME_FOR_TEST
,
"Server"
),
]
)
model_state_dict_shapes
=
{}
test_models
=
[
test_suit
[
2
]
for
test_suit
in
test_suits
]
for
model_name
in
test_models
:
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
"bfloat16"
).
to
(
"cuda:0"
)
state_dict
=
model
.
state_dict
()
state_dict_keys
=
list
(
state_dict
.
keys
())
model_state_dict_shapes
[
model_name
]
=
{
key
:
state_dict
[
key
].
shape
for
key
in
state_dict_keys
}
del
model
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
truncate_size
=
10
checking_parameters
=
[
"model.embed_tokens.weight"
,
"model.layers.0.input_layernorm.weight"
,
"model.layers.1.self_attn.q_proj.weight"
,
"model.layers.2.self_attn.k_proj.weight"
,
"model.layers.3.self_attn.v_proj.weight"
,
"model.layers.4.self_attn.o_proj.weight"
,
"model.layers.5.mlp.gate_proj.weight"
,
"model.layers.6.mlp.up_proj.weight"
,
"model.layers.7.mlp.down_proj.weight"
,
"model.layers.8.post_attention_layernorm.weight"
,
"model.norm.weight"
,
"lm_head.weight"
,
]
for
tp_size
,
dp_size
,
model_name
,
backend
in
test_suits
:
test_update_weights_from_distributed
(
tp_size
,
dp_size
,
model_name
,
backend
,
model_state_dict_shapes
[
model_name
],
truncate_size
,
checking_parameters
,
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment