Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
27b557ae
Unverified
Commit
27b557ae
authored
Sep 16, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 16, 2024
Browse files
Clean up model loader (#1440)
parent
93dffd69
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
33 additions
and
80 deletions
+33
-80
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+3
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+18
-23
python/sglang/srt/server.py
python/sglang/srt/server.py
+1
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+2
-51
test/srt/test_update_weights.py
test/srt/test_update_weights.py
+9
-4
No files found.
python/sglang/srt/managers/tp_worker.py
View file @
27b557ae
...
@@ -415,7 +415,7 @@ class ModelTpServer:
...
@@ -415,7 +415,7 @@ class ModelTpServer:
# Truncate prompts that are too long
# Truncate prompts that are too long
if
len
(
req
.
origin_input_ids
)
>=
self
.
max_req_input_len
:
if
len
(
req
.
origin_input_ids
)
>=
self
.
max_req_input_len
:
logger
.
warn
(
logger
.
warn
ing
(
"Request length is longer than the KV cache pool size or "
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
"the max context length. Truncated!!!"
)
)
...
@@ -936,6 +936,8 @@ class ModelTpServer:
...
@@ -936,6 +936,8 @@ class ModelTpServer:
if
success
:
if
success
:
flash_cache_success
=
self
.
flush_cache
()
flash_cache_success
=
self
.
flush_cache
()
assert
flash_cache_success
,
"Cache flush failed after updating weights"
assert
flash_cache_success
,
"Cache flush failed after updating weights"
else
:
logger
.
error
(
message
)
return
success
,
message
return
success
,
message
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
27b557ae
...
@@ -54,11 +54,9 @@ from sglang.srt.server_args import ServerArgs
...
@@ -54,11 +54,9 @@ from sglang.srt.server_args import ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_available_gpu_memory
,
get_available_gpu_memory
,
is_generation_model
,
is_generation_model
,
is_llama3_405b_fp8_head_16
,
is_multimodal_model
,
is_multimodal_model
,
monkey_patch_vllm_dummy_weight_loader
,
monkey_patch_vllm_dummy_weight_loader
,
monkey_patch_vllm_p2p_access_check
,
monkey_patch_vllm_p2p_access_check
,
monkey_patch_vllm_qvk_linear_loader
,
)
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -166,10 +164,13 @@ class ModelRunner:
...
@@ -166,10 +164,13 @@ class ModelRunner:
return
min_per_gpu_memory
return
min_per_gpu_memory
def
load_model
(
self
):
def
load_model
(
self
):
torch
.
set_num_threads
(
1
)
logger
.
info
(
logger
.
info
(
f
"Load weight begin. avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
f
"Load weight begin. avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
)
)
# This can reduce thread conflicts and speed up weight loading.
torch
.
set_num_threads
(
1
)
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
8
:
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
8
:
logger
.
info
(
logger
.
info
(
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
...
@@ -178,6 +179,7 @@ class ModelRunner:
...
@@ -178,6 +179,7 @@ class ModelRunner:
if
torch
.
cuda
.
get_device_capability
()[
1
]
<
5
:
if
torch
.
cuda
.
get_device_capability
()[
1
]
<
5
:
raise
RuntimeError
(
"SGLang only supports sm75 and above."
)
raise
RuntimeError
(
"SGLang only supports sm75 and above."
)
# Prepare the vllm model config
monkey_patch_vllm_dummy_weight_loader
()
monkey_patch_vllm_dummy_weight_loader
()
self
.
device_config
=
DeviceConfig
()
self
.
device_config
=
DeviceConfig
()
self
.
load_config
=
LoadConfig
(
load_format
=
self
.
server_args
.
load_format
)
self
.
load_config
=
LoadConfig
(
load_format
=
self
.
server_args
.
load_format
)
...
@@ -188,23 +190,16 @@ class ModelRunner:
...
@@ -188,23 +190,16 @@ class ModelRunner:
tokenizer_mode
=
None
,
tokenizer_mode
=
None
,
trust_remote_code
=
self
.
server_args
.
trust_remote_code
,
trust_remote_code
=
self
.
server_args
.
trust_remote_code
,
dtype
=
self
.
server_args
.
dtype
,
dtype
=
self
.
server_args
.
dtype
,
seed
=
42
,
seed
=
self
.
server_args
.
random_seed
,
skip_tokenizer_init
=
True
,
skip_tokenizer_init
=
True
,
)
)
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
# Drop this after Sept, 2024.
if
is_llama3_405b_fp8_head_16
(
self
.
model_config
)
and
self
.
tp_size
<=
8
:
self
.
model_config
.
hf_config
.
num_key_value_heads
=
8
self
.
vllm_model_config
.
hf_config
.
num_key_value_heads
=
8
monkey_patch_vllm_qvk_linear_loader
()
self
.
dtype
=
self
.
vllm_model_config
.
dtype
if
self
.
model_config
.
model_override_args
is
not
None
:
if
self
.
model_config
.
model_override_args
is
not
None
:
self
.
vllm_model_config
.
hf_config
.
update
(
self
.
vllm_model_config
.
hf_config
.
update
(
self
.
model_config
.
model_override_args
self
.
model_config
.
model_override_args
)
)
self
.
dtype
=
self
.
vllm_model_config
.
dtype
# Load the model
self
.
model
=
get_model
(
self
.
model
=
get_model
(
model_config
=
self
.
vllm_model_config
,
model_config
=
self
.
vllm_model_config
,
load_config
=
self
.
load_config
,
load_config
=
self
.
load_config
,
...
@@ -255,20 +250,20 @@ class ModelRunner:
...
@@ -255,20 +250,20 @@ class ModelRunner:
tokenizer_mode
=
None
,
tokenizer_mode
=
None
,
trust_remote_code
=
self
.
server_args
.
trust_remote_code
,
trust_remote_code
=
self
.
server_args
.
trust_remote_code
,
dtype
=
self
.
server_args
.
dtype
,
dtype
=
self
.
server_args
.
dtype
,
seed
=
42
,
seed
=
self
.
server_args
.
random_seed
,
skip_tokenizer_init
=
True
,
skip_tokenizer_init
=
True
,
)
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Failed to load model config:
{
e
}
"
)
message
=
f
"Failed to load model config:
{
e
}
.
"
return
False
,
"Failed to update model weights"
return
False
,
message
load_config
=
LoadConfig
(
load_format
=
load_format
)
load_config
=
LoadConfig
(
load_format
=
load_format
)
# Only support vllm DefaultModelLoader for now
# Only support vllm DefaultModelLoader for now
loader
=
get_model_loader
(
load_config
)
loader
=
get_model_loader
(
load_config
)
if
not
isinstance
(
loader
,
DefaultModelLoader
):
if
not
isinstance
(
loader
,
DefaultModelLoader
):
logger
.
error
(
"Failed to get
weights iterator: Unsupported
loader"
)
message
=
f
"Failed to get
model loader:
{
loader
}
.
"
return
False
,
"Failed to update model weights"
return
False
,
message
def
get_weight_iter
(
config
):
def
get_weight_iter
(
config
):
iter
=
loader
.
_get_weights_iterator
(
iter
=
loader
.
_get_weights_iterator
(
...
@@ -293,14 +288,14 @@ class ModelRunner:
...
@@ -293,14 +288,14 @@ class ModelRunner:
try
:
try
:
iter
=
get_weight_iter
(
vllm_model_config
)
iter
=
get_weight_iter
(
vllm_model_config
)
except
Exception
as
e
:
except
Exception
as
e
:
message
=
f
"Failed to get weights iterator:
{
e
}
"
message
=
f
"Failed to get weights iterator:
{
e
}
."
logger
.
error
(
message
)
return
False
,
message
return
False
,
message
try
:
try
:
model
=
model_load_weights
(
self
.
model
,
iter
)
model
=
model_load_weights
(
self
.
model
,
iter
)
except
Exception
as
e
:
except
Exception
as
e
:
message
=
f
"Failed to update weights:
{
e
}
.
\n
Rolling back to original weights"
message
=
(
logger
.
error
(
message
)
f
"Failed to update weights:
{
e
}
.
\n
Rolling back to original weights."
)
del
iter
del
iter
gc
.
collect
()
gc
.
collect
()
iter
=
get_weight_iter
(
self
.
vllm_model_config
)
iter
=
get_weight_iter
(
self
.
vllm_model_config
)
...
@@ -315,7 +310,7 @@ class ModelRunner:
...
@@ -315,7 +310,7 @@ class ModelRunner:
self
.
model_config
.
path
=
model_path
self
.
model_config
.
path
=
model_path
logger
.
info
(
"Update weights end."
)
logger
.
info
(
"Update weights end."
)
return
True
,
"Succeeded to update model weights"
return
True
,
"Succeeded to update model weights
.
"
def
init_lora_manager
(
self
):
def
init_lora_manager
(
self
):
self
.
lora_manager
=
LoRAManager
(
self
.
lora_manager
=
LoRAManager
(
...
...
python/sglang/srt/server.py
View file @
27b557ae
...
@@ -152,7 +152,7 @@ async def flush_cache():
...
@@ -152,7 +152,7 @@ async def flush_cache():
async
def
update_weights
(
obj
:
UpdateWeightReqInput
,
request
:
Request
):
async
def
update_weights
(
obj
:
UpdateWeightReqInput
,
request
:
Request
):
success
,
message
=
await
tokenizer_manager
.
update_weights
(
obj
,
request
)
success
,
message
=
await
tokenizer_manager
.
update_weights
(
obj
,
request
)
content
=
{
"
message"
:
message
,
"success"
:
str
(
success
)
}
content
=
{
"
success"
:
success
,
"message"
:
message
}
if
success
:
if
success
:
return
JSONResponse
(
return
JSONResponse
(
content
,
content
,
...
...
python/sglang/srt/utils.py
View file @
27b557ae
...
@@ -187,7 +187,7 @@ def allocate_init_ports(
...
@@ -187,7 +187,7 @@ def allocate_init_ports(
cur_port
+=
1
cur_port
+=
1
if
port
is
not
None
and
ret_ports
[
0
]
!=
port
:
if
port
is
not
None
and
ret_ports
[
0
]
!=
port
:
logger
.
warn
(
logger
.
warn
ing
(
f
"WARNING: Port
{
port
}
is not available. Use port
{
ret_ports
[
0
]
}
instead."
f
"WARNING: Port
{
port
}
is not available. Use port
{
ret_ports
[
0
]
}
instead."
)
)
...
@@ -623,56 +623,7 @@ def set_ulimit(target_soft_limit=65535):
...
@@ -623,56 +623,7 @@ def set_ulimit(target_soft_limit=65535):
try
:
try
:
resource
.
setrlimit
(
resource_type
,
(
target_soft_limit
,
current_hard
))
resource
.
setrlimit
(
resource_type
,
(
target_soft_limit
,
current_hard
))
except
ValueError
as
e
:
except
ValueError
as
e
:
logger
.
warn
(
f
"Fail to set RLIMIT_NOFILE:
{
e
}
"
)
logger
.
warning
(
f
"Fail to set RLIMIT_NOFILE:
{
e
}
"
)
def
is_llama3_405b_fp8_head_16
(
model_config
):
"""Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads."""
if
(
model_config
.
hf_config
.
architectures
[
0
]
==
"LlamaForCausalLM"
and
model_config
.
hf_config
.
hidden_size
==
16384
and
model_config
.
hf_config
.
intermediate_size
==
53248
and
model_config
.
hf_config
.
num_hidden_layers
==
126
and
model_config
.
hf_config
.
num_key_value_heads
==
16
and
hasattr
(
model_config
.
hf_config
,
"quantization_config"
)
and
model_config
.
hf_config
.
quantization_config
[
"quant_method"
]
==
"fbgemm_fp8"
):
return
True
return
False
def
monkey_patch_vllm_qvk_linear_loader
():
"""A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints."""
from
vllm.model_executor.layers.linear
import
QKVParallelLinear
origin_weight_loader
=
QKVParallelLinear
.
weight_loader
def
get_original_weight
(
loaded_weight
,
head_dim
):
n_kv_head
=
loaded_weight
.
shape
[
0
]
//
(
2
*
head_dim
)
dim
=
loaded_weight
.
shape
[
1
]
for
i
in
range
(
n_kv_head
):
loaded_weight
[
i
*
head_dim
:
(
i
+
1
)
*
head_dim
,
:]
=
loaded_weight
[
2
*
i
*
head_dim
:
(
2
*
i
+
1
)
*
head_dim
,
:
]
original_kv_weight
=
loaded_weight
[:
n_kv_head
*
head_dim
,
:]
assert
original_kv_weight
.
shape
==
(
n_kv_head
*
head_dim
,
dim
)
return
original_kv_weight
def
weight_loader_srt
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
str
]
=
None
,
):
if
(
loaded_shard_id
in
[
"k"
,
"v"
]
and
loaded_weight
.
shape
[
0
]
==
self
.
head_size
*
self
.
total_num_kv_heads
*
2
):
loaded_weight
=
get_original_weight
(
loaded_weight
,
self
.
head_size
)
origin_weight_loader
(
self
,
param
,
loaded_weight
,
loaded_shard_id
)
setattr
(
QKVParallelLinear
,
"weight_loader"
,
weight_loader_srt
)
def
add_api_key_middleware
(
app
,
api_key
:
str
):
def
add_api_key_middleware
(
app
,
api_key
:
str
):
...
...
test/srt/test_update_weights.py
View file @
27b557ae
...
@@ -44,7 +44,6 @@ class TestReplaceWeights(unittest.TestCase):
...
@@ -44,7 +44,6 @@ class TestReplaceWeights(unittest.TestCase):
)
)
print
(
json
.
dumps
(
response
.
json
()))
print
(
json
.
dumps
(
response
.
json
()))
print
(
"="
*
100
)
print
(
"="
*
100
)
# return the "text" in response
text
=
response
.
json
()[
"text"
]
text
=
response
.
json
()[
"text"
]
return
text
return
text
...
@@ -61,7 +60,9 @@ class TestReplaceWeights(unittest.TestCase):
...
@@ -61,7 +60,9 @@ class TestReplaceWeights(unittest.TestCase):
"model_path"
:
model_path
,
"model_path"
:
model_path
,
},
},
)
)
ret
=
response
.
json
()
print
(
json
.
dumps
(
response
.
json
()))
print
(
json
.
dumps
(
response
.
json
()))
return
ret
def
test_replace_weights
(
self
):
def
test_replace_weights
(
self
):
origin_model_path
=
self
.
get_model_info
()
origin_model_path
=
self
.
get_model_info
()
...
@@ -70,7 +71,8 @@ class TestReplaceWeights(unittest.TestCase):
...
@@ -70,7 +71,8 @@ class TestReplaceWeights(unittest.TestCase):
# update weights
# update weights
new_model_path
=
"meta-llama/Meta-Llama-3.1-8B"
new_model_path
=
"meta-llama/Meta-Llama-3.1-8B"
self
.
run_update_weights
(
new_model_path
)
ret
=
self
.
run_update_weights
(
new_model_path
)
assert
ret
[
"success"
]
updated_model_path
=
self
.
get_model_info
()
updated_model_path
=
self
.
get_model_info
()
print
(
f
"updated_model_path:
{
updated_model_path
}
"
)
print
(
f
"updated_model_path:
{
updated_model_path
}
"
)
...
@@ -81,7 +83,9 @@ class TestReplaceWeights(unittest.TestCase):
...
@@ -81,7 +83,9 @@ class TestReplaceWeights(unittest.TestCase):
assert
origin_response
[:
32
]
!=
updated_response
[:
32
]
assert
origin_response
[:
32
]
!=
updated_response
[:
32
]
# update weights back
# update weights back
self
.
run_update_weights
(
origin_model_path
)
ret
=
self
.
run_update_weights
(
origin_model_path
)
assert
ret
[
"success"
]
updated_model_path
=
self
.
get_model_info
()
updated_model_path
=
self
.
get_model_info
()
assert
updated_model_path
==
origin_model_path
assert
updated_model_path
==
origin_model_path
...
@@ -95,7 +99,8 @@ class TestReplaceWeights(unittest.TestCase):
...
@@ -95,7 +99,8 @@ class TestReplaceWeights(unittest.TestCase):
# update weights
# update weights
new_model_path
=
"meta-llama/Meta-Llama-3.1-8B-1"
new_model_path
=
"meta-llama/Meta-Llama-3.1-8B-1"
self
.
run_update_weights
(
new_model_path
)
ret
=
self
.
run_update_weights
(
new_model_path
)
assert
not
ret
[
"success"
]
updated_model_path
=
self
.
get_model_info
()
updated_model_path
=
self
.
get_model_info
()
print
(
f
"updated_model_path:
{
updated_model_path
}
"
)
print
(
f
"updated_model_path:
{
updated_model_path
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment