Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
1a956e13
Unverified
Commit
1a956e13
authored
Jun 05, 2023
by
Zhuohan Li
Committed by
GitHub
Jun 05, 2023
Browse files
Fix various issues of async servers (#135)
parent
8274ca23
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
289 additions
and
121 deletions
+289
-121
benchmarks/benchmark_async_llm_server.py
benchmarks/benchmark_async_llm_server.py
+58
-0
cacheflow/config.py
cacheflow/config.py
+3
-3
cacheflow/core/block_manager.py
cacheflow/core/block_manager.py
+6
-3
cacheflow/core/scheduler.py
cacheflow/core/scheduler.py
+13
-1
cacheflow/entrypoints/openai/openai_frontend.py
cacheflow/entrypoints/openai/openai_frontend.py
+23
-7
cacheflow/entrypoints/simple_fastapi_frontend.py
cacheflow/entrypoints/simple_fastapi_frontend.py
+17
-8
cacheflow/sequence.py
cacheflow/sequence.py
+9
-1
cacheflow/server/arg_utils.py
cacheflow/server/arg_utils.py
+72
-59
cacheflow/server/async_llm_server.py
cacheflow/server/async_llm_server.py
+71
-22
cacheflow/server/llm_server.py
cacheflow/server/llm_server.py
+7
-9
cacheflow/server/ray_utils.py
cacheflow/server/ray_utils.py
+10
-8
No files found.
benchmarks/benchmark_async_llm_server.py
0 → 100644
View file @
1a956e13
import
argparse
import
json
import
threading
import
time
import
requests
def
main
(
args
:
argparse
.
Namespace
):
prompts
=
[
f
"Tell me a story with more than
{
''
.
join
([
str
(
i
+
1
)]
*
5
)
}
words"
for
i
in
range
(
args
.
n_threads
)]
headers
=
{
"User-Agent"
:
"CacheFlow Benchmark Client"
}
ploads
=
[{
"prompt"
:
p
,
"max_tokens"
:
args
.
max_tokens
,
"temperature"
:
0.0
,
"ignore_eos"
:
True
,
}
for
p
in
prompts
]
def
send_request
(
results
,
i
):
response
=
requests
.
post
(
args
.
api_url
,
headers
=
headers
,
json
=
ploads
[
i
],
stream
=
True
)
results
[
i
]
=
response
# use args.n_threads to prompt the backend
tik
=
time
.
time
()
threads
=
[]
results
=
[
None
]
*
args
.
n_threads
for
i
in
range
(
args
.
n_threads
):
t
=
threading
.
Thread
(
target
=
send_request
,
args
=
(
results
,
i
))
t
.
start
()
threads
.
append
(
t
)
for
t
in
threads
:
t
.
join
()
print
(
f
"Time (POST):
{
time
.
time
()
-
tik
}
s"
)
n_words
=
0
for
i
,
response
in
enumerate
(
results
):
k
=
list
(
response
.
iter_lines
(
chunk_size
=
8192
,
decode_unicode
=
False
,
delimiter
=
b
"
\0
"
))
response_new_words
=
json
.
loads
(
k
[
-
2
].
decode
(
"utf-8"
))[
"text"
][
0
]
n_words
+=
len
(
response_new_words
.
split
(
" "
))
-
len
(
prompts
[
i
].
split
(
" "
))
time_seconds
=
time
.
time
()
-
tik
print
(
f
"Time (total):
{
time_seconds
:.
3
f
}
s to finish, n_threads:
{
args
.
n_threads
}
, "
f
"throughput:
{
n_words
/
time_seconds
}
words/s."
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--api-url"
,
type
=
str
,
default
=
"http://localhost:8001/generate"
)
parser
.
add_argument
(
"--max-tokens"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--n-threads"
,
type
=
int
,
default
=
128
)
args
=
parser
.
parse_args
()
main
(
args
)
cacheflow/config.py
View file @
1a956e13
...
@@ -116,15 +116,15 @@ class ParallelConfig:
...
@@ -116,15 +116,15 @@ class ParallelConfig:
self
,
self
,
pipeline_parallel_size
:
int
,
pipeline_parallel_size
:
int
,
tensor_parallel_size
:
int
,
tensor_parallel_size
:
int
,
use_ray
:
bool
,
worker_
use_ray
:
bool
,
)
->
None
:
)
->
None
:
self
.
pipeline_parallel_size
=
pipeline_parallel_size
self
.
pipeline_parallel_size
=
pipeline_parallel_size
self
.
tensor_parallel_size
=
tensor_parallel_size
self
.
tensor_parallel_size
=
tensor_parallel_size
self
.
use_ray
=
use_ray
self
.
worker_
use_ray
=
worker_
use_ray
self
.
world_size
=
pipeline_parallel_size
*
tensor_parallel_size
self
.
world_size
=
pipeline_parallel_size
*
tensor_parallel_size
if
self
.
world_size
>
1
:
if
self
.
world_size
>
1
:
self
.
use_ray
=
True
self
.
worker_
use_ray
=
True
self
.
_verify_args
()
self
.
_verify_args
()
def
_verify_args
(
self
)
->
None
:
def
_verify_args
(
self
)
->
None
:
...
...
cacheflow/core/block_manager.py
View file @
1a956e13
...
@@ -148,7 +148,7 @@ class BlockSpaceManager:
...
@@ -148,7 +148,7 @@ class BlockSpaceManager:
# the sequences in the same group.
# the sequences in the same group.
blocks
:
Set
[
PhysicalTokenBlock
]
=
set
()
blocks
:
Set
[
PhysicalTokenBlock
]
=
set
()
for
seq
in
seq_group
.
get_seqs
():
for
seq
in
seq_group
.
get_seqs
():
if
S
eq
uenceStatus
.
is_finished
(
seq
.
status
):
if
s
eq
.
is_finished
():
continue
continue
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
for
block
in
block_table
:
for
block
in
block_table
:
...
@@ -169,7 +169,7 @@ class BlockSpaceManager:
...
@@ -169,7 +169,7 @@ class BlockSpaceManager:
# CPU block -> GPU block.
# CPU block -> GPU block.
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
for
seq
in
seq_group
.
get_seqs
():
for
seq
in
seq_group
.
get_seqs
():
if
S
eq
uenceStatus
.
is_finished
(
seq
.
status
):
if
s
eq
.
is_finished
():
continue
continue
new_block_table
:
BlockTable
=
[]
new_block_table
:
BlockTable
=
[]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
...
@@ -200,7 +200,7 @@ class BlockSpaceManager:
...
@@ -200,7 +200,7 @@ class BlockSpaceManager:
# GPU block -> CPU block.
# GPU block -> CPU block.
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
for
seq
in
seq_group
.
get_seqs
():
for
seq
in
seq_group
.
get_seqs
():
if
S
eq
uenceStatus
.
is_finished
(
seq
.
status
):
if
s
eq
.
is_finished
():
continue
continue
new_block_table
:
BlockTable
=
[]
new_block_table
:
BlockTable
=
[]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
...
@@ -231,6 +231,9 @@ class BlockSpaceManager:
...
@@ -231,6 +231,9 @@ class BlockSpaceManager:
self
.
cpu_allocator
.
free
(
block
)
self
.
cpu_allocator
.
free
(
block
)
def
free
(
self
,
seq
:
Sequence
)
->
None
:
def
free
(
self
,
seq
:
Sequence
)
->
None
:
if
seq
.
seq_id
not
in
self
.
block_tables
:
# Already freed or haven't been scheduled yet.
return
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
self
.
_free_block_table
(
block_table
)
self
.
_free_block_table
(
block_table
)
del
self
.
block_tables
[
seq
.
seq_id
]
del
self
.
block_tables
[
seq
.
seq_id
]
...
...
cacheflow/core/scheduler.py
View file @
1a956e13
...
@@ -12,7 +12,7 @@ from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup,
...
@@ -12,7 +12,7 @@ from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup,
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_LOGGING_INTERVAL_SEC
=
10
_LOGGING_INTERVAL_SEC
=
5
class
PreemptionMode
(
enum
.
Enum
):
class
PreemptionMode
(
enum
.
Enum
):
...
@@ -84,6 +84,18 @@ class Scheduler:
...
@@ -84,6 +84,18 @@ class Scheduler:
# Add sequence groups to the waiting queue.
# Add sequence groups to the waiting queue.
self
.
waiting
.
append
(
seq_group
)
self
.
waiting
.
append
(
seq_group
)
def
abort_seq_group
(
self
,
request_id
:
str
)
->
None
:
for
state_queue
in
[
self
.
waiting
,
self
.
running
,
self
.
swapped
]:
for
seq_group
in
state_queue
:
if
seq_group
.
request_id
==
request_id
:
# Remove the sequence group from the state queue.
state_queue
.
remove
(
seq_group
)
for
seq
in
seq_group
.
seqs
:
if
seq
.
is_finished
():
continue
self
.
free_seq
(
seq
,
SequenceStatus
.
FINISHED_ABORTED
)
return
def
has_unfinished_seqs
(
self
)
->
bool
:
def
has_unfinished_seqs
(
self
)
->
bool
:
return
self
.
waiting
or
self
.
running
or
self
.
swapped
return
self
.
waiting
or
self
.
running
or
self
.
swapped
...
...
cacheflow/entrypoints/openai/openai_frontend.py
View file @
1a956e13
...
@@ -7,13 +7,14 @@ import time
...
@@ -7,13 +7,14 @@ import time
from
typing
import
AsyncGenerator
,
Dict
,
List
,
Optional
from
typing
import
AsyncGenerator
,
Dict
,
List
,
Optional
import
fastapi
import
fastapi
from
fastapi
import
BackgroundTasks
,
Request
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
StreamingResponse
,
JSONResponse
from
fastapi.responses
import
StreamingResponse
,
JSONResponse
import
uvicorn
import
uvicorn
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.arg_utils
import
Async
ServerArgs
from
cacheflow.server.async_llm_server
import
AsyncLLMServer
from
cacheflow.server.async_llm_server
import
AsyncLLMServer
from
cacheflow.server.tokenizer_utils
import
get_tokenizer
from
cacheflow.server.tokenizer_utils
import
get_tokenizer
from
cacheflow.logger
import
init_logger
from
cacheflow.logger
import
init_logger
...
@@ -33,6 +34,7 @@ from cacheflow.entrypoints.openai.protocol import (
...
@@ -33,6 +34,7 @@ from cacheflow.entrypoints.openai.protocol import (
UsageInfo
,
UsageInfo
,
)
)
TIMEOUT_KEEP_ALIVE
=
5
# seconds
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
served_model
=
None
served_model
=
None
...
@@ -93,7 +95,8 @@ def create_logprobs(token_ids: List[int],
...
@@ -93,7 +95,8 @@ def create_logprobs(token_ids: List[int],
@
app
.
post
(
"/v1/completions"
)
@
app
.
post
(
"/v1/completions"
)
async
def
create_completion
(
request
:
CompletionRequest
):
async
def
create_completion
(
raw_request
:
Request
):
request
=
CompletionRequest
(
**
await
raw_request
.
json
())
logger
.
info
(
f
"Received completion request:
{
request
}
"
)
logger
.
info
(
f
"Received completion request:
{
request
}
"
)
error_check_ret
=
await
check_model
(
request
)
error_check_ret
=
await
check_model
(
request
)
...
@@ -139,7 +142,7 @@ async def create_completion(request: CompletionRequest):
...
@@ -139,7 +142,7 @@ async def create_completion(request: CompletionRequest):
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
result_generator
=
server
.
generate
(
prompt
,
sampling_params
,
result_generator
=
server
.
generate
(
prompt
,
sampling_params
,
request_id
=
request_id
)
request_id
)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
# results. In addition, we do not stream the results when use beam search.
...
@@ -147,6 +150,9 @@ async def create_completion(request: CompletionRequest):
...
@@ -147,6 +150,9 @@ async def create_completion(request: CompletionRequest):
(
request
.
best_of
is
None
or
request
.
n
==
request
.
best_of
)
and
(
request
.
best_of
is
None
or
request
.
n
==
request
.
best_of
)
and
not
request
.
use_beam_search
)
not
request
.
use_beam_search
)
async
def
abort_request
()
->
None
:
await
server
.
abort
(
request_id
)
def
create_stream_response_json
(
index
:
int
,
def
create_stream_response_json
(
index
:
int
,
text
:
str
,
text
:
str
,
logprobs
:
Optional
[
LogProbs
]
=
None
,
logprobs
:
Optional
[
LogProbs
]
=
None
,
...
@@ -203,12 +209,21 @@ async def create_completion(request: CompletionRequest):
...
@@ -203,12 +209,21 @@ async def create_completion(request: CompletionRequest):
# Streaming response
# Streaming response
if
stream
:
if
stream
:
background_tasks
=
BackgroundTasks
()
# Abort the request if the client disconnects.
background_tasks
.
add_task
(
abort_request
)
return
StreamingResponse
(
completion_stream_generator
(),
return
StreamingResponse
(
completion_stream_generator
(),
media_type
=
"text/event-stream"
)
media_type
=
"text/event-stream"
,
background
=
background_tasks
)
# Non-streaming response
# Non-streaming response
final_res
:
RequestOutput
=
None
final_res
:
RequestOutput
=
None
async
for
res
in
result_generator
:
async
for
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
server
.
abort
(
request_id
)
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
"Client disconnected"
)
final_res
=
res
final_res
=
res
assert
final_res
is
not
None
assert
final_res
is
not
None
choices
=
[]
choices
=
[]
...
@@ -276,7 +291,7 @@ if __name__ == "__main__":
...
@@ -276,7 +291,7 @@ if __name__ == "__main__":
help
=
"The model name used in the API. If not specified, "
help
=
"The model name used in the API. If not specified, "
"the model name will be the same as the "
"the model name will be the same as the "
"huggingface name."
)
"huggingface name."
)
parser
=
ServerArgs
.
add_cli_args
(
parser
)
parser
=
Async
ServerArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
app
.
add_middleware
(
app
.
add_middleware
(
...
@@ -291,10 +306,11 @@ if __name__ == "__main__":
...
@@ -291,10 +306,11 @@ if __name__ == "__main__":
served_model
=
args
.
served_model_name
or
args
.
model
served_model
=
args
.
served_model_name
or
args
.
model
server_args
=
ServerArgs
.
from_cli_args
(
args
)
server_args
=
Async
ServerArgs
.
from_cli_args
(
args
)
server
=
AsyncLLMServer
.
from_server_args
(
server_args
)
server
=
AsyncLLMServer
.
from_server_args
(
server_args
)
# A separate tokenizer to map token IDs to strings.
# A separate tokenizer to map token IDs to strings.
tokenizer
=
get_tokenizer
(
args
.
model
)
tokenizer
=
get_tokenizer
(
args
.
model
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
,
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
)
cacheflow/entrypoints/simple_fastapi_frontend.py
View file @
1a956e13
...
@@ -2,15 +2,16 @@ import argparse
...
@@ -2,15 +2,16 @@ import argparse
import
json
import
json
from
typing
import
AsyncGenerator
from
typing
import
AsyncGenerator
from
fastapi
import
FastAPI
,
Request
from
fastapi
import
BackgroundTasks
,
FastAPI
,
Request
from
fastapi.responses
import
StreamingResponse
from
fastapi.responses
import
StreamingResponse
import
uvicorn
import
uvicorn
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.arg_utils
import
Async
ServerArgs
from
cacheflow.server.async_llm_server
import
AsyncLLMServer
from
cacheflow.server.async_llm_server
import
AsyncLLMServer
from
cacheflow.
server.ray_
utils
import
initialize_cluster
from
cacheflow.utils
import
random_uuid
TIMEOUT_KEEP_ALIVE
=
5
# seconds.
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds
app
=
FastAPI
()
app
=
FastAPI
()
...
@@ -20,7 +21,8 @@ async def generate_stream(request: Request) -> StreamingResponse:
...
@@ -20,7 +21,8 @@ async def generate_stream(request: Request) -> StreamingResponse:
request_dict
=
await
request
.
json
()
request_dict
=
await
request
.
json
()
prompt
=
request_dict
.
pop
(
"prompt"
)
prompt
=
request_dict
.
pop
(
"prompt"
)
sampling_params
=
SamplingParams
(
**
request_dict
)
sampling_params
=
SamplingParams
(
**
request_dict
)
results_generator
=
server
.
generate
(
prompt
,
sampling_params
)
request_id
=
random_uuid
()
results_generator
=
server
.
generate
(
prompt
,
sampling_params
,
request_id
)
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
async
for
request_output
in
results_generator
:
async
for
request_output
in
results_generator
:
...
@@ -35,17 +37,24 @@ async def generate_stream(request: Request) -> StreamingResponse:
...
@@ -35,17 +37,24 @@ async def generate_stream(request: Request) -> StreamingResponse:
}
}
yield
(
json
.
dumps
(
ret
)
+
"
\0
"
).
encode
(
"utf-8"
)
yield
(
json
.
dumps
(
ret
)
+
"
\0
"
).
encode
(
"utf-8"
)
return
StreamingResponse
(
stream_results
())
async
def
abort_request
()
->
None
:
await
server
.
abort
(
request_id
)
background_tasks
=
BackgroundTasks
()
# Abort the request if the client disconnects.
background_tasks
.
add_task
(
abort_request
)
return
StreamingResponse
(
stream_results
(),
background
=
background_tasks
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8001
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8001
)
parser
=
ServerArgs
.
add_cli_args
(
parser
)
parser
=
Async
ServerArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
server_args
=
ServerArgs
.
from_cli_args
(
args
)
server_args
=
Async
ServerArgs
.
from_cli_args
(
args
)
server
=
AsyncLLMServer
.
from_server_args
(
server_args
)
server
=
AsyncLLMServer
.
from_server_args
(
server_args
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"debug"
,
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
)
cacheflow/sequence.py
View file @
1a956e13
...
@@ -12,12 +12,14 @@ class SequenceStatus(enum.Enum):
...
@@ -12,12 +12,14 @@ class SequenceStatus(enum.Enum):
SWAPPED
=
enum
.
auto
()
SWAPPED
=
enum
.
auto
()
FINISHED_STOPPED
=
enum
.
auto
()
FINISHED_STOPPED
=
enum
.
auto
()
FINISHED_LENGTH_CAPPED
=
enum
.
auto
()
FINISHED_LENGTH_CAPPED
=
enum
.
auto
()
FINISHED_ABORTED
=
enum
.
auto
()
@
staticmethod
@
staticmethod
def
is_finished
(
status
:
"SequenceStatus"
)
->
bool
:
def
is_finished
(
status
:
"SequenceStatus"
)
->
bool
:
return
status
in
[
return
status
in
[
SequenceStatus
.
FINISHED_STOPPED
,
SequenceStatus
.
FINISHED_STOPPED
,
SequenceStatus
.
FINISHED_LENGTH_CAPPED
,
SequenceStatus
.
FINISHED_LENGTH_CAPPED
,
SequenceStatus
.
FINISHED_ABORTED
,
]
]
@
staticmethod
@
staticmethod
...
@@ -26,10 +28,13 @@ class SequenceStatus(enum.Enum):
...
@@ -26,10 +28,13 @@ class SequenceStatus(enum.Enum):
finish_reason
=
"stop"
finish_reason
=
"stop"
elif
status
==
SequenceStatus
.
FINISHED_LENGTH_CAPPED
:
elif
status
==
SequenceStatus
.
FINISHED_LENGTH_CAPPED
:
finish_reason
=
"length"
finish_reason
=
"length"
elif
status
==
SequenceStatus
.
FINISHED_ABORTED
:
finish_reason
=
"abort"
else
:
else
:
finish_reason
=
None
finish_reason
=
None
return
finish_reason
return
finish_reason
class
SequenceData
:
class
SequenceData
:
def
__init__
(
def
__init__
(
...
@@ -137,6 +142,9 @@ class Sequence:
...
@@ -137,6 +142,9 @@ class Sequence:
def
get_cumulative_logprob
(
self
)
->
float
:
def
get_cumulative_logprob
(
self
)
->
float
:
return
self
.
data
.
cumulative_logprob
return
self
.
data
.
cumulative_logprob
def
is_finished
(
self
)
->
bool
:
return
SequenceStatus
.
is_finished
(
self
.
status
)
def
fork
(
self
,
child_seq
:
'Sequence'
)
->
None
:
def
fork
(
self
,
child_seq
:
'Sequence'
)
->
None
:
child_seq
.
logical_token_blocks
=
copy
.
deepcopy
(
self
.
logical_token_blocks
)
child_seq
.
logical_token_blocks
=
copy
.
deepcopy
(
self
.
logical_token_blocks
)
child_seq
.
output_logprobs
=
copy
.
deepcopy
(
self
.
output_logprobs
)
child_seq
.
output_logprobs
=
copy
.
deepcopy
(
self
.
output_logprobs
)
...
@@ -182,7 +190,7 @@ class SequenceGroup:
...
@@ -182,7 +190,7 @@ class SequenceGroup:
raise
ValueError
(
f
'Sequence
{
seq_id
}
not found.'
)
raise
ValueError
(
f
'Sequence
{
seq_id
}
not found.'
)
def
is_finished
(
self
)
->
bool
:
def
is_finished
(
self
)
->
bool
:
return
all
(
S
eq
uenceStatus
.
is_finished
(
seq
.
status
)
for
seq
in
self
.
seqs
)
return
all
(
s
eq
.
is_finished
()
for
seq
in
self
.
seqs
)
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceGroup(request_id=
{
self
.
request_id
}
, "
return
(
f
"SequenceGroup(request_id=
{
self
.
request_id
}
, "
...
...
cacheflow/server/arg_utils.py
View file @
1a956e13
...
@@ -15,7 +15,7 @@ class ServerArgs:
...
@@ -15,7 +15,7 @@ class ServerArgs:
use_dummy_weights
:
bool
=
False
use_dummy_weights
:
bool
=
False
dtype
:
str
=
"default"
dtype
:
str
=
"default"
seed
:
int
=
0
seed
:
int
=
0
use_ray
:
bool
=
False
worker_
use_ray
:
bool
=
False
pipeline_parallel_size
:
int
=
1
pipeline_parallel_size
:
int
=
1
tensor_parallel_size
:
int
=
1
tensor_parallel_size
:
int
=
1
block_size
:
int
=
16
block_size
:
int
=
16
...
@@ -32,7 +32,63 @@ class ServerArgs:
...
@@ -32,7 +32,63 @@ class ServerArgs:
def
add_cli_args
(
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
,
parser
:
argparse
.
ArgumentParser
,
)
->
argparse
.
ArgumentParser
:
)
->
argparse
.
ArgumentParser
:
return
_add_server_arguments
(
parser
)
"""Shared CLI arguments for CacheFlow servers."""
# Model arguments
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'facebook/opt-125m'
,
help
=
'name or path of the huggingface model to use'
)
parser
.
add_argument
(
'--download-dir'
,
type
=
str
,
default
=
ServerArgs
.
download_dir
,
help
=
'directory to download and load the weights, '
'default to the default cache dir of '
'huggingface'
)
parser
.
add_argument
(
'--use-np-weights'
,
action
=
'store_true'
,
help
=
'save a numpy copy of model weights for '
'faster loading. This can increase the disk '
'usage by up to 2x.'
)
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
# TODO(woosuk): Support FP32.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
ServerArgs
.
dtype
,
choices
=
[
'default'
,
'half'
,
'bfloat16'
],
help
=
'data type for model weights and activations. '
'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'
)
# Parallel arguments
parser
.
add_argument
(
'--worker-use-ray'
,
action
=
'store_true'
,
help
=
'use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU'
)
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
type
=
int
,
default
=
ServerArgs
.
pipeline_parallel_size
,
help
=
'number of pipeline stages'
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
ServerArgs
.
tensor_parallel_size
,
help
=
'number of tensor parallel replicas'
)
# KV cache arguments
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
ServerArgs
.
block_size
,
choices
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
],
help
=
'token block size'
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
ServerArgs
.
seed
,
help
=
'random seed'
)
parser
.
add_argument
(
'--swap-space'
,
type
=
int
,
default
=
ServerArgs
.
swap_space
,
help
=
'CPU swap space size (GiB) per GPU'
)
parser
.
add_argument
(
'--gpu-memory-utilization'
,
type
=
float
,
default
=
ServerArgs
.
gpu_memory_utilization
,
help
=
'the percentage of GPU memory to be used for'
'the model executor'
)
parser
.
add_argument
(
'--max-num-batched-tokens'
,
type
=
int
,
default
=
ServerArgs
.
max_num_batched_tokens
,
help
=
'maximum number of batched tokens per '
'iteration'
)
parser
.
add_argument
(
'--max-num-seqs'
,
type
=
int
,
default
=
ServerArgs
.
max_num_seqs
,
help
=
'maximum number of sequences per iteration'
)
parser
.
add_argument
(
'--disable-log-stats'
,
action
=
'store_true'
,
help
=
'disable logging statistics'
)
return
parser
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
)
->
"ServerArgs"
:
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
)
->
"ServerArgs"
:
...
@@ -53,65 +109,22 @@ class ServerArgs:
...
@@ -53,65 +109,22 @@ class ServerArgs:
self
.
swap_space
)
self
.
swap_space
)
parallel_config
=
ParallelConfig
(
self
.
pipeline_parallel_size
,
parallel_config
=
ParallelConfig
(
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
,
self
.
tensor_parallel_size
,
self
.
use_ray
)
self
.
worker_
use_ray
)
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
self
.
max_num_seqs
)
self
.
max_num_seqs
)
return
model_config
,
cache_config
,
parallel_config
,
scheduler_config
return
model_config
,
cache_config
,
parallel_config
,
scheduler_config
def
_add_server_arguments
(
@
dataclass
parser
:
argparse
.
ArgumentParser
,
class
AsyncServerArgs
(
ServerArgs
):
)
->
argparse
.
ArgumentParser
:
server_use_ray
:
bool
=
False
"""Shared CLI arguments for CacheFlow servers."""
# Model arguments
@
staticmethod
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'facebook/opt-125m'
,
def
add_cli_args
(
help
=
'name or path of the huggingface model to use'
)
parser
:
argparse
.
ArgumentParser
,
parser
.
add_argument
(
'--download-dir'
,
type
=
str
,
)
->
argparse
.
ArgumentParser
:
default
=
ServerArgs
.
download_dir
,
parser
=
ServerArgs
.
add_cli_args
(
parser
)
help
=
'directory to download and load the weights, '
parser
.
add_argument
(
'--server-use-ray'
,
action
=
'store_true'
,
'default to the default cache dir of huggingface'
)
help
=
'use Ray to start the LLM server in a '
parser
.
add_argument
(
'--use-np-weights'
,
action
=
'store_true'
,
'separate process as the web server process.'
)
help
=
'save a numpy copy of model weights for faster '
return
parser
'loading. This can increase the disk usage by up '
'to 2x.'
)
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
# TODO(woosuk): Support FP32.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
ServerArgs
.
dtype
,
choices
=
[
'default'
,
'half'
,
'bfloat16'
],
help
=
(
'data type for model weights and activations. '
'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'
))
# Parallel arguments
parser
.
add_argument
(
'--use-ray'
,
action
=
'store_true'
,
help
=
'use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU'
)
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
type
=
int
,
default
=
ServerArgs
.
pipeline_parallel_size
,
help
=
'number of pipeline stages'
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
ServerArgs
.
tensor_parallel_size
,
help
=
'number of tensor parallel replicas'
)
# KV cache arguments
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
ServerArgs
.
block_size
,
choices
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
],
help
=
'token block size'
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
ServerArgs
.
seed
,
help
=
'random seed'
)
parser
.
add_argument
(
'--swap-space'
,
type
=
int
,
default
=
ServerArgs
.
swap_space
,
help
=
'CPU swap space size (GiB) per GPU'
)
parser
.
add_argument
(
'--gpu-memory-utilization'
,
type
=
float
,
default
=
ServerArgs
.
gpu_memory_utilization
,
help
=
'the percentage of GPU memory to be used for the '
'model executor'
)
parser
.
add_argument
(
'--max-num-batched-tokens'
,
type
=
int
,
default
=
ServerArgs
.
max_num_batched_tokens
,
help
=
'maximum number of batched tokens per iteration'
)
parser
.
add_argument
(
'--max-num-seqs'
,
type
=
int
,
default
=
ServerArgs
.
max_num_seqs
,
help
=
'maximum number of sequences per iteration'
)
parser
.
add_argument
(
'--disable-log-stats'
,
action
=
'store_true'
,
help
=
'disable logging statistics'
)
return
parser
cacheflow/server/async_llm_server.py
View file @
1a956e13
...
@@ -2,37 +2,52 @@ import asyncio
...
@@ -2,37 +2,52 @@ import asyncio
import
time
import
time
from
typing
import
Dict
,
Optional
from
typing
import
Dict
,
Optional
import
ray
from
cacheflow.logger
import
init_logger
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.arg_utils
import
Async
ServerArgs
from
cacheflow.server.llm_server
import
LLMServer
from
cacheflow.server.llm_server
import
LLMServer
from
cacheflow.server.ray_utils
import
initialize_cluster
from
cacheflow.server.ray_utils
import
ray
,
initialize_cluster
from
cacheflow.utils
import
random_uuid
logger
=
init_logger
(
__name__
)
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds
class
AsyncLLMServer
:
class
AsyncLLMServer
:
def
__init__
(
self
,
server_use_ray
:
bool
,
*
args
,
**
kwargs
)
->
None
:
def
__init__
(
self
,
worker_use_ray
:
bool
,
server_use_ray
:
bool
,
if
server_use_ray
:
*
args
,
**
kwargs
)
->
None
:
remote_server_class
=
ray
.
remote
(
num_cpus
=
0
)(
LLMServer
)
self
.
worker_use_ray
=
worker_use_ray
self
.
server_use_ray
=
server_use_ray
if
not
self
.
server_use_ray
:
server_class
=
LLMServer
elif
self
.
worker_use_ray
:
server_class
=
ray
.
remote
(
num_cpus
=
0
)(
LLMServer
).
remote
else
:
else
:
remote_server_class
=
ray
.
remote
(
num_gpus
=
1
)(
LLMServer
)
server_class
=
ray
.
remote
(
num_gpus
=
1
)(
LLMServer
).
remote
self
.
server
=
remote_server_class
.
remote
(
*
args
,
**
kwargs
)
self
.
server
=
server_class
(
*
args
,
**
kwargs
)
# Request id -> request output.
# Request id -> request output.
self
.
request_outputs
:
Dict
[
str
,
RequestOutput
]
=
{}
self
.
request_outputs
:
Dict
[
str
,
RequestOutput
]
=
{}
# Request id -> event to notify that there is new output.
# Request id -> event to notify that there is new output.
self
.
request_events
:
Dict
[
str
,
asyncio
.
Event
]
=
{}
self
.
request_events
:
Dict
[
str
,
asyncio
.
Event
]
=
{}
self
.
is_server_running
=
False
self
.
is_server_running
=
False
self
.
kicking_request_id
:
Optional
[
str
]
=
None
async
def
server_step
(
self
):
async
def
server_step
(
self
,
kicking_request_id
:
Optional
[
str
]
=
None
):
self
.
is_server_running
=
True
self
.
is_server_running
=
True
request_outputs
=
await
self
.
server
.
step
.
remote
()
self
.
kicking_request_id
=
kicking_request_id
if
self
.
server_use_ray
:
request_outputs
=
await
self
.
server
.
step
.
remote
()
else
:
# Yield to the event loop to allow other coroutines to run
# while is_server_running is True. This let the server to add new
# requests into the queue.
await
asyncio
.
sleep
(
0
)
request_outputs
=
self
.
server
.
step
()
self
.
is_server_running
=
False
self
.
is_server_running
=
False
self
.
kicking_request_id
=
None
# Notify the waiting coroutines that there are new outputs ready.
# Notify the waiting coroutines that there are new outputs ready.
for
request_output
in
request_outputs
:
for
request_output
in
request_outputs
:
request_id
=
request_output
.
request_id
request_id
=
request_output
.
request_id
...
@@ -40,20 +55,26 @@ class AsyncLLMServer:
...
@@ -40,20 +55,26 @@ class AsyncLLMServer:
self
.
request_events
[
request_id
].
set
()
self
.
request_events
[
request_id
].
set
()
async
def
generate
(
self
,
prompt
:
str
,
sampling_params
:
SamplingParams
,
async
def
generate
(
self
,
prompt
:
str
,
sampling_params
:
SamplingParams
,
request_id
:
Optional
[
str
]
=
None
)
->
RequestOutput
:
request_id
:
str
)
->
RequestOutput
:
# Preprocess the request.
# Preprocess the request.
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
# Create an event to notify us that there is new output from the
# Create an event to notify us that there is new output from the
# cacheflow server.
# cacheflow server.
if
request_id
is
None
:
request_id
=
random_uuid
()
request_event
=
asyncio
.
Event
()
request_event
=
asyncio
.
Event
()
self
.
request_events
[
request_id
]
=
request_event
self
.
request_events
[
request_id
]
=
request_event
logger
.
info
(
f
"Received request
{
request_id
}
: "
f
"prompt:
{
prompt
!
r
}
, "
f
"sampling params:
{
sampling_params
}
."
)
# Add the request into the cacheflow server's waiting queue.
# Add the request into the cacheflow server's waiting queue.
await
self
.
server
.
add_request
.
remote
(
if
self
.
server_use_ray
:
request_id
,
prompt
,
sampling_params
,
arrival_time
=
arrival_time
)
await
self
.
server
.
add_request
.
remote
(
request_id
,
prompt
,
sampling_params
,
arrival_time
=
arrival_time
)
else
:
self
.
server
.
add_request
(
request_id
,
prompt
,
sampling_params
,
arrival_time
=
arrival_time
)
# The cacheflow server does not have a background loop that keeps
# The cacheflow server does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking
# processing incoming requests. Therefore, we need to keep kicking
...
@@ -61,7 +82,7 @@ class AsyncLLMServer:
...
@@ -61,7 +82,7 @@ class AsyncLLMServer:
while
True
:
while
True
:
# Kick the server if the server is not running.
# Kick the server if the server is not running.
if
not
self
.
is_server_running
:
if
not
self
.
is_server_running
:
await
self
.
server_step
()
await
self
.
server_step
(
request_id
)
# Wait for new output. The group_event will be set in server_step
# Wait for new output. The group_event will be set in server_step
# when there is new output available for the sequence group.
# when there is new output available for the sequence group.
...
@@ -80,6 +101,8 @@ class AsyncLLMServer:
...
@@ -80,6 +101,8 @@ class AsyncLLMServer:
# Once finished, release the resources of the sequence group.
# Once finished, release the resources of the sequence group.
if
request_output
.
finished
():
if
request_output
.
finished
():
logger
.
info
(
f
"Finished request
{
request_id
}
."
)
del
self
.
request_outputs
[
request_id
]
del
self
.
request_outputs
[
request_id
]
del
self
.
request_events
[
request_id
]
del
self
.
request_events
[
request_id
]
# Kick the server if the server is not running. This is to
# Kick the server if the server is not running. This is to
...
@@ -89,15 +112,41 @@ class AsyncLLMServer:
...
@@ -89,15 +112,41 @@ class AsyncLLMServer:
await
self
.
server_step
()
await
self
.
server_step
()
break
break
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
if
request_id
not
in
self
.
request_events
:
# The request has already finished or been aborted.
return
logger
.
info
(
f
"Aborted request
{
request_id
}
."
)
if
self
.
server_use_ray
:
await
self
.
server
.
abort_request
.
remote
(
request_id
)
else
:
self
.
server
.
abort_request
(
request_id
)
if
request_id
in
self
.
request_events
:
del
self
.
request_events
[
request_id
]
if
request_id
in
self
.
request_outputs
:
del
self
.
request_outputs
[
request_id
]
# To prevent deadlock when a request is aborted while the server is
# running.
if
self
.
kicking_request_id
==
request_id
:
self
.
is_server_running
=
False
self
.
kicking_request_id
=
None
@
classmethod
@
classmethod
def
from_server_args
(
cls
,
server_args
:
ServerArgs
)
->
"AsyncLLMServer"
:
def
from_server_args
(
cls
,
server_args
:
Async
ServerArgs
)
->
"AsyncLLMServer"
:
# Create the server configs.
# Create the server configs.
server_configs
=
server_args
.
create_server_configs
()
server_configs
=
server_args
.
create_server_configs
()
parallel_config
=
server_configs
[
2
]
parallel_config
=
server_configs
[
2
]
# Initialize the cluster.
# Initialize the cluster.
distributed_init_method
,
devices
=
initialize_cluster
(
parallel_config
)
distributed_init_method
,
devices
=
initialize_cluster
(
parallel_config
,
server_args
.
server_use_ray
)
# Create the LLM server.
# Create the LLM server.
server
=
cls
(
server_args
.
use_ray
,
*
server_configs
,
server
=
cls
(
server_args
.
worker_use_ray
,
server_args
.
server_use_ray
,
*
server_configs
,
distributed_init_method
,
devices
,
distributed_init_method
,
devices
,
log_stats
=
not
server_args
.
disable_log_stats
)
log_stats
=
not
server_args
.
disable_log_stats
)
return
server
return
server
cacheflow/server/llm_server.py
View file @
1a956e13
import
time
import
time
from
typing
import
Any
,
List
,
Optional
from
typing
import
Any
,
List
,
Optional
try
:
import
ray
except
ImportError
:
ray
=
None
from
cacheflow.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
from
cacheflow.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
SchedulerConfig
)
from
cacheflow.core.scheduler
import
Scheduler
from
cacheflow.core.scheduler
import
Scheduler
...
@@ -13,7 +8,7 @@ from cacheflow.logger import init_logger
...
@@ -13,7 +8,7 @@ from cacheflow.logger import init_logger
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.ray_utils
import
initialize_cluster
from
cacheflow.server.ray_utils
import
ray
,
initialize_cluster
from
cacheflow.server.tokenizer_utils
import
(
get_tokenizer
,
from
cacheflow.server.tokenizer_utils
import
(
get_tokenizer
,
detokenize_incrementally
)
detokenize_incrementally
)
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
...
@@ -62,7 +57,7 @@ class LLMServer:
...
@@ -62,7 +57,7 @@ class LLMServer:
assert
len
(
stage_devices
)
==
1
,
"Only support one stage for now."
assert
len
(
stage_devices
)
==
1
,
"Only support one stage for now."
for
rank
,
node_resource
,
_
in
stage_devices
[
0
]:
for
rank
,
node_resource
,
_
in
stage_devices
[
0
]:
worker_cls
=
Worker
worker_cls
=
Worker
if
self
.
parallel_config
.
use_ray
:
if
self
.
parallel_config
.
worker_
use_ray
:
worker_cls
=
ray
.
remote
(
worker_cls
=
ray
.
remote
(
num_cpus
=
0
,
num_cpus
=
0
,
num_gpus
=
1
,
num_gpus
=
1
,
...
@@ -152,6 +147,9 @@ class LLMServer:
...
@@ -152,6 +147,9 @@ class LLMServer:
# Add the sequence group to the scheduler.
# Add the sequence group to the scheduler.
self
.
scheduler
.
add_seq_group
(
seq_group
)
self
.
scheduler
.
add_seq_group
(
seq_group
)
def
abort_request
(
self
,
request_id
:
str
)
->
None
:
self
.
scheduler
.
abort_seq_group
(
request_id
)
def
get_num_unfinished_requests
(
self
)
->
int
:
def
get_num_unfinished_requests
(
self
)
->
int
:
return
self
.
scheduler
.
get_num_unfinished_seq_groups
()
return
self
.
scheduler
.
get_num_unfinished_seq_groups
()
...
@@ -243,13 +241,13 @@ class LLMServer:
...
@@ -243,13 +241,13 @@ class LLMServer:
all_outputs
=
[]
all_outputs
=
[]
for
worker
in
self
.
workers
:
for
worker
in
self
.
workers
:
executor
=
getattr
(
worker
,
method
)
executor
=
getattr
(
worker
,
method
)
if
self
.
parallel_config
.
use_ray
:
if
self
.
parallel_config
.
worker_
use_ray
:
executor
=
executor
.
remote
executor
=
executor
.
remote
output
=
executor
(
*
args
,
**
kwargs
)
output
=
executor
(
*
args
,
**
kwargs
)
all_outputs
.
append
(
output
)
all_outputs
.
append
(
output
)
if
self
.
parallel_config
.
use_ray
:
if
self
.
parallel_config
.
worker_
use_ray
:
all_outputs
=
ray
.
get
(
all_outputs
)
all_outputs
=
ray
.
get
(
all_outputs
)
if
get_all_outputs
:
if
get_all_outputs
:
...
...
cacheflow/server/ray_utils.py
View file @
1a956e13
...
@@ -13,9 +13,18 @@ DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), devi
...
@@ -13,9 +13,18 @@ DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), devi
def
initialize_cluster
(
def
initialize_cluster
(
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
server_use_ray
:
bool
=
False
,
address
:
Optional
[
str
]
=
None
,
address
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
str
,
List
[
List
[
DeviceID
]]]:
)
->
Tuple
[
str
,
List
[
List
[
DeviceID
]]]:
if
not
parallel_config
.
use_ray
:
if
parallel_config
.
worker_use_ray
or
server_use_ray
:
if
ray
is
None
:
raise
ImportError
(
"Ray is not installed. Please install Ray to use distributed "
"serving."
)
# Connect to a ray cluster.
ray
.
init
(
address
=
address
)
if
not
parallel_config
.
worker_use_ray
:
# Initialize cluster locally.
# Initialize cluster locally.
port
=
random
.
randint
(
10000
,
20000
)
port
=
random
.
randint
(
10000
,
20000
)
# We need to setup the distributed init method to make sure
# We need to setup the distributed init method to make sure
...
@@ -24,13 +33,6 @@ def initialize_cluster(
...
@@ -24,13 +33,6 @@ def initialize_cluster(
all_stage_devices
=
[[(
0
,
None
,
0
)]]
all_stage_devices
=
[[(
0
,
None
,
0
)]]
return
distributed_init_method
,
all_stage_devices
return
distributed_init_method
,
all_stage_devices
if
ray
is
None
:
raise
ImportError
(
"Ray is not installed. Please install Ray to use distributed "
"serving."
)
# Connect to a ray cluster.
ray
.
init
(
address
=
address
)
# Assume we have a uniform cluster that each node has the same number of
# Assume we have a uniform cluster that each node has the same number of
# GPUs for now.
# GPUs for now.
valid_node_resources
=
[]
valid_node_resources
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment