Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
18108abe
Unverified
Commit
18108abe
authored
Dec 02, 2024
by
Lianmin Zheng
Committed by
GitHub
Dec 02, 2024
Browse files
[Minor] Fix code style (#2311)
parent
c54bda30
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
249 additions
and
274 deletions
+249
-274
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+76
-77
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+0
-10
python/sglang/srt/server.py
python/sglang/srt/server.py
+170
-172
test/srt/test_get_weights_by_name.py
test/srt/test_get_weights_by_name.py
+2
-2
test/srt/test_update_weights_from_distributed.py
test/srt/test_update_weights_from_distributed.py
+1
-13
No files found.
python/sglang/srt/managers/tokenizer_manager.py
View file @
18108abe
...
@@ -25,7 +25,6 @@ import uuid
...
@@ -25,7 +25,6 @@ import uuid
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
fastapi
import
fastapi
import
torch
import
uvloop
import
uvloop
import
zmq
import
zmq
import
zmq.asyncio
import
zmq.asyncio
...
@@ -337,6 +336,12 @@ class TokenizerManager:
...
@@ -337,6 +336,12 @@ class TokenizerManager:
rids
.
append
(
tmp_obj
.
rid
)
rids
.
append
(
tmp_obj
.
rid
)
else
:
else
:
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
if
batch_size
>
128
:
logger
.
warning
(
"Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
"The performance might be better if you just duplicate the requests n times or use "
"many threads to send them one by one with parallel sampling (n > 1)."
)
# Tokenize all requests
# Tokenize all requests
objs
=
[
obj
[
i
]
for
i
in
range
(
batch_size
)]
objs
=
[
obj
[
i
]
for
i
in
range
(
batch_size
)]
...
@@ -494,9 +499,7 @@ class TokenizerManager:
...
@@ -494,9 +499,7 @@ class TokenizerManager:
result
=
await
self
.
parameter_update_result
result
=
await
self
.
parameter_update_result
return
result
.
success
,
result
.
message
return
result
.
success
,
result
.
message
else
:
else
:
logger
.
error
(
logger
.
error
(
"Another parameter update is in progress in tokenizer manager"
)
f
"Another parameter update is in progress in tokenizer manager"
)
return
(
return
(
False
,
False
,
"Another parameter update is in progress. Please try again later."
,
"Another parameter update is in progress. Please try again later."
,
...
@@ -597,7 +600,68 @@ class TokenizerManager:
...
@@ -597,7 +600,68 @@ class TokenizerManager:
InitWeightsUpdateGroupReqOutput
,
InitWeightsUpdateGroupReqOutput
,
]
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
]
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
if
isinstance
(
recv_obj
,
UpdateWeightFromDiskReqOutput
):
if
isinstance
(
recv_obj
,
(
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
)):
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
state
=
self
.
rid_to_state
.
get
(
rid
,
None
)
if
state
is
None
:
continue
recv_obj
.
meta_info
[
i
][
"id"
]
=
rid
if
isinstance
(
recv_obj
,
BatchStrOut
):
out_dict
=
{
"text"
:
recv_obj
.
output_strs
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
elif
isinstance
(
recv_obj
,
BatchTokenIDOut
):
out_dict
=
{
"token_ids"
:
recv_obj
.
output_ids
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
else
:
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
)
out_dict
=
{
"embedding"
:
recv_obj
.
embeddings
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
state
.
out_list
.
append
(
out_dict
)
state
.
finished
=
recv_obj
.
finished_reason
[
i
]
is
not
None
state
.
event
.
set
()
if
self
.
enable_metrics
:
completion_tokens
=
recv_obj
.
meta_info
[
i
][
"completion_tokens"
]
if
state
.
first_token_time
is
None
:
state
.
first_token_time
=
time
.
time
()
self
.
metrics_collector
.
observe_time_to_first_token
(
state
.
first_token_time
-
state
.
created_time
)
else
:
if
completion_tokens
>=
2
:
self
.
metrics_collector
.
observe_time_per_output_token
(
(
time
.
time
()
-
state
.
first_token_time
)
/
(
completion_tokens
-
1
)
)
if
state
.
finished
:
self
.
metrics_collector
.
inc_prompt_tokens
(
recv_obj
.
meta_info
[
i
][
"prompt_tokens"
]
)
self
.
metrics_collector
.
inc_generation_tokens
(
completion_tokens
)
self
.
metrics_collector
.
observe_e2e_request_latency
(
time
.
time
()
-
state
.
created_time
)
if
completion_tokens
>=
1
:
self
.
metrics_collector
.
observe_time_per_output_token
(
(
time
.
time
()
-
state
.
created_time
)
/
completion_tokens
)
elif
isinstance
(
recv_obj
,
OpenSessionReqOutput
):
self
.
session_futures
[
recv_obj
.
session_id
].
set_result
(
recv_obj
.
session_id
)
elif
isinstance
(
recv_obj
,
UpdateWeightFromDiskReqOutput
):
if
self
.
server_args
.
dp_size
==
1
:
if
self
.
server_args
.
dp_size
==
1
:
self
.
model_update_result
.
set_result
(
recv_obj
)
self
.
model_update_result
.
set_result
(
recv_obj
)
else
:
# self.server_args.dp_size > 1
else
:
# self.server_args.dp_size > 1
...
@@ -605,13 +669,16 @@ class TokenizerManager:
...
@@ -605,13 +669,16 @@ class TokenizerManager:
# set future if the all results are recevied
# set future if the all results are recevied
if
len
(
self
.
model_update_tmp
)
==
self
.
server_args
.
dp_size
:
if
len
(
self
.
model_update_tmp
)
==
self
.
server_args
.
dp_size
:
self
.
model_update_result
.
set_result
(
self
.
model_update_tmp
)
self
.
model_update_result
.
set_result
(
self
.
model_update_tmp
)
continue
elif
isinstance
(
recv_obj
,
InitWeightsUpdateGroupReqOutput
):
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for init parameter update group"
self
.
init_weights_update_group_result
.
set_result
(
recv_obj
)
elif
isinstance
(
recv_obj
,
UpdateWeightsFromDistributedReqOutput
):
elif
isinstance
(
recv_obj
,
UpdateWeightsFromDistributedReqOutput
):
assert
(
assert
(
self
.
server_args
.
dp_size
==
1
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for update weights from distributed"
),
"dp_size must be 1 for update weights from distributed"
self
.
parameter_update_result
.
set_result
(
recv_obj
)
self
.
parameter_update_result
.
set_result
(
recv_obj
)
continue
elif
isinstance
(
recv_obj
,
GetWeightsByNameReqOutput
):
elif
isinstance
(
recv_obj
,
GetWeightsByNameReqOutput
):
if
self
.
server_args
.
dp_size
==
1
:
if
self
.
server_args
.
dp_size
==
1
:
self
.
get_weights_by_name_result
.
set_result
(
recv_obj
)
self
.
get_weights_by_name_result
.
set_result
(
recv_obj
)
...
@@ -621,76 +688,8 @@ class TokenizerManager:
...
@@ -621,76 +688,8 @@ class TokenizerManager:
self
.
get_weights_by_name_result
.
set_result
(
self
.
get_weights_by_name_result
.
set_result
(
self
.
get_weights_by_name_tmp
self
.
get_weights_by_name_tmp
)
)
continue
else
:
elif
isinstance
(
recv_obj
,
InitWeightsUpdateGroupReqOutput
):
raise
ValueError
(
f
"Invalid object:
{
recv_obj
=
}
"
)
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for init parameter update group"
self
.
init_weights_update_group_result
.
set_result
(
recv_obj
)
continue
elif
isinstance
(
recv_obj
,
OpenSessionReqOutput
):
self
.
session_futures
[
recv_obj
.
session_id
].
set_result
(
recv_obj
.
session_id
)
continue
assert
isinstance
(
recv_obj
,
(
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
)
),
f
"Unexpected obj received:
{
type
(
recv_obj
)
}
"
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
state
=
self
.
rid_to_state
.
get
(
rid
,
None
)
if
state
is
None
:
continue
recv_obj
.
meta_info
[
i
][
"id"
]
=
rid
if
isinstance
(
recv_obj
,
BatchStrOut
):
out_dict
=
{
"text"
:
recv_obj
.
output_strs
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
elif
isinstance
(
recv_obj
,
BatchTokenIDOut
):
out_dict
=
{
"token_ids"
:
recv_obj
.
output_ids
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
else
:
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
)
out_dict
=
{
"embedding"
:
recv_obj
.
embeddings
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
state
.
out_list
.
append
(
out_dict
)
state
.
finished
=
recv_obj
.
finished_reason
[
i
]
is
not
None
state
.
event
.
set
()
if
self
.
enable_metrics
:
completion_tokens
=
recv_obj
.
meta_info
[
i
][
"completion_tokens"
]
if
state
.
first_token_time
is
None
:
state
.
first_token_time
=
time
.
time
()
self
.
metrics_collector
.
observe_time_to_first_token
(
state
.
first_token_time
-
state
.
created_time
)
else
:
if
completion_tokens
>=
2
:
self
.
metrics_collector
.
observe_time_per_output_token
(
(
time
.
time
()
-
state
.
first_token_time
)
/
(
completion_tokens
-
1
)
)
if
state
.
finished
:
self
.
metrics_collector
.
inc_prompt_tokens
(
recv_obj
.
meta_info
[
i
][
"prompt_tokens"
]
)
self
.
metrics_collector
.
inc_generation_tokens
(
completion_tokens
)
self
.
metrics_collector
.
observe_e2e_request_latency
(
time
.
time
()
-
state
.
created_time
)
if
completion_tokens
>=
1
:
self
.
metrics_collector
.
observe_time_per_output_token
(
(
time
.
time
()
-
state
.
created_time
)
/
completion_tokens
)
def
convert_logprob_style
(
def
convert_logprob_style
(
self
,
self
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
18108abe
...
@@ -218,16 +218,6 @@ class ModelRunner:
...
@@ -218,16 +218,6 @@ class ModelRunner:
)
)
self
.
tp_group
=
get_tp_group
()
self
.
tp_group
=
get_tp_group
()
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
# so we disable padding in cuda graph.
if
self
.
device
==
"cuda"
and
not
all
(
in_the_same_node_as
(
self
.
tp_group
.
cpu_group
,
source_rank
=
0
)
):
self
.
server_args
.
disable_cuda_graph_padding
=
True
logger
.
info
(
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
)
# Check memory for tensor parallelism
# Check memory for tensor parallelism
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
local_gpu_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
local_gpu_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
...
...
python/sglang/srt/server.py
View file @
18108abe
...
@@ -82,7 +82,6 @@ from sglang.srt.utils import (
...
@@ -82,7 +82,6 @@ from sglang.srt.utils import (
assert_pkg_version
,
assert_pkg_version
,
configure_logger
,
configure_logger
,
delete_directory
,
delete_directory
,
init_custom_process_group
,
is_port_available
,
is_port_available
,
kill_process_tree
,
kill_process_tree
,
maybe_set_triton_cache_manager
,
maybe_set_triton_cache_manager
,
...
@@ -154,13 +153,11 @@ async def get_model_info():
...
@@ -154,13 +153,11 @@ async def get_model_info():
@
app
.
get
(
"/get_server_info"
)
@
app
.
get
(
"/get_server_info"
)
async
def
get_server_info
():
async
def
get_server_info
():
try
:
return
{
return
await
_get_server_info
()
**
dataclasses
.
asdict
(
tokenizer_manager
.
server_args
),
# server args
**
scheduler_info
,
except
Exception
as
e
:
"version"
:
__version__
,
return
ORJSONResponse
(
}
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
post
(
"/flush_cache"
)
@
app
.
post
(
"/flush_cache"
)
...
@@ -567,14 +564,6 @@ def launch_server(
...
@@ -567,14 +564,6 @@ def launch_server(
t
.
join
()
t
.
join
()
async
def
_get_server_info
():
return
{
**
dataclasses
.
asdict
(
tokenizer_manager
.
server_args
),
# server args
**
scheduler_info
,
"version"
:
__version__
,
}
def
_set_envs_and_config
(
server_args
:
ServerArgs
):
def
_set_envs_and_config
(
server_args
:
ServerArgs
):
# Set global environments
# Set global environments
os
.
environ
[
"TF_CPP_MIN_LOG_LEVEL"
]
=
"3"
os
.
environ
[
"TF_CPP_MIN_LOG_LEVEL"
]
=
"3"
...
@@ -687,160 +676,6 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
...
@@ -687,160 +676,6 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
delete_directory
(
server_args
.
model_path
)
delete_directory
(
server_args
.
model_path
)
class
Runtime
:
"""
A wrapper for the server.
This is used for launching the server in a python program without
using the commond line interface.
"""
def
__init__
(
self
,
log_level
:
str
=
"error"
,
*
args
,
**
kwargs
,
):
"""See the arguments in server_args.py::ServerArgs"""
self
.
server_args
=
ServerArgs
(
*
args
,
log_level
=
log_level
,
**
kwargs
)
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
atexit
.
register
(
self
.
shutdown
)
# Pre-allocate ports
for
port
in
range
(
10000
,
40000
):
if
is_port_available
(
port
):
break
port
+=
1
self
.
server_args
.
port
=
port
self
.
url
=
self
.
server_args
.
url
()
self
.
generate_url
=
self
.
url
+
"/generate"
# NOTE: We store pid instead of proc to fix some issues during __delete__
self
.
pid
=
None
pipe_reader
,
pipe_writer
=
mp
.
Pipe
(
duplex
=
False
)
proc
=
mp
.
Process
(
target
=
launch_server
,
args
=
(
self
.
server_args
,
pipe_writer
),
)
proc
.
start
()
pipe_writer
.
close
()
self
.
pid
=
proc
.
pid
try
:
init_state
=
pipe_reader
.
recv
()
except
EOFError
:
init_state
=
""
if
init_state
!=
"ready"
:
self
.
shutdown
()
raise
RuntimeError
(
"Initialization failed. Please see the error messages above."
)
self
.
endpoint
=
RuntimeEndpoint
(
self
.
url
)
def
shutdown
(
self
):
if
self
.
pid
is
not
None
:
kill_process_tree
(
self
.
pid
)
self
.
pid
=
None
def
cache_prefix
(
self
,
prefix
:
str
):
self
.
endpoint
.
cache_prefix
(
prefix
)
def
get_tokenizer
(
self
):
return
get_tokenizer
(
self
.
server_args
.
tokenizer_path
,
tokenizer_mode
=
self
.
server_args
.
tokenizer_mode
,
trust_remote_code
=
self
.
server_args
.
trust_remote_code
,
)
async
def
async_generate
(
self
,
prompt
:
str
,
sampling_params
:
Optional
[
Dict
]
=
None
,
):
if
self
.
server_args
.
skip_tokenizer_init
:
json_data
=
{
"input_ids"
:
prompt
,
"sampling_params"
:
sampling_params
,
"stream"
:
True
,
}
else
:
json_data
=
{
"text"
:
prompt
,
"sampling_params"
:
sampling_params
,
"stream"
:
True
,
}
pos
=
0
timeout
=
aiohttp
.
ClientTimeout
(
total
=
3
*
3600
)
async
with
aiohttp
.
ClientSession
(
timeout
=
timeout
,
trust_env
=
True
)
as
session
:
async
with
session
.
post
(
self
.
generate_url
,
json
=
json_data
)
as
response
:
async
for
chunk
,
_
in
response
.
content
.
iter_chunks
():
chunk
=
chunk
.
decode
(
"utf-8"
)
if
chunk
and
chunk
.
startswith
(
"data:"
):
if
chunk
==
"data: [DONE]
\n\n
"
:
break
data
=
json
.
loads
(
chunk
[
5
:].
strip
(
"
\n
"
))
if
"text"
in
data
:
cur
=
data
[
"text"
][
pos
:]
if
cur
:
yield
cur
pos
+=
len
(
cur
)
else
:
yield
data
add_request
=
async_generate
def
generate
(
self
,
prompt
:
Union
[
str
,
List
[
str
]],
sampling_params
:
Optional
[
Dict
]
=
None
,
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
False
,
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
top_logprobs_num
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
lora_path
:
Optional
[
List
[
Optional
[
str
]]]
=
None
,
):
json_data
=
{
"text"
:
prompt
,
"sampling_params"
:
sampling_params
,
"return_logprob"
:
return_logprob
,
"logprob_start_len"
:
logprob_start_len
,
"top_logprobs_num"
:
top_logprobs_num
,
"lora_path"
:
lora_path
,
}
assert
not
isinstance
(
lora_path
,
list
)
or
len
(
lora_path
)
==
len
(
prompt
)
response
=
requests
.
post
(
self
.
url
+
"/generate"
,
json
=
json_data
,
)
return
json
.
dumps
(
response
.
json
())
def
encode
(
self
,
prompt
:
Union
[
str
,
List
[
str
],
List
[
Dict
],
List
[
List
[
Dict
]]],
):
json_data
=
{
"text"
:
prompt
}
response
=
requests
.
post
(
self
.
url
+
"/encode"
,
json
=
json_data
)
return
json
.
dumps
(
response
.
json
())
async
def
get_server_info
(
self
):
async
with
aiohttp
.
ClientSession
()
as
session
:
async
with
session
.
get
(
f
"
{
self
.
url
}
/get_server_info"
)
as
response
:
if
response
.
status
==
200
:
return
await
response
.
json
()
else
:
error_data
=
await
response
.
json
()
raise
RuntimeError
(
f
"Failed to get server info.
{
error_data
[
'error'
][
'message'
]
}
"
)
def
__del__
(
self
):
self
.
shutdown
()
STREAM_END_SYMBOL
=
b
"data: [DONE]"
STREAM_END_SYMBOL
=
b
"data: [DONE]"
STREAM_CHUNK_START_SYMBOL
=
b
"data:"
STREAM_CHUNK_START_SYMBOL
=
b
"data:"
...
@@ -854,6 +689,8 @@ class Engine:
...
@@ -854,6 +689,8 @@ class Engine:
"""
"""
def
__init__
(
self
,
log_level
:
str
=
"error"
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
log_level
:
str
=
"error"
,
*
args
,
**
kwargs
):
"""See the arguments in server_args.py::ServerArgs"""
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
atexit
.
register
(
self
.
shutdown
)
atexit
.
register
(
self
.
shutdown
)
...
@@ -986,8 +823,12 @@ class Engine:
...
@@ -986,8 +823,12 @@ class Engine:
def
stop_profile
(
self
):
def
stop_profile
(
self
):
tokenizer_manager
.
stop_profile
()
tokenizer_manager
.
stop_profile
()
async
def
get_server_info
(
self
):
def
get_server_info
(
self
):
return
await
_get_server_info
()
return
{
**
dataclasses
.
asdict
(
tokenizer_manager
.
server_args
),
# server args
**
scheduler_info
,
"version"
:
__version__
,
}
def
init_weights_update_group
(
def
init_weights_update_group
(
self
,
self
,
...
@@ -1037,3 +878,160 @@ class Engine:
...
@@ -1037,3 +878,160 @@ class Engine:
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
_get_weights
())
return
loop
.
run_until_complete
(
_get_weights
())
class
Runtime
:
"""
A wrapper for the HTTP server.
This is used for launching the server in a python program without
using the commond line interface.
It is mainly used for the frontend language.
You should use the Engine class if you want to do normal offline processing.
"""
def
__init__
(
self
,
log_level
:
str
=
"error"
,
*
args
,
**
kwargs
,
):
"""See the arguments in server_args.py::ServerArgs"""
self
.
server_args
=
ServerArgs
(
*
args
,
log_level
=
log_level
,
**
kwargs
)
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
atexit
.
register
(
self
.
shutdown
)
# Pre-allocate ports
for
port
in
range
(
10000
,
40000
):
if
is_port_available
(
port
):
break
port
+=
1
self
.
server_args
.
port
=
port
self
.
url
=
self
.
server_args
.
url
()
self
.
generate_url
=
self
.
url
+
"/generate"
# NOTE: We store pid instead of proc to fix some issues during __delete__
self
.
pid
=
None
pipe_reader
,
pipe_writer
=
mp
.
Pipe
(
duplex
=
False
)
proc
=
mp
.
Process
(
target
=
launch_server
,
args
=
(
self
.
server_args
,
pipe_writer
),
)
proc
.
start
()
pipe_writer
.
close
()
self
.
pid
=
proc
.
pid
try
:
init_state
=
pipe_reader
.
recv
()
except
EOFError
:
init_state
=
""
if
init_state
!=
"ready"
:
self
.
shutdown
()
raise
RuntimeError
(
"Initialization failed. Please see the error messages above."
)
self
.
endpoint
=
RuntimeEndpoint
(
self
.
url
)
def
shutdown
(
self
):
if
self
.
pid
is
not
None
:
kill_process_tree
(
self
.
pid
)
self
.
pid
=
None
def
cache_prefix
(
self
,
prefix
:
str
):
self
.
endpoint
.
cache_prefix
(
prefix
)
def
get_tokenizer
(
self
):
return
get_tokenizer
(
self
.
server_args
.
tokenizer_path
,
tokenizer_mode
=
self
.
server_args
.
tokenizer_mode
,
trust_remote_code
=
self
.
server_args
.
trust_remote_code
,
)
async
def
async_generate
(
self
,
prompt
:
str
,
sampling_params
:
Optional
[
Dict
]
=
None
,
):
if
self
.
server_args
.
skip_tokenizer_init
:
json_data
=
{
"input_ids"
:
prompt
,
"sampling_params"
:
sampling_params
,
"stream"
:
True
,
}
else
:
json_data
=
{
"text"
:
prompt
,
"sampling_params"
:
sampling_params
,
"stream"
:
True
,
}
pos
=
0
timeout
=
aiohttp
.
ClientTimeout
(
total
=
3
*
3600
)
async
with
aiohttp
.
ClientSession
(
timeout
=
timeout
,
trust_env
=
True
)
as
session
:
async
with
session
.
post
(
self
.
generate_url
,
json
=
json_data
)
as
response
:
async
for
chunk
,
_
in
response
.
content
.
iter_chunks
():
chunk
=
chunk
.
decode
(
"utf-8"
)
if
chunk
and
chunk
.
startswith
(
"data:"
):
if
chunk
==
"data: [DONE]
\n\n
"
:
break
data
=
json
.
loads
(
chunk
[
5
:].
strip
(
"
\n
"
))
if
"text"
in
data
:
cur
=
data
[
"text"
][
pos
:]
if
cur
:
yield
cur
pos
+=
len
(
cur
)
else
:
yield
data
add_request
=
async_generate
def
generate
(
self
,
prompt
:
Union
[
str
,
List
[
str
]],
sampling_params
:
Optional
[
Dict
]
=
None
,
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
False
,
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
top_logprobs_num
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
lora_path
:
Optional
[
List
[
Optional
[
str
]]]
=
None
,
):
json_data
=
{
"text"
:
prompt
,
"sampling_params"
:
sampling_params
,
"return_logprob"
:
return_logprob
,
"logprob_start_len"
:
logprob_start_len
,
"top_logprobs_num"
:
top_logprobs_num
,
"lora_path"
:
lora_path
,
}
assert
not
isinstance
(
lora_path
,
list
)
or
len
(
lora_path
)
==
len
(
prompt
)
response
=
requests
.
post
(
self
.
url
+
"/generate"
,
json
=
json_data
,
)
return
json
.
dumps
(
response
.
json
())
def
encode
(
self
,
prompt
:
Union
[
str
,
List
[
str
],
List
[
Dict
],
List
[
List
[
Dict
]]],
):
json_data
=
{
"text"
:
prompt
}
response
=
requests
.
post
(
self
.
url
+
"/encode"
,
json
=
json_data
)
return
json
.
dumps
(
response
.
json
())
async
def
get_server_info
(
self
):
async
with
aiohttp
.
ClientSession
()
as
session
:
async
with
session
.
get
(
f
"
{
self
.
url
}
/get_server_info"
)
as
response
:
if
response
.
status
==
200
:
return
await
response
.
json
()
else
:
error_data
=
await
response
.
json
()
raise
RuntimeError
(
f
"Failed to get server info.
{
error_data
[
'error'
][
'message'
]
}
"
)
def
__del__
(
self
):
self
.
shutdown
()
test/srt/test_get_weights_by_name.py
View file @
18108abe
...
@@ -67,7 +67,7 @@ class TestGetWeightsByName(unittest.TestCase):
...
@@ -67,7 +67,7 @@ class TestGetWeightsByName(unittest.TestCase):
terminate_process
(
self
.
process
)
terminate_process
(
self
.
process
)
def
assert_tie_word_embeddings
(
self
,
truncate_size
):
def
assert_tie_word_embeddings
(
self
,
truncate_size
):
print
(
f
"assert_tie_word_embeddings"
)
print
(
"assert_tie_word_embeddings"
)
if
self
.
backend
==
"Engine"
:
if
self
.
backend
==
"Engine"
:
backend_ret
=
_process_return
(
backend_ret
=
_process_return
(
self
.
engine
.
get_weights_by_name
(
"lm_head.weight"
,
truncate_size
)
self
.
engine
.
get_weights_by_name
(
"lm_head.weight"
,
truncate_size
)
...
@@ -79,7 +79,7 @@ class TestGetWeightsByName(unittest.TestCase):
...
@@ -79,7 +79,7 @@ class TestGetWeightsByName(unittest.TestCase):
json
=
{
"name"
:
"lm_head.weight"
,
"truncate_size"
:
truncate_size
},
json
=
{
"name"
:
"lm_head.weight"
,
"truncate_size"
:
truncate_size
},
).
json
()
).
json
()
)
)
print
(
f
"assert_tie_word_embeddings of hf and backend"
)
print
(
"assert_tie_word_embeddings of hf and backend"
)
assert
np
.
allclose
(
assert
np
.
allclose
(
self
.
hf_model
.
get_parameter
(
"model.embed_tokens.weight"
)
self
.
hf_model
.
get_parameter
(
"model.embed_tokens.weight"
)
.
cpu
()
.
cpu
()
...
...
test/srt/test_update_weights_from_distributed.py
View file @
18108abe
...
@@ -127,7 +127,7 @@ def init_process_hf(
...
@@ -127,7 +127,7 @@ def init_process_hf(
hf_instruct_params
=
[]
hf_instruct_params
=
[]
hf_base_params
=
[]
hf_base_params
=
[]
print
(
f
"get parameter in hf instruct model and base model"
)
print
(
"get parameter in hf instruct model and base model"
)
for
parameter_name
in
checking_parameters
:
for
parameter_name
in
checking_parameters
:
hf_instruct_params
.
append
(
hf_instruct_params
.
append
(
hf_instruct_model
.
get_parameter
(
parameter_name
)[:
truncate_size
]
hf_instruct_model
.
get_parameter
(
parameter_name
)[:
truncate_size
]
...
@@ -186,7 +186,6 @@ def init_process_hf(
...
@@ -186,7 +186,6 @@ def init_process_hf(
param_queue
.
put
((
"broadcast_time"
,
broadcast_time
))
param_queue
.
put
((
"broadcast_time"
,
broadcast_time
))
# Delete the huggingface models to free up memory.
# Delete the huggingface models to free up memory.
del
hf_instruct_model
del
hf_instruct_model
del
hf_base_model
del
hf_base_model
gc
.
collect
()
gc
.
collect
()
...
@@ -238,7 +237,6 @@ def init_process_sgl(
...
@@ -238,7 +237,6 @@ def init_process_sgl(
print
(
f
"rank
{
rank
}
init server on url:
{
url
}
"
)
print
(
f
"rank
{
rank
}
init server on url:
{
url
}
"
)
# Get weights of instruct model, i.e. pre-training weights.
# Get weights of instruct model, i.e. pre-training weights.
instruct_params
=
[]
instruct_params
=
[]
for
parameter_name
in
checking_parameters
:
for
parameter_name
in
checking_parameters
:
instruct_params
.
append
(
instruct_params
.
append
(
...
@@ -253,7 +251,6 @@ def init_process_sgl(
...
@@ -253,7 +251,6 @@ def init_process_sgl(
param_queue
.
put
((
f
"sgl_dp_
{
rank
}
_instruct_params"
,
instruct_params
))
param_queue
.
put
((
f
"sgl_dp_
{
rank
}
_instruct_params"
,
instruct_params
))
# Init weight update group with the training engine.
# Init weight update group with the training engine.
if
backend
==
"Engine"
:
if
backend
==
"Engine"
:
engine
.
init_weights_update_group
(
engine
.
init_weights_update_group
(
master_address
=
"localhost"
,
master_address
=
"localhost"
,
...
@@ -282,7 +279,6 @@ def init_process_sgl(
...
@@ -282,7 +279,6 @@ def init_process_sgl(
# The last parameter is lm_head.weight, which is tied
# The last parameter is lm_head.weight, which is tied
# with embed_tokens.weight. Actually, we only need
# with embed_tokens.weight. Actually, we only need
# to update embed_tokens.weight once.
# to update embed_tokens.weight once.
tie_word_embeddings
=
(
tie_word_embeddings
=
(
True
if
model_name
==
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
else
False
True
if
model_name
==
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
else
False
)
)
...
@@ -291,7 +287,6 @@ def init_process_sgl(
...
@@ -291,7 +287,6 @@ def init_process_sgl(
update_parameters
.
remove
(
"lm_head.weight"
)
update_parameters
.
remove
(
"lm_head.weight"
)
# Get weights from the training engine and update the inference engine.
# Get weights from the training engine and update the inference engine.
for
parameter_name
in
update_parameters
:
for
parameter_name
in
update_parameters
:
if
backend
==
"Engine"
:
if
backend
==
"Engine"
:
engine
.
update_weights_from_distributed
(
engine
.
update_weights_from_distributed
(
...
@@ -312,7 +307,6 @@ def init_process_sgl(
...
@@ -312,7 +307,6 @@ def init_process_sgl(
time_end_update
=
time
.
time
()
time_end_update
=
time
.
time
()
# Measure the latency of broadcast/weights update.
# Measure the latency of broadcast/weights update.
update_time
=
time_end_update
-
time_begin_update
update_time
=
time_end_update
-
time_begin_update
print
(
print
(
f
"fully update model_name
{
model_name
}
rank
{
rank
}
parameter from distributed time:
{
update_time
:.
3
f
}
s"
f
"fully update model_name
{
model_name
}
rank
{
rank
}
parameter from distributed time:
{
update_time
:.
3
f
}
s"
...
@@ -320,7 +314,6 @@ def init_process_sgl(
...
@@ -320,7 +314,6 @@ def init_process_sgl(
param_queue
.
put
((
f
"update_sgl_dp_
{
rank
}
_time"
,
update_time
))
param_queue
.
put
((
f
"update_sgl_dp_
{
rank
}
_time"
,
update_time
))
# Get the weights of post-training model after weights update for correctness check.
# Get the weights of post-training model after weights update for correctness check.
base_params
=
[]
base_params
=
[]
for
parameter_name
in
checking_parameters
:
for
parameter_name
in
checking_parameters
:
if
backend
==
"Engine"
:
if
backend
==
"Engine"
:
...
@@ -340,7 +333,6 @@ def init_process_sgl(
...
@@ -340,7 +333,6 @@ def init_process_sgl(
param_queue
.
put
((
f
"sgl_dp_
{
rank
}
_base_params"
,
base_params
))
param_queue
.
put
((
f
"sgl_dp_
{
rank
}
_base_params"
,
base_params
))
# Shutdown the engine or terminate the server process.
# Shutdown the engine or terminate the server process.
if
backend
==
"Engine"
:
if
backend
==
"Engine"
:
engine
.
shutdown
()
engine
.
shutdown
()
else
:
else
:
...
@@ -426,7 +418,6 @@ def test_update_weights_from_distributed(
...
@@ -426,7 +418,6 @@ def test_update_weights_from_distributed(
# Check the correctness of weights update by verifying
# Check the correctness of weights update by verifying
# the weights of instruct model and base model.
# the weights of instruct model and base model.
for
i
in
range
(
len
(
params
[
"hf_instruct"
])):
for
i
in
range
(
len
(
params
[
"hf_instruct"
])):
verify_params_close
(
verify_params_close
(
params
[
"hf_instruct"
][
i
],
params
[
"hf_instruct"
][
i
],
...
@@ -463,7 +454,6 @@ def test_update_weights_from_distributed(
...
@@ -463,7 +454,6 @@ def test_update_weights_from_distributed(
),
"hf_instruct_params and hf_base_params have different lengths"
),
"hf_instruct_params and hf_base_params have different lengths"
# Check if the weights of lm_head are tied with embed_tokens.
# Check if the weights of lm_head are tied with embed_tokens.
params_to_check
=
[
params_to_check
=
[
(
(
params
[
"hf_instruct"
],
params
[
"hf_instruct"
],
...
@@ -509,7 +499,6 @@ def test_update_weights_from_distributed(
...
@@ -509,7 +499,6 @@ def test_update_weights_from_distributed(
# Time limit for broadcast and update on CI is 3 / 6
# Time limit for broadcast and update on CI is 3 / 6
# On local H100, it's 1 / 2
# On local H100, it's 1 / 2
time_limit
=
3
if
model_name
==
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
else
6
time_limit
=
3
if
model_name
==
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
else
6
assert
(
assert
(
...
@@ -526,7 +515,6 @@ def test_update_weights_from_distributed(
...
@@ -526,7 +515,6 @@ def test_update_weights_from_distributed(
),
f
"update_sgl_dp_two_time exceeds time limit
{
time_limit
}
s"
),
f
"update_sgl_dp_two_time exceeds time limit
{
time_limit
}
s"
# Delete the context and close the parameter queue.
# Delete the context and close the parameter queue.
del
context
del
context
param_queue
.
close
()
param_queue
.
close
()
param_queue
.
join_thread
()
param_queue
.
join_thread
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment