Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
cd10654e
Unverified
Commit
cd10654e
authored
Aug 20, 2024
by
Shan Yu
Committed by
GitHub
Aug 20, 2024
Browse files
[Feat] Support update weights without restart server (#1157)
parent
350a8160
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
303 additions
and
13 deletions
+303
-13
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+5
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+14
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+42
-3
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+17
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+97
-9
python/sglang/srt/server.py
python/sglang/srt/server.py
+22
-1
test/srt/test_update_weights.py
test/srt/test_update_weights.py
+106
-0
No files found.
python/sglang/srt/managers/detokenizer_manager.py
View file @
cd10654e
...
@@ -28,6 +28,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -28,6 +28,7 @@ from sglang.srt.managers.io_struct import (
BatchEmbeddingOut
,
BatchEmbeddingOut
,
BatchStrOut
,
BatchStrOut
,
BatchTokenIDOut
,
BatchTokenIDOut
,
UpdateWeightReqOutput
,
)
)
from
sglang.srt.managers.schedule_batch
import
FINISH_MATCHED_STR
from
sglang.srt.managers.schedule_batch
import
FINISH_MATCHED_STR
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
...
@@ -84,6 +85,10 @@ class DetokenizerManager:
...
@@ -84,6 +85,10 @@ class DetokenizerManager:
)
)
continue
continue
if
isinstance
(
recv_obj
,
UpdateWeightReqOutput
):
self
.
send_to_tokenizer
.
send_pyobj
(
recv_obj
)
continue
assert
isinstance
(
recv_obj
,
BatchTokenIDOut
)
assert
isinstance
(
recv_obj
,
BatchTokenIDOut
)
bs
=
len
(
recv_obj
.
rids
)
bs
=
len
(
recv_obj
.
rids
)
...
...
python/sglang/srt/managers/io_struct.py
View file @
cd10654e
...
@@ -278,6 +278,20 @@ class FlushCacheReq:
...
@@ -278,6 +278,20 @@ class FlushCacheReq:
pass
pass
@
dataclass
class
UpdateWeightReqInput
:
# The model path with the new weights
model_path
:
str
# The format to load the weights
load_format
:
Optional
[
str
]
=
None
@
dataclass
class
UpdateWeightReqOutput
:
success
:
bool
message
:
str
@
dataclass
@
dataclass
class
AbortReq
:
class
AbortReq
:
# The request id
# The request id
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
cd10654e
...
@@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import (
...
@@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import (
GenerateReqInput
,
GenerateReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
UpdateWeightReqInput
,
UpdateWeightReqOutput
,
)
)
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.sampling_params
import
SamplingParams
...
@@ -121,6 +123,10 @@ class TokenizerManager:
...
@@ -121,6 +123,10 @@ class TokenizerManager:
self
.
to_create_loop
=
True
self
.
to_create_loop
=
True
self
.
rid_to_state
:
Dict
[
str
,
ReqState
]
=
{}
self
.
rid_to_state
:
Dict
[
str
,
ReqState
]
=
{}
# for update model weights
self
.
model_update_lock
=
asyncio
.
Lock
()
self
.
model_update_result
=
None
async
def
get_pixel_values
(
self
,
image_data
):
async
def
get_pixel_values
(
self
,
image_data
):
aspect_ratio
=
getattr
(
self
.
hf_config
,
"image_aspect_ratio"
,
None
)
aspect_ratio
=
getattr
(
self
.
hf_config
,
"image_aspect_ratio"
,
None
)
grid_pinpoints
=
(
grid_pinpoints
=
(
...
@@ -146,6 +152,9 @@ class TokenizerManager:
...
@@ -146,6 +152,9 @@ class TokenizerManager:
if
self
.
to_create_loop
:
if
self
.
to_create_loop
:
self
.
create_handle_loop
()
self
.
create_handle_loop
()
while
self
.
model_update_lock
.
locked
():
await
asyncio
.
sleep
(
0
)
obj
.
post_init
()
obj
.
post_init
()
is_single
=
obj
.
is_single
is_single
=
obj
.
is_single
...
@@ -513,6 +522,30 @@ class TokenizerManager:
...
@@ -513,6 +522,30 @@ class TokenizerManager:
req
=
FlushCacheReq
()
req
=
FlushCacheReq
()
self
.
send_to_router
.
send_pyobj
(
req
)
self
.
send_to_router
.
send_pyobj
(
req
)
async
def
update_weights
(
self
,
obj
:
UpdateWeightReqInput
,
request
):
if
self
.
to_create_loop
:
self
.
create_handle_loop
()
# default the load format to the server_args
if
obj
.
load_format
is
None
:
obj
.
load_format
=
self
.
server_args
.
load_format
if
not
self
.
model_update_lock
.
locked
():
async
with
self
.
model_update_lock
:
# wait for the previous generation requests to finish
while
len
(
self
.
rid_to_state
)
>
0
:
await
asyncio
.
sleep
(
0
)
self
.
send_to_router
.
send_pyobj
(
obj
)
self
.
model_update_result
=
asyncio
.
Future
()
result
=
await
self
.
model_update_result
if
result
.
success
:
self
.
server_args
.
model_path
=
obj
.
model_path
self
.
server_args
.
load_format
=
obj
.
load_format
self
.
model_path
=
obj
.
model_path
return
result
.
success
,
result
.
message
else
:
return
False
,
"Another update is in progress. Please try again later."
def
abort_request
(
self
,
rid
:
str
):
def
abort_request
(
self
,
rid
:
str
):
if
rid
not
in
self
.
rid_to_state
:
if
rid
not
in
self
.
rid_to_state
:
return
return
...
@@ -541,12 +574,18 @@ class TokenizerManager:
...
@@ -541,12 +574,18 @@ class TokenizerManager:
async
def
handle_loop
(
self
):
async
def
handle_loop
(
self
):
while
True
:
while
True
:
recv_obj
:
Union
[
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
]
=
(
recv_obj
:
Union
[
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
,
UpdateWeightReqOutput
)
]
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
if
isinstance
(
recv_obj
,
UpdateWeightReqOutput
):
self
.
model_update_result
.
set_result
(
recv_obj
)
continue
assert
isinstance
(
assert
isinstance
(
recv_obj
,
(
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
)
recv_obj
,
(
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
)
),
f
"Unexpected obj received:
{
type
(
recv_obj
)
}
"
),
f
"Unexpected obj received:
{
type
(
recv_obj
)
}
"
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
state
=
self
.
rid_to_state
.
get
(
rid
,
None
)
state
=
self
.
rid_to_state
.
get
(
rid
,
None
)
if
state
is
None
:
if
state
is
None
:
...
...
python/sglang/srt/managers/tp_worker.py
View file @
cd10654e
...
@@ -39,6 +39,8 @@ from sglang.srt.managers.io_struct import (
...
@@ -39,6 +39,8 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq
,
FlushCacheReq
,
TokenizedEmbeddingReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
UpdateWeightReqInput
,
UpdateWeightReqOutput
,
)
)
from
sglang.srt.managers.policy_scheduler
import
PolicyScheduler
,
PrefillAdder
from
sglang.srt.managers.policy_scheduler
import
PolicyScheduler
,
PrefillAdder
from
sglang.srt.managers.schedule_batch
import
(
from
sglang.srt.managers.schedule_batch
import
(
...
@@ -214,6 +216,9 @@ class ModelTpServer:
...
@@ -214,6 +216,9 @@ class ModelTpServer:
self
.
flush_cache
()
self
.
flush_cache
()
elif
isinstance
(
recv_req
,
AbortReq
):
elif
isinstance
(
recv_req
,
AbortReq
):
self
.
abort_request
(
recv_req
)
self
.
abort_request
(
recv_req
)
elif
isinstance
(
recv_req
,
UpdateWeightReqInput
):
success
,
message
=
self
.
update_weights
(
recv_req
)
self
.
out_pyobjs
.
append
(
UpdateWeightReqOutput
(
success
,
message
))
else
:
else
:
raise
ValueError
(
f
"Invalid request:
{
recv_req
}
"
)
raise
ValueError
(
f
"Invalid request:
{
recv_req
}
"
)
...
@@ -773,12 +778,15 @@ class ModelTpServer:
...
@@ -773,12 +778,15 @@ class ModelTpServer:
self
.
token_to_kv_pool
.
clear
()
self
.
token_to_kv_pool
.
clear
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
logger
.
info
(
"Cache flushed successfully!"
)
logger
.
info
(
"Cache flushed successfully!"
)
if_success
=
True
else
:
else
:
logging
.
warning
(
logging
.
warning
(
f
"Cache not flushed because there are pending requests. "
f
"Cache not flushed because there are pending requests. "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
f
"#running-req:
{
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
}
"
f
"#running-req:
{
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
}
"
)
)
if_success
=
False
return
if_success
def
abort_request
(
self
,
recv_req
):
def
abort_request
(
self
,
recv_req
):
# Delete requests in the waiting queue
# Delete requests in the waiting queue
...
@@ -798,6 +806,15 @@ class ModelTpServer:
...
@@ -798,6 +806,15 @@ class ModelTpServer:
req
.
finished_reason
=
FINISH_ABORT
()
req
.
finished_reason
=
FINISH_ABORT
()
break
break
def
update_weights
(
self
,
recv_req
):
success
,
message
=
self
.
model_runner
.
update_weights
(
recv_req
.
model_path
,
recv_req
.
load_format
)
if
success
:
flash_cache_success
=
self
.
flush_cache
()
assert
flash_cache_success
,
"Cache flush failed after updating weights"
return
success
,
message
def
run_tp_server
(
def
run_tp_server
(
gpu_id
:
int
,
gpu_id
:
int
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
cd10654e
...
@@ -15,6 +15,7 @@ limitations under the License.
...
@@ -15,6 +15,7 @@ limitations under the License.
"""ModelRunner runs the forward passes of the models."""
"""ModelRunner runs the forward passes of the models."""
import
gc
import
importlib
import
importlib
import
importlib.resources
import
importlib.resources
import
logging
import
logging
...
@@ -157,9 +158,9 @@ class ModelRunner:
...
@@ -157,9 +158,9 @@ class ModelRunner:
self
.
server_args
.
dtype
=
"float16"
self
.
server_args
.
dtype
=
"float16"
monkey_patch_vllm_dummy_weight_loader
()
monkey_patch_vllm_dummy_weight_loader
()
device_config
=
DeviceConfig
()
self
.
device_config
=
DeviceConfig
()
load_config
=
LoadConfig
(
load_format
=
self
.
server_args
.
load_format
)
self
.
load_config
=
LoadConfig
(
load_format
=
self
.
server_args
.
load_format
)
vllm_model_config
=
VllmModelConfig
(
self
.
vllm_model_config
=
VllmModelConfig
(
model
=
self
.
server_args
.
model_path
,
model
=
self
.
server_args
.
model_path
,
quantization
=
self
.
server_args
.
quantization
,
quantization
=
self
.
server_args
.
quantization
,
tokenizer
=
None
,
tokenizer
=
None
,
...
@@ -173,17 +174,19 @@ class ModelRunner:
...
@@ -173,17 +174,19 @@ class ModelRunner:
if
is_llama3_405b_fp8_head_16
(
self
.
model_config
)
and
self
.
tp_size
<=
8
:
if
is_llama3_405b_fp8_head_16
(
self
.
model_config
)
and
self
.
tp_size
<=
8
:
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
self
.
model_config
.
hf_config
.
num_key_value_heads
=
8
self
.
model_config
.
hf_config
.
num_key_value_heads
=
8
vllm_model_config
.
hf_config
.
num_key_value_heads
=
8
self
.
vllm_model_config
.
hf_config
.
num_key_value_heads
=
8
monkey_patch_vllm_qvk_linear_loader
()
monkey_patch_vllm_qvk_linear_loader
()
self
.
dtype
=
vllm_model_config
.
dtype
self
.
dtype
=
self
.
vllm_model_config
.
dtype
if
self
.
model_config
.
model_overide_args
is
not
None
:
if
self
.
model_config
.
model_overide_args
is
not
None
:
vllm_model_config
.
hf_config
.
update
(
self
.
model_config
.
model_overide_args
)
self
.
vllm_model_config
.
hf_config
.
update
(
self
.
model_config
.
model_overide_args
)
self
.
model
=
get_model
(
self
.
model
=
get_model
(
model_config
=
vllm_model_config
,
model_config
=
self
.
vllm_model_config
,
device_config
=
device_config
,
device_config
=
self
.
device_config
,
load_config
=
load_config
,
load_config
=
self
.
load_config
,
lora_config
=
None
,
lora_config
=
None
,
multimodal_config
=
None
,
multimodal_config
=
None
,
parallel_config
=
None
,
parallel_config
=
None
,
...
@@ -206,6 +209,91 @@ class ModelRunner:
...
@@ -206,6 +209,91 @@ class ModelRunner:
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
)
)
def
update_weights
(
self
,
model_path
,
load_format
):
from
vllm.model_executor.model_loader.loader
import
(
DefaultModelLoader
,
device_loading_context
,
get_model_loader
,
)
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
logger
.
info
(
f
"[gpu=
{
self
.
gpu_id
}
] Update weights begin. "
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
)
target_device
=
torch
.
device
(
self
.
device_config
.
device
)
try
:
vllm_model_config
=
VllmModelConfig
(
model
=
model_path
,
quantization
=
self
.
server_args
.
quantization
,
tokenizer
=
None
,
tokenizer_mode
=
None
,
trust_remote_code
=
self
.
server_args
.
trust_remote_code
,
dtype
=
self
.
server_args
.
dtype
,
seed
=
42
,
skip_tokenizer_init
=
True
,
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to load model config:
{
e
}
"
)
return
False
,
"Failed to update model weights"
load_config
=
LoadConfig
(
load_format
=
load_format
)
# Only support vllm DefaultModelLoader for now
loader
=
get_model_loader
(
load_config
)
if
not
isinstance
(
loader
,
DefaultModelLoader
):
logger
.
error
(
"Failed to get weights iterator: Unsupported loader"
)
return
False
,
"Failed to update model weights"
def
get_weight_iter
(
config
):
iter
=
loader
.
_get_weights_iterator
(
config
.
model
,
config
.
revision
,
fall_back_to_pt
=
getattr
(
self
.
model
,
"fall_back_to_pt_during_load"
,
True
),
)
return
iter
def
model_load_weights
(
model
,
iter
):
model
.
load_weights
(
iter
)
for
_
,
module
in
self
.
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
with
device_loading_context
(
module
,
target_device
):
quant_method
.
process_weights_after_loading
(
module
)
return
model
with
set_default_torch_dtype
(
vllm_model_config
.
dtype
):
try
:
iter
=
get_weight_iter
(
vllm_model_config
)
except
Exception
as
e
:
message
=
f
"Failed to get weights iterator:
{
e
}
"
logger
.
error
(
message
)
return
False
,
message
try
:
model
=
model_load_weights
(
self
.
model
,
iter
)
except
Exception
as
e
:
message
=
f
"Failed to update weights:
{
e
}
.
\n
Rolling back to original weights"
logger
.
error
(
message
)
del
iter
gc
.
collect
()
iter
=
get_weight_iter
(
self
.
vllm_model_config
)
self
.
model
=
model_load_weights
(
self
.
model
,
iter
)
return
False
,
message
self
.
model
=
model
self
.
server_args
.
model_path
=
model_path
self
.
server_args
.
load_format
=
load_format
self
.
vllm_model_config
=
vllm_model_config
self
.
load_config
=
load_config
self
.
model_config
.
path
=
model_path
logger
.
info
(
f
"[gpu=
{
self
.
gpu_id
}
] Update weights end."
)
return
True
,
"Succeeded to update model weights"
def
profile_max_num_token
(
self
,
total_gpu_memory
):
def
profile_max_num_token
(
self
,
total_gpu_memory
):
available_gpu_memory
=
get_available_gpu_memory
(
available_gpu_memory
=
get_available_gpu_memory
(
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
...
...
python/sglang/srt/server.py
View file @
cd10654e
...
@@ -51,7 +51,11 @@ from sglang.srt.managers.controller_single import (
...
@@ -51,7 +51,11 @@ from sglang.srt.managers.controller_single import (
start_controller_process
as
start_controller_process_single
,
start_controller_process
as
start_controller_process_single
,
)
)
from
sglang.srt.managers.detokenizer_manager
import
start_detokenizer_process
from
sglang.srt.managers.detokenizer_manager
import
start_detokenizer_process
from
sglang.srt.managers.io_struct
import
EmbeddingReqInput
,
GenerateReqInput
from
sglang.srt.managers.io_struct
import
(
EmbeddingReqInput
,
GenerateReqInput
,
UpdateWeightReqInput
,
)
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.openai_api.adapter
import
(
from
sglang.srt.openai_api.adapter
import
(
load_chat_template_for_openai_api
,
load_chat_template_for_openai_api
,
...
@@ -136,6 +140,23 @@ async def flush_cache():
...
@@ -136,6 +140,23 @@ async def flush_cache():
)
)
@
app
.
post
(
"/update_weights"
)
async
def
update_weights
(
obj
:
UpdateWeightReqInput
,
request
:
Request
):
success
,
message
=
await
tokenizer_manager
.
update_weights
(
obj
,
request
)
content
=
{
"message"
:
message
,
"success"
:
str
(
success
)}
if
success
:
return
JSONResponse
(
content
,
status_code
=
HTTPStatus
.
OK
,
)
else
:
return
JSONResponse
(
content
,
status_code
=
HTTPStatus
.
BAD_REQUEST
,
)
async
def
generate_request
(
obj
:
GenerateReqInput
,
request
:
Request
):
async
def
generate_request
(
obj
:
GenerateReqInput
,
request
:
Request
):
"""Handle a generate request."""
"""Handle a generate request."""
if
obj
.
stream
:
if
obj
.
stream
:
...
...
test/srt/test_update_weights.py
0 → 100644
View file @
cd10654e
import
json
import
unittest
import
requests
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_URL_FOR_UNIT_TEST
,
popen_launch_server
,
)
class
TestReplaceWeights
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_UNIT_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
300
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
def
run_decode
(
self
):
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
"The capital of France is"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
32
,
"n"
:
1
,
},
"stream"
:
False
,
"return_logprob"
:
False
,
"top_logprobs_num"
:
0
,
"return_text_in_logprobs"
:
False
,
"logprob_start_len"
:
0
,
},
)
print
(
json
.
dumps
(
response
.
json
()))
print
(
"="
*
100
)
# return the "text" in response
text
=
response
.
json
()[
"text"
]
return
text
def
get_model_info
(
self
):
response
=
requests
.
get
(
self
.
base_url
+
"/get_model_info"
)
model_path
=
response
.
json
()[
"model_path"
]
print
(
json
.
dumps
(
response
.
json
()))
return
model_path
def
run_update_weights
(
self
,
model_path
):
response
=
requests
.
post
(
self
.
base_url
+
"/update_weights"
,
json
=
{
"model_path"
:
model_path
,
},
)
print
(
json
.
dumps
(
response
.
json
()))
def
test_replace_weights
(
self
):
origin_model_path
=
self
.
get_model_info
()
print
(
f
"origin_model_path:
{
origin_model_path
}
"
)
origin_response
=
self
.
run_decode
()
# update weights
new_model_path
=
"meta-llama/Meta-Llama-3.1-8B"
self
.
run_update_weights
(
new_model_path
)
updated_model_path
=
self
.
get_model_info
()
print
(
f
"updated_model_path:
{
updated_model_path
}
"
)
assert
updated_model_path
==
new_model_path
assert
updated_model_path
!=
origin_model_path
updated_response
=
self
.
run_decode
()
assert
origin_response
[:
32
]
!=
updated_response
[:
32
]
# update weights back
self
.
run_update_weights
(
origin_model_path
)
updated_model_path
=
self
.
get_model_info
()
assert
updated_model_path
==
origin_model_path
updated_response
=
self
.
run_decode
()
assert
origin_response
[:
32
]
==
updated_response
[:
32
]
def
test_replace_weights_unexist_model
(
self
):
origin_model_path
=
self
.
get_model_info
()
print
(
f
"origin_model_path:
{
origin_model_path
}
"
)
origin_response
=
self
.
run_decode
()
# update weights
new_model_path
=
"meta-llama/Meta-Llama-3.1-8B-1"
self
.
run_update_weights
(
new_model_path
)
updated_model_path
=
self
.
get_model_info
()
print
(
f
"updated_model_path:
{
updated_model_path
}
"
)
assert
updated_model_path
==
origin_model_path
updated_response
=
self
.
run_decode
()
assert
origin_response
[:
32
]
==
updated_response
[:
32
]
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment