Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
27b557ae
"torchvision/vscode:/vscode.git/clone" did not exist on "caeaeb650e76e667006305277064e3e324915256"
Unverified
Commit
27b557ae
authored
Sep 16, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 16, 2024
Browse files
Clean up model loader (#1440)
parent
93dffd69
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
33 additions
and
80 deletions
+33
-80
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+3
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+18
-23
python/sglang/srt/server.py
python/sglang/srt/server.py
+1
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+2
-51
test/srt/test_update_weights.py
test/srt/test_update_weights.py
+9
-4
No files found.
python/sglang/srt/managers/tp_worker.py
View file @
27b557ae
...
...
@@ -415,7 +415,7 @@ class ModelTpServer:
# Truncate prompts that are too long
if
len
(
req
.
origin_input_ids
)
>=
self
.
max_req_input_len
:
logger
.
warn
(
logger
.
warn
ing
(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
)
...
...
@@ -936,6 +936,8 @@ class ModelTpServer:
if
success
:
flash_cache_success
=
self
.
flush_cache
()
assert
flash_cache_success
,
"Cache flush failed after updating weights"
else
:
logger
.
error
(
message
)
return
success
,
message
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
27b557ae
...
...
@@ -54,11 +54,9 @@ from sglang.srt.server_args import ServerArgs
from
sglang.srt.utils
import
(
get_available_gpu_memory
,
is_generation_model
,
is_llama3_405b_fp8_head_16
,
is_multimodal_model
,
monkey_patch_vllm_dummy_weight_loader
,
monkey_patch_vllm_p2p_access_check
,
monkey_patch_vllm_qvk_linear_loader
,
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -166,10 +164,13 @@ class ModelRunner:
return
min_per_gpu_memory
def
load_model
(
self
):
torch
.
set_num_threads
(
1
)
logger
.
info
(
f
"Load weight begin. avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
)
# This can reduce thread conflicts and speed up weight loading.
torch
.
set_num_threads
(
1
)
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
8
:
logger
.
info
(
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
...
...
@@ -178,6 +179,7 @@ class ModelRunner:
if
torch
.
cuda
.
get_device_capability
()[
1
]
<
5
:
raise
RuntimeError
(
"SGLang only supports sm75 and above."
)
# Prepare the vllm model config
monkey_patch_vllm_dummy_weight_loader
()
self
.
device_config
=
DeviceConfig
()
self
.
load_config
=
LoadConfig
(
load_format
=
self
.
server_args
.
load_format
)
...
...
@@ -188,23 +190,16 @@ class ModelRunner:
tokenizer_mode
=
None
,
trust_remote_code
=
self
.
server_args
.
trust_remote_code
,
dtype
=
self
.
server_args
.
dtype
,
seed
=
42
,
seed
=
self
.
server_args
.
random_seed
,
skip_tokenizer_init
=
True
,
)
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
# Drop this after Sept, 2024.
if
is_llama3_405b_fp8_head_16
(
self
.
model_config
)
and
self
.
tp_size
<=
8
:
self
.
model_config
.
hf_config
.
num_key_value_heads
=
8
self
.
vllm_model_config
.
hf_config
.
num_key_value_heads
=
8
monkey_patch_vllm_qvk_linear_loader
()
self
.
dtype
=
self
.
vllm_model_config
.
dtype
if
self
.
model_config
.
model_override_args
is
not
None
:
self
.
vllm_model_config
.
hf_config
.
update
(
self
.
model_config
.
model_override_args
)
self
.
dtype
=
self
.
vllm_model_config
.
dtype
# Load the model
self
.
model
=
get_model
(
model_config
=
self
.
vllm_model_config
,
load_config
=
self
.
load_config
,
...
...
@@ -255,20 +250,20 @@ class ModelRunner:
tokenizer_mode
=
None
,
trust_remote_code
=
self
.
server_args
.
trust_remote_code
,
dtype
=
self
.
server_args
.
dtype
,
seed
=
42
,
seed
=
self
.
server_args
.
random_seed
,
skip_tokenizer_init
=
True
,
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to load model config:
{
e
}
"
)
return
False
,
"Failed to update model weights"
message
=
f
"Failed to load model config:
{
e
}
.
"
return
False
,
message
load_config
=
LoadConfig
(
load_format
=
load_format
)
# Only support vllm DefaultModelLoader for now
loader
=
get_model_loader
(
load_config
)
if
not
isinstance
(
loader
,
DefaultModelLoader
):
logger
.
error
(
"Failed to get
weights iterator: Unsupported
loader"
)
return
False
,
"Failed to update model weights"
message
=
f
"Failed to get
model loader:
{
loader
}
.
"
return
False
,
message
def
get_weight_iter
(
config
):
iter
=
loader
.
_get_weights_iterator
(
...
...
@@ -293,14 +288,14 @@ class ModelRunner:
try
:
iter
=
get_weight_iter
(
vllm_model_config
)
except
Exception
as
e
:
message
=
f
"Failed to get weights iterator:
{
e
}
"
logger
.
error
(
message
)
message
=
f
"Failed to get weights iterator:
{
e
}
."
return
False
,
message
try
:
model
=
model_load_weights
(
self
.
model
,
iter
)
except
Exception
as
e
:
message
=
f
"Failed to update weights:
{
e
}
.
\n
Rolling back to original weights"
logger
.
error
(
message
)
message
=
(
f
"Failed to update weights:
{
e
}
.
\n
Rolling back to original weights."
)
del
iter
gc
.
collect
()
iter
=
get_weight_iter
(
self
.
vllm_model_config
)
...
...
@@ -315,7 +310,7 @@ class ModelRunner:
self
.
model_config
.
path
=
model_path
logger
.
info
(
"Update weights end."
)
return
True
,
"Succeeded to update model weights"
return
True
,
"Succeeded to update model weights
.
"
def
init_lora_manager
(
self
):
self
.
lora_manager
=
LoRAManager
(
...
...
python/sglang/srt/server.py
View file @
27b557ae
...
...
@@ -152,7 +152,7 @@ async def flush_cache():
async
def
update_weights
(
obj
:
UpdateWeightReqInput
,
request
:
Request
):
success
,
message
=
await
tokenizer_manager
.
update_weights
(
obj
,
request
)
content
=
{
"
message"
:
message
,
"success"
:
str
(
success
)
}
content
=
{
"
success"
:
success
,
"message"
:
message
}
if
success
:
return
JSONResponse
(
content
,
...
...
python/sglang/srt/utils.py
View file @
27b557ae
...
...
@@ -187,7 +187,7 @@ def allocate_init_ports(
cur_port
+=
1
if
port
is
not
None
and
ret_ports
[
0
]
!=
port
:
logger
.
warn
(
logger
.
warn
ing
(
f
"WARNING: Port
{
port
}
is not available. Use port
{
ret_ports
[
0
]
}
instead."
)
...
...
@@ -623,56 +623,7 @@ def set_ulimit(target_soft_limit=65535):
try
:
resource
.
setrlimit
(
resource_type
,
(
target_soft_limit
,
current_hard
))
except
ValueError
as
e
:
logger
.
warn
(
f
"Fail to set RLIMIT_NOFILE:
{
e
}
"
)
def
is_llama3_405b_fp8_head_16
(
model_config
):
"""Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads."""
if
(
model_config
.
hf_config
.
architectures
[
0
]
==
"LlamaForCausalLM"
and
model_config
.
hf_config
.
hidden_size
==
16384
and
model_config
.
hf_config
.
intermediate_size
==
53248
and
model_config
.
hf_config
.
num_hidden_layers
==
126
and
model_config
.
hf_config
.
num_key_value_heads
==
16
and
hasattr
(
model_config
.
hf_config
,
"quantization_config"
)
and
model_config
.
hf_config
.
quantization_config
[
"quant_method"
]
==
"fbgemm_fp8"
):
return
True
return
False
def
monkey_patch_vllm_qvk_linear_loader
():
"""A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints."""
from
vllm.model_executor.layers.linear
import
QKVParallelLinear
origin_weight_loader
=
QKVParallelLinear
.
weight_loader
def
get_original_weight
(
loaded_weight
,
head_dim
):
n_kv_head
=
loaded_weight
.
shape
[
0
]
//
(
2
*
head_dim
)
dim
=
loaded_weight
.
shape
[
1
]
for
i
in
range
(
n_kv_head
):
loaded_weight
[
i
*
head_dim
:
(
i
+
1
)
*
head_dim
,
:]
=
loaded_weight
[
2
*
i
*
head_dim
:
(
2
*
i
+
1
)
*
head_dim
,
:
]
original_kv_weight
=
loaded_weight
[:
n_kv_head
*
head_dim
,
:]
assert
original_kv_weight
.
shape
==
(
n_kv_head
*
head_dim
,
dim
)
return
original_kv_weight
def
weight_loader_srt
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
str
]
=
None
,
):
if
(
loaded_shard_id
in
[
"k"
,
"v"
]
and
loaded_weight
.
shape
[
0
]
==
self
.
head_size
*
self
.
total_num_kv_heads
*
2
):
loaded_weight
=
get_original_weight
(
loaded_weight
,
self
.
head_size
)
origin_weight_loader
(
self
,
param
,
loaded_weight
,
loaded_shard_id
)
setattr
(
QKVParallelLinear
,
"weight_loader"
,
weight_loader_srt
)
logger
.
warning
(
f
"Fail to set RLIMIT_NOFILE:
{
e
}
"
)
def
add_api_key_middleware
(
app
,
api_key
:
str
):
...
...
test/srt/test_update_weights.py
View file @
27b557ae
...
...
@@ -44,7 +44,6 @@ class TestReplaceWeights(unittest.TestCase):
)
print
(
json
.
dumps
(
response
.
json
()))
print
(
"="
*
100
)
# return the "text" in response
text
=
response
.
json
()[
"text"
]
return
text
...
...
@@ -61,7 +60,9 @@ class TestReplaceWeights(unittest.TestCase):
"model_path"
:
model_path
,
},
)
ret
=
response
.
json
()
print
(
json
.
dumps
(
response
.
json
()))
return
ret
def
test_replace_weights
(
self
):
origin_model_path
=
self
.
get_model_info
()
...
...
@@ -70,7 +71,8 @@ class TestReplaceWeights(unittest.TestCase):
# update weights
new_model_path
=
"meta-llama/Meta-Llama-3.1-8B"
self
.
run_update_weights
(
new_model_path
)
ret
=
self
.
run_update_weights
(
new_model_path
)
assert
ret
[
"success"
]
updated_model_path
=
self
.
get_model_info
()
print
(
f
"updated_model_path:
{
updated_model_path
}
"
)
...
...
@@ -81,7 +83,9 @@ class TestReplaceWeights(unittest.TestCase):
assert
origin_response
[:
32
]
!=
updated_response
[:
32
]
# update weights back
self
.
run_update_weights
(
origin_model_path
)
ret
=
self
.
run_update_weights
(
origin_model_path
)
assert
ret
[
"success"
]
updated_model_path
=
self
.
get_model_info
()
assert
updated_model_path
==
origin_model_path
...
...
@@ -95,7 +99,8 @@ class TestReplaceWeights(unittest.TestCase):
# update weights
new_model_path
=
"meta-llama/Meta-Llama-3.1-8B-1"
self
.
run_update_weights
(
new_model_path
)
ret
=
self
.
run_update_weights
(
new_model_path
)
assert
not
ret
[
"success"
]
updated_model_path
=
self
.
get_model_info
()
print
(
f
"updated_model_path:
{
updated_model_path
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment