Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
1a956e13
Unverified
Commit
1a956e13
authored
Jun 05, 2023
by
Zhuohan Li
Committed by
GitHub
Jun 05, 2023
Browse files
Fix various issues of async servers (#135)
parent
8274ca23
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
289 additions
and
121 deletions
+289
-121
benchmarks/benchmark_async_llm_server.py
benchmarks/benchmark_async_llm_server.py
+58
-0
cacheflow/config.py
cacheflow/config.py
+3
-3
cacheflow/core/block_manager.py
cacheflow/core/block_manager.py
+6
-3
cacheflow/core/scheduler.py
cacheflow/core/scheduler.py
+13
-1
cacheflow/entrypoints/openai/openai_frontend.py
cacheflow/entrypoints/openai/openai_frontend.py
+23
-7
cacheflow/entrypoints/simple_fastapi_frontend.py
cacheflow/entrypoints/simple_fastapi_frontend.py
+17
-8
cacheflow/sequence.py
cacheflow/sequence.py
+9
-1
cacheflow/server/arg_utils.py
cacheflow/server/arg_utils.py
+72
-59
cacheflow/server/async_llm_server.py
cacheflow/server/async_llm_server.py
+71
-22
cacheflow/server/llm_server.py
cacheflow/server/llm_server.py
+7
-9
cacheflow/server/ray_utils.py
cacheflow/server/ray_utils.py
+10
-8
No files found.
benchmarks/benchmark_async_llm_server.py
0 → 100644
View file @
1a956e13
import
argparse
import
json
import
threading
import
time
import
requests
def
main
(
args
:
argparse
.
Namespace
):
prompts
=
[
f
"Tell me a story with more than
{
''
.
join
([
str
(
i
+
1
)]
*
5
)
}
words"
for
i
in
range
(
args
.
n_threads
)]
headers
=
{
"User-Agent"
:
"CacheFlow Benchmark Client"
}
ploads
=
[{
"prompt"
:
p
,
"max_tokens"
:
args
.
max_tokens
,
"temperature"
:
0.0
,
"ignore_eos"
:
True
,
}
for
p
in
prompts
]
def
send_request
(
results
,
i
):
response
=
requests
.
post
(
args
.
api_url
,
headers
=
headers
,
json
=
ploads
[
i
],
stream
=
True
)
results
[
i
]
=
response
# use args.n_threads to prompt the backend
tik
=
time
.
time
()
threads
=
[]
results
=
[
None
]
*
args
.
n_threads
for
i
in
range
(
args
.
n_threads
):
t
=
threading
.
Thread
(
target
=
send_request
,
args
=
(
results
,
i
))
t
.
start
()
threads
.
append
(
t
)
for
t
in
threads
:
t
.
join
()
print
(
f
"Time (POST):
{
time
.
time
()
-
tik
}
s"
)
n_words
=
0
for
i
,
response
in
enumerate
(
results
):
k
=
list
(
response
.
iter_lines
(
chunk_size
=
8192
,
decode_unicode
=
False
,
delimiter
=
b
"
\0
"
))
response_new_words
=
json
.
loads
(
k
[
-
2
].
decode
(
"utf-8"
))[
"text"
][
0
]
n_words
+=
len
(
response_new_words
.
split
(
" "
))
-
len
(
prompts
[
i
].
split
(
" "
))
time_seconds
=
time
.
time
()
-
tik
print
(
f
"Time (total):
{
time_seconds
:.
3
f
}
s to finish, n_threads:
{
args
.
n_threads
}
, "
f
"throughput:
{
n_words
/
time_seconds
}
words/s."
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--api-url"
,
type
=
str
,
default
=
"http://localhost:8001/generate"
)
parser
.
add_argument
(
"--max-tokens"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--n-threads"
,
type
=
int
,
default
=
128
)
args
=
parser
.
parse_args
()
main
(
args
)
cacheflow/config.py
View file @
1a956e13
...
@@ -116,15 +116,15 @@ class ParallelConfig:
...
@@ -116,15 +116,15 @@ class ParallelConfig:
self
,
self
,
pipeline_parallel_size
:
int
,
pipeline_parallel_size
:
int
,
tensor_parallel_size
:
int
,
tensor_parallel_size
:
int
,
use_ray
:
bool
,
worker_
use_ray
:
bool
,
)
->
None
:
)
->
None
:
self
.
pipeline_parallel_size
=
pipeline_parallel_size
self
.
pipeline_parallel_size
=
pipeline_parallel_size
self
.
tensor_parallel_size
=
tensor_parallel_size
self
.
tensor_parallel_size
=
tensor_parallel_size
self
.
use_ray
=
use_ray
self
.
worker_
use_ray
=
worker_
use_ray
self
.
world_size
=
pipeline_parallel_size
*
tensor_parallel_size
self
.
world_size
=
pipeline_parallel_size
*
tensor_parallel_size
if
self
.
world_size
>
1
:
if
self
.
world_size
>
1
:
self
.
use_ray
=
True
self
.
worker_
use_ray
=
True
self
.
_verify_args
()
self
.
_verify_args
()
def
_verify_args
(
self
)
->
None
:
def
_verify_args
(
self
)
->
None
:
...
...
cacheflow/core/block_manager.py
View file @
1a956e13
...
@@ -148,7 +148,7 @@ class BlockSpaceManager:
...
@@ -148,7 +148,7 @@ class BlockSpaceManager:
# the sequences in the same group.
# the sequences in the same group.
blocks
:
Set
[
PhysicalTokenBlock
]
=
set
()
blocks
:
Set
[
PhysicalTokenBlock
]
=
set
()
for
seq
in
seq_group
.
get_seqs
():
for
seq
in
seq_group
.
get_seqs
():
if
S
eq
uenceStatus
.
is_finished
(
seq
.
status
):
if
s
eq
.
is_finished
():
continue
continue
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
for
block
in
block_table
:
for
block
in
block_table
:
...
@@ -169,7 +169,7 @@ class BlockSpaceManager:
...
@@ -169,7 +169,7 @@ class BlockSpaceManager:
# CPU block -> GPU block.
# CPU block -> GPU block.
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
for
seq
in
seq_group
.
get_seqs
():
for
seq
in
seq_group
.
get_seqs
():
if
S
eq
uenceStatus
.
is_finished
(
seq
.
status
):
if
s
eq
.
is_finished
():
continue
continue
new_block_table
:
BlockTable
=
[]
new_block_table
:
BlockTable
=
[]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
...
@@ -200,7 +200,7 @@ class BlockSpaceManager:
...
@@ -200,7 +200,7 @@ class BlockSpaceManager:
# GPU block -> CPU block.
# GPU block -> CPU block.
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
for
seq
in
seq_group
.
get_seqs
():
for
seq
in
seq_group
.
get_seqs
():
if
S
eq
uenceStatus
.
is_finished
(
seq
.
status
):
if
s
eq
.
is_finished
():
continue
continue
new_block_table
:
BlockTable
=
[]
new_block_table
:
BlockTable
=
[]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
...
@@ -231,6 +231,9 @@ class BlockSpaceManager:
...
@@ -231,6 +231,9 @@ class BlockSpaceManager:
self
.
cpu_allocator
.
free
(
block
)
self
.
cpu_allocator
.
free
(
block
)
def
free
(
self
,
seq
:
Sequence
)
->
None
:
def
free
(
self
,
seq
:
Sequence
)
->
None
:
if
seq
.
seq_id
not
in
self
.
block_tables
:
# Already freed or haven't been scheduled yet.
return
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
self
.
_free_block_table
(
block_table
)
self
.
_free_block_table
(
block_table
)
del
self
.
block_tables
[
seq
.
seq_id
]
del
self
.
block_tables
[
seq
.
seq_id
]
...
...
cacheflow/core/scheduler.py
View file @
1a956e13
...
@@ -12,7 +12,7 @@ from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup,
...
@@ -12,7 +12,7 @@ from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup,
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_LOGGING_INTERVAL_SEC
=
10
_LOGGING_INTERVAL_SEC
=
5
class
PreemptionMode
(
enum
.
Enum
):
class
PreemptionMode
(
enum
.
Enum
):
...
@@ -84,6 +84,18 @@ class Scheduler:
...
@@ -84,6 +84,18 @@ class Scheduler:
# Add sequence groups to the waiting queue.
# Add sequence groups to the waiting queue.
self
.
waiting
.
append
(
seq_group
)
self
.
waiting
.
append
(
seq_group
)
def
abort_seq_group
(
self
,
request_id
:
str
)
->
None
:
for
state_queue
in
[
self
.
waiting
,
self
.
running
,
self
.
swapped
]:
for
seq_group
in
state_queue
:
if
seq_group
.
request_id
==
request_id
:
# Remove the sequence group from the state queue.
state_queue
.
remove
(
seq_group
)
for
seq
in
seq_group
.
seqs
:
if
seq
.
is_finished
():
continue
self
.
free_seq
(
seq
,
SequenceStatus
.
FINISHED_ABORTED
)
return
def
has_unfinished_seqs
(
self
)
->
bool
:
def
has_unfinished_seqs
(
self
)
->
bool
:
return
self
.
waiting
or
self
.
running
or
self
.
swapped
return
self
.
waiting
or
self
.
running
or
self
.
swapped
...
...
cacheflow/entrypoints/openai/openai_frontend.py
View file @
1a956e13
...
@@ -7,13 +7,14 @@ import time
...
@@ -7,13 +7,14 @@ import time
from
typing
import
AsyncGenerator
,
Dict
,
List
,
Optional
from
typing
import
AsyncGenerator
,
Dict
,
List
,
Optional
import
fastapi
import
fastapi
from
fastapi
import
BackgroundTasks
,
Request
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
StreamingResponse
,
JSONResponse
from
fastapi.responses
import
StreamingResponse
,
JSONResponse
import
uvicorn
import
uvicorn
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.arg_utils
import
Async
ServerArgs
from
cacheflow.server.async_llm_server
import
AsyncLLMServer
from
cacheflow.server.async_llm_server
import
AsyncLLMServer
from
cacheflow.server.tokenizer_utils
import
get_tokenizer
from
cacheflow.server.tokenizer_utils
import
get_tokenizer
from
cacheflow.logger
import
init_logger
from
cacheflow.logger
import
init_logger
...
@@ -33,6 +34,7 @@ from cacheflow.entrypoints.openai.protocol import (
...
@@ -33,6 +34,7 @@ from cacheflow.entrypoints.openai.protocol import (
UsageInfo
,
UsageInfo
,
)
)
TIMEOUT_KEEP_ALIVE
=
5
# seconds
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
served_model
=
None
served_model
=
None
...
@@ -93,7 +95,8 @@ def create_logprobs(token_ids: List[int],
...
@@ -93,7 +95,8 @@ def create_logprobs(token_ids: List[int],
@
app
.
post
(
"/v1/completions"
)
@
app
.
post
(
"/v1/completions"
)
async
def
create_completion
(
request
:
CompletionRequest
):
async
def
create_completion
(
raw_request
:
Request
):
request
=
CompletionRequest
(
**
await
raw_request
.
json
())
logger
.
info
(
f
"Received completion request:
{
request
}
"
)
logger
.
info
(
f
"Received completion request:
{
request
}
"
)
error_check_ret
=
await
check_model
(
request
)
error_check_ret
=
await
check_model
(
request
)
...
@@ -139,7 +142,7 @@ async def create_completion(request: CompletionRequest):
...
@@ -139,7 +142,7 @@ async def create_completion(request: CompletionRequest):
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
result_generator
=
server
.
generate
(
prompt
,
sampling_params
,
result_generator
=
server
.
generate
(
prompt
,
sampling_params
,
request_id
=
request_id
)
request_id
)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
# results. In addition, we do not stream the results when use beam search.
...
@@ -147,6 +150,9 @@ async def create_completion(request: CompletionRequest):
...
@@ -147,6 +150,9 @@ async def create_completion(request: CompletionRequest):
(
request
.
best_of
is
None
or
request
.
n
==
request
.
best_of
)
and
(
request
.
best_of
is
None
or
request
.
n
==
request
.
best_of
)
and
not
request
.
use_beam_search
)
not
request
.
use_beam_search
)
async
def
abort_request
()
->
None
:
await
server
.
abort
(
request_id
)
def
create_stream_response_json
(
index
:
int
,
def
create_stream_response_json
(
index
:
int
,
text
:
str
,
text
:
str
,
logprobs
:
Optional
[
LogProbs
]
=
None
,
logprobs
:
Optional
[
LogProbs
]
=
None
,
...
@@ -203,12 +209,21 @@ async def create_completion(request: CompletionRequest):
...
@@ -203,12 +209,21 @@ async def create_completion(request: CompletionRequest):
# Streaming response
# Streaming response
if
stream
:
if
stream
:
background_tasks
=
BackgroundTasks
()
# Abort the request if the client disconnects.
background_tasks
.
add_task
(
abort_request
)
return
StreamingResponse
(
completion_stream_generator
(),
return
StreamingResponse
(
completion_stream_generator
(),
media_type
=
"text/event-stream"
)
media_type
=
"text/event-stream"
,
background
=
background_tasks
)
# Non-streaming response
# Non-streaming response
final_res
:
RequestOutput
=
None
final_res
:
RequestOutput
=
None
async
for
res
in
result_generator
:
async
for
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
server
.
abort
(
request_id
)
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
"Client disconnected"
)
final_res
=
res
final_res
=
res
assert
final_res
is
not
None
assert
final_res
is
not
None
choices
=
[]
choices
=
[]
...
@@ -276,7 +291,7 @@ if __name__ == "__main__":
...
@@ -276,7 +291,7 @@ if __name__ == "__main__":
help
=
"The model name used in the API. If not specified, "
help
=
"The model name used in the API. If not specified, "
"the model name will be the same as the "
"the model name will be the same as the "
"huggingface name."
)
"huggingface name."
)
parser
=
ServerArgs
.
add_cli_args
(
parser
)
parser
=
Async
ServerArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
app
.
add_middleware
(
app
.
add_middleware
(
...
@@ -291,10 +306,11 @@ if __name__ == "__main__":
...
@@ -291,10 +306,11 @@ if __name__ == "__main__":
served_model
=
args
.
served_model_name
or
args
.
model
served_model
=
args
.
served_model_name
or
args
.
model
server_args
=
ServerArgs
.
from_cli_args
(
args
)
server_args
=
Async
ServerArgs
.
from_cli_args
(
args
)
server
=
AsyncLLMServer
.
from_server_args
(
server_args
)
server
=
AsyncLLMServer
.
from_server_args
(
server_args
)
# A separate tokenizer to map token IDs to strings.
# A separate tokenizer to map token IDs to strings.
tokenizer
=
get_tokenizer
(
args
.
model
)
tokenizer
=
get_tokenizer
(
args
.
model
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
,
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
)
cacheflow/entrypoints/simple_fastapi_frontend.py
View file @
1a956e13
...
@@ -2,15 +2,16 @@ import argparse
...
@@ -2,15 +2,16 @@ import argparse
import
json
import
json
from
typing
import
AsyncGenerator
from
typing
import
AsyncGenerator
from
fastapi
import
FastAPI
,
Request
from
fastapi
import
BackgroundTasks
,
FastAPI
,
Request
from
fastapi.responses
import
StreamingResponse
from
fastapi.responses
import
StreamingResponse
import
uvicorn
import
uvicorn
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.arg_utils
import
Async
ServerArgs
from
cacheflow.server.async_llm_server
import
AsyncLLMServer
from
cacheflow.server.async_llm_server
import
AsyncLLMServer
from
cacheflow.
server.ray_
utils
import
initialize_cluster
from
cacheflow.utils
import
random_uuid
TIMEOUT_KEEP_ALIVE
=
5
# seconds.
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds
app
=
FastAPI
()
app
=
FastAPI
()
...
@@ -20,7 +21,8 @@ async def generate_stream(request: Request) -> StreamingResponse:
...
@@ -20,7 +21,8 @@ async def generate_stream(request: Request) -> StreamingResponse:
request_dict
=
await
request
.
json
()
request_dict
=
await
request
.
json
()
prompt
=
request_dict
.
pop
(
"prompt"
)
prompt
=
request_dict
.
pop
(
"prompt"
)
sampling_params
=
SamplingParams
(
**
request_dict
)
sampling_params
=
SamplingParams
(
**
request_dict
)
results_generator
=
server
.
generate
(
prompt
,
sampling_params
)
request_id
=
random_uuid
()
results_generator
=
server
.
generate
(
prompt
,
sampling_params
,
request_id
)
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
async
for
request_output
in
results_generator
:
async
for
request_output
in
results_generator
:
...
@@ -35,17 +37,24 @@ async def generate_stream(request: Request) -> StreamingResponse:
...
@@ -35,17 +37,24 @@ async def generate_stream(request: Request) -> StreamingResponse:
}
}
yield
(
json
.
dumps
(
ret
)
+
"
\0
"
).
encode
(
"utf-8"
)
yield
(
json
.
dumps
(
ret
)
+
"
\0
"
).
encode
(
"utf-8"
)
return
StreamingResponse
(
stream_results
())
async
def
abort_request
()
->
None
:
await
server
.
abort
(
request_id
)
background_tasks
=
BackgroundTasks
()
# Abort the request if the client disconnects.
background_tasks
.
add_task
(
abort_request
)
return
StreamingResponse
(
stream_results
(),
background
=
background_tasks
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8001
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8001
)
parser
=
ServerArgs
.
add_cli_args
(
parser
)
parser
=
Async
ServerArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
server_args
=
ServerArgs
.
from_cli_args
(
args
)
server_args
=
Async
ServerArgs
.
from_cli_args
(
args
)
server
=
AsyncLLMServer
.
from_server_args
(
server_args
)
server
=
AsyncLLMServer
.
from_server_args
(
server_args
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"debug"
,
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
)
cacheflow/sequence.py
View file @
1a956e13
...
@@ -12,12 +12,14 @@ class SequenceStatus(enum.Enum):
...
@@ -12,12 +12,14 @@ class SequenceStatus(enum.Enum):
SWAPPED
=
enum
.
auto
()
SWAPPED
=
enum
.
auto
()
FINISHED_STOPPED
=
enum
.
auto
()
FINISHED_STOPPED
=
enum
.
auto
()
FINISHED_LENGTH_CAPPED
=
enum
.
auto
()
FINISHED_LENGTH_CAPPED
=
enum
.
auto
()
FINISHED_ABORTED
=
enum
.
auto
()
@
staticmethod
@
staticmethod
def
is_finished
(
status
:
"SequenceStatus"
)
->
bool
:
def
is_finished
(
status
:
"SequenceStatus"
)
->
bool
:
return
status
in
[
return
status
in
[
SequenceStatus
.
FINISHED_STOPPED
,
SequenceStatus
.
FINISHED_STOPPED
,
SequenceStatus
.
FINISHED_LENGTH_CAPPED
,
SequenceStatus
.
FINISHED_LENGTH_CAPPED
,
SequenceStatus
.
FINISHED_ABORTED
,
]
]
@
staticmethod
@
staticmethod
...
@@ -26,10 +28,13 @@ class SequenceStatus(enum.Enum):
...
@@ -26,10 +28,13 @@ class SequenceStatus(enum.Enum):
finish_reason
=
"stop"
finish_reason
=
"stop"
elif
status
==
SequenceStatus
.
FINISHED_LENGTH_CAPPED
:
elif
status
==
SequenceStatus
.
FINISHED_LENGTH_CAPPED
:
finish_reason
=
"length"
finish_reason
=
"length"
elif
status
==
SequenceStatus
.
FINISHED_ABORTED
:
finish_reason
=
"abort"
else
:
else
:
finish_reason
=
None
finish_reason
=
None
return
finish_reason
return
finish_reason
class
SequenceData
:
class
SequenceData
:
def
__init__
(
def
__init__
(
...
@@ -137,6 +142,9 @@ class Sequence:
...
@@ -137,6 +142,9 @@ class Sequence:
def
get_cumulative_logprob
(
self
)
->
float
:
def
get_cumulative_logprob
(
self
)
->
float
:
return
self
.
data
.
cumulative_logprob
return
self
.
data
.
cumulative_logprob
def
is_finished
(
self
)
->
bool
:
return
SequenceStatus
.
is_finished
(
self
.
status
)
def
fork
(
self
,
child_seq
:
'Sequence'
)
->
None
:
def
fork
(
self
,
child_seq
:
'Sequence'
)
->
None
:
child_seq
.
logical_token_blocks
=
copy
.
deepcopy
(
self
.
logical_token_blocks
)
child_seq
.
logical_token_blocks
=
copy
.
deepcopy
(
self
.
logical_token_blocks
)
child_seq
.
output_logprobs
=
copy
.
deepcopy
(
self
.
output_logprobs
)
child_seq
.
output_logprobs
=
copy
.
deepcopy
(
self
.
output_logprobs
)
...
@@ -182,7 +190,7 @@ class SequenceGroup:
...
@@ -182,7 +190,7 @@ class SequenceGroup:
raise
ValueError
(
f
'Sequence
{
seq_id
}
not found.'
)
raise
ValueError
(
f
'Sequence
{
seq_id
}
not found.'
)
def
is_finished
(
self
)
->
bool
:
def
is_finished
(
self
)
->
bool
:
return
all
(
S
eq
uenceStatus
.
is_finished
(
seq
.
status
)
for
seq
in
self
.
seqs
)
return
all
(
s
eq
.
is_finished
()
for
seq
in
self
.
seqs
)
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceGroup(request_id=
{
self
.
request_id
}
, "
return
(
f
"SequenceGroup(request_id=
{
self
.
request_id
}
, "
...
...
cacheflow/server/arg_utils.py
View file @
1a956e13
...
@@ -15,7 +15,7 @@ class ServerArgs:
...
@@ -15,7 +15,7 @@ class ServerArgs:
use_dummy_weights
:
bool
=
False
use_dummy_weights
:
bool
=
False
dtype
:
str
=
"default"
dtype
:
str
=
"default"
seed
:
int
=
0
seed
:
int
=
0
use_ray
:
bool
=
False
worker_
use_ray
:
bool
=
False
pipeline_parallel_size
:
int
=
1
pipeline_parallel_size
:
int
=
1
tensor_parallel_size
:
int
=
1
tensor_parallel_size
:
int
=
1
block_size
:
int
=
16
block_size
:
int
=
16
...
@@ -32,7 +32,63 @@ class ServerArgs:
...
@@ -32,7 +32,63 @@ class ServerArgs:
def
add_cli_args
(
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
,
parser
:
argparse
.
ArgumentParser
,
)
->
argparse
.
ArgumentParser
:
)
->
argparse
.
ArgumentParser
:
return
_add_server_arguments
(
parser
)
"""Shared CLI arguments for CacheFlow servers."""
# Model arguments
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'facebook/opt-125m'
,
help
=
'name or path of the huggingface model to use'
)
parser
.
add_argument
(
'--download-dir'
,
type
=
str
,
default
=
ServerArgs
.
download_dir
,
help
=
'directory to download and load the weights, '
'default to the default cache dir of '
'huggingface'
)
parser
.
add_argument
(
'--use-np-weights'
,
action
=
'store_true'
,
help
=
'save a numpy copy of model weights for '
'faster loading. This can increase the disk '
'usage by up to 2x.'
)
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
# TODO(woosuk): Support FP32.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
ServerArgs
.
dtype
,
choices
=
[
'default'
,
'half'
,
'bfloat16'
],
help
=
'data type for model weights and activations. '
'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'
)
# Parallel arguments
parser
.
add_argument
(
'--worker-use-ray'
,
action
=
'store_true'
,
help
=
'use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU'
)
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
type
=
int
,
default
=
ServerArgs
.
pipeline_parallel_size
,
help
=
'number of pipeline stages'
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
ServerArgs
.
tensor_parallel_size
,
help
=
'number of tensor parallel replicas'
)
# KV cache arguments
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
ServerArgs
.
block_size
,
choices
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
],
help
=
'token block size'
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
ServerArgs
.
seed
,
help
=
'random seed'
)
parser
.
add_argument
(
'--swap-space'
,
type
=
int
,
default
=
ServerArgs
.
swap_space
,
help
=
'CPU swap space size (GiB) per GPU'
)
parser
.
add_argument
(
'--gpu-memory-utilization'
,
type
=
float
,
default
=
ServerArgs
.
gpu_memory_utilization
,
help
=
'the percentage of GPU memory to be used for'
'the model executor'
)
parser
.
add_argument
(
'--max-num-batched-tokens'
,
type
=
int
,
default
=
ServerArgs
.
max_num_batched_tokens
,
help
=
'maximum number of batched tokens per '
'iteration'
)
parser
.
add_argument
(
'--max-num-seqs'
,
type
=
int
,
default
=
ServerArgs
.
max_num_seqs
,
help
=
'maximum number of sequences per iteration'
)
parser
.
add_argument
(
'--disable-log-stats'
,
action
=
'store_true'
,
help
=
'disable logging statistics'
)
return
parser
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
)
->
"ServerArgs"
:
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
)
->
"ServerArgs"
:
...
@@ -53,65 +109,22 @@ class ServerArgs:
...
@@ -53,65 +109,22 @@ class ServerArgs:
self
.
swap_space
)
self
.
swap_space
)
parallel_config
=
ParallelConfig
(
self
.
pipeline_parallel_size
,
parallel_config
=
ParallelConfig
(
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
,
self
.
tensor_parallel_size
,
self
.
use_ray
)
self
.
worker_
use_ray
)
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
self
.
max_num_seqs
)
self
.
max_num_seqs
)
return
model_config
,
cache_config
,
parallel_config
,
scheduler_config
return
model_config
,
cache_config
,
parallel_config
,
scheduler_config
def
_add_server_arguments
(
@
dataclass
parser
:
argparse
.
ArgumentParser
,
class
AsyncServerArgs
(
ServerArgs
):
)
->
argparse
.
ArgumentParser
:
server_use_ray
:
bool
=
False
"""Shared CLI arguments for CacheFlow servers."""
# Model arguments
@
staticmethod
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'facebook/opt-125m'
,
def
add_cli_args
(
help
=
'name or path of the huggingface model to use'
)
parser
:
argparse
.
ArgumentParser
,
parser
.
add_argument
(
'--download-dir'
,
type
=
str
,
)
->
argparse
.
ArgumentParser
:
default
=
ServerArgs
.
download_dir
,
parser
=
ServerArgs
.
add_cli_args
(
parser
)
help
=
'directory to download and load the weights, '
parser
.
add_argument
(
'--server-use-ray'
,
action
=
'store_true'
,
'default to the default cache dir of huggingface'
)
help
=
'use Ray to start the LLM server in a '
parser
.
add_argument
(
'--use-np-weights'
,
action
=
'store_true'
,
'separate process as the web server process.'
)
help
=
'save a numpy copy of model weights for faster '
return
parser
'loading. This can increase the disk usage by up '
'to 2x.'
)
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
# TODO(woosuk): Support FP32.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
ServerArgs
.
dtype
,
choices
=
[
'default'
,
'half'
,
'bfloat16'
],
help
=
(
'data type for model weights and activations. '
'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'
))
# Parallel arguments
parser
.
add_argument
(
'--use-ray'
,
action
=
'store_true'
,
help
=
'use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU'
)
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
type
=
int
,
default
=
ServerArgs
.
pipeline_parallel_size
,
help
=
'number of pipeline stages'
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
ServerArgs
.
tensor_parallel_size
,
help
=
'number of tensor parallel replicas'
)
# KV cache arguments
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
ServerArgs
.
block_size
,
choices
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
],
help
=
'token block size'
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
ServerArgs
.
seed
,
help
=
'random seed'
)
parser
.
add_argument
(
'--swap-space'
,
type
=
int
,
default
=
ServerArgs
.
swap_space
,
help
=
'CPU swap space size (GiB) per GPU'
)
parser
.
add_argument
(
'--gpu-memory-utilization'
,
type
=
float
,
default
=
ServerArgs
.
gpu_memory_utilization
,
help
=
'the percentage of GPU memory to be used for the '
'model executor'
)
parser
.
add_argument
(
'--max-num-batched-tokens'
,
type
=
int
,
default
=
ServerArgs
.
max_num_batched_tokens
,
help
=
'maximum number of batched tokens per iteration'
)
parser
.
add_argument
(
'--max-num-seqs'
,
type
=
int
,
default
=
ServerArgs
.
max_num_seqs
,
help
=
'maximum number of sequences per iteration'
)
parser
.
add_argument
(
'--disable-log-stats'
,
action
=
'store_true'
,
help
=
'disable logging statistics'
)
return
parser
cacheflow/server/async_llm_server.py
View file @
1a956e13
...
@@ -2,37 +2,52 @@ import asyncio
...
@@ -2,37 +2,52 @@ import asyncio
import
time
import
time
from
typing
import
Dict
,
Optional
from
typing
import
Dict
,
Optional
import
ray
from
cacheflow.logger
import
init_logger
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.arg_utils
import
Async
ServerArgs
from
cacheflow.server.llm_server
import
LLMServer
from
cacheflow.server.llm_server
import
LLMServer
from
cacheflow.server.ray_utils
import
initialize_cluster
from
cacheflow.server.ray_utils
import
ray
,
initialize_cluster
from
cacheflow.utils
import
random_uuid
logger
=
init_logger
(
__name__
)
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds
class
AsyncLLMServer
:
class
AsyncLLMServer
:
def
__init__
(
self
,
server_use_ray
:
bool
,
*
args
,
**
kwargs
)
->
None
:
def
__init__
(
self
,
worker_use_ray
:
bool
,
server_use_ray
:
bool
,
if
server_use_ray
:
*
args
,
**
kwargs
)
->
None
:
remote_server_class
=
ray
.
remote
(
num_cpus
=
0
)(
LLMServer
)
self
.
worker_use_ray
=
worker_use_ray
self
.
server_use_ray
=
server_use_ray
if
not
self
.
server_use_ray
:
server_class
=
LLMServer
elif
self
.
worker_use_ray
:
server_class
=
ray
.
remote
(
num_cpus
=
0
)(
LLMServer
).
remote
else
:
else
:
remote_server_class
=
ray
.
remote
(
num_gpus
=
1
)(
LLMServer
)
server_class
=
ray
.
remote
(
num_gpus
=
1
)(
LLMServer
).
remote
self
.
server
=
remote_server_class
.
remote
(
*
args
,
**
kwargs
)
self
.
server
=
server_class
(
*
args
,
**
kwargs
)
# Request id -> request output.
# Request id -> request output.
self
.
request_outputs
:
Dict
[
str
,
RequestOutput
]
=
{}
self
.
request_outputs
:
Dict
[
str
,
RequestOutput
]
=
{}
# Request id -> event to notify that there is new output.
# Request id -> event to notify that there is new output.
self
.
request_events
:
Dict
[
str
,
asyncio
.
Event
]
=
{}
self
.
request_events
:
Dict
[
str
,
asyncio
.
Event
]
=
{}
self
.
is_server_running
=
False
self
.
is_server_running
=
False
self
.
kicking_request_id
:
Optional
[
str
]
=
None
async
def
server_step
(
self
):
async
def
server_step
(
self
,
kicking_request_id
:
Optional
[
str
]
=
None
):
self
.
is_server_running
=
True
self
.
is_server_running
=
True
request_outputs
=
await
self
.
server
.
step
.
remote
()
self
.
kicking_request_id
=
kicking_request_id
if
self
.
server_use_ray
:
request_outputs
=
await
self
.
server
.
step
.
remote
()
else
:
# Yield to the event loop to allow other coroutines to run
# while is_server_running is True. This let the server to add new
# requests into the queue.
await
asyncio
.
sleep
(
0
)
request_outputs
=
self
.
server
.
step
()
self
.
is_server_running
=
False
self
.
is_server_running
=
False
self
.
kicking_request_id
=
None
# Notify the waiting coroutines that there are new outputs ready.
# Notify the waiting coroutines that there are new outputs ready.
for
request_output
in
request_outputs
:
for
request_output
in
request_outputs
:
request_id
=
request_output
.
request_id
request_id
=
request_output
.
request_id
...
@@ -40,20 +55,26 @@ class AsyncLLMServer:
...
@@ -40,20 +55,26 @@ class AsyncLLMServer:
self
.
request_events
[
request_id
].
set
()
self
.
request_events
[
request_id
].
set
()
async
def
generate
(
self
,
prompt
:
str
,
sampling_params
:
SamplingParams
,
async
def
generate
(
self
,
prompt
:
str
,
sampling_params
:
SamplingParams
,
request_id
:
Optional
[
str
]
=
None
)
->
RequestOutput
:
request_id
:
str
)
->
RequestOutput
:
# Preprocess the request.
# Preprocess the request.
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
# Create an event to notify us that there is new output from the
# Create an event to notify us that there is new output from the
# cacheflow server.
# cacheflow server.
if
request_id
is
None
:
request_id
=
random_uuid
()
request_event
=
asyncio
.
Event
()
request_event
=
asyncio
.
Event
()
self
.
request_events
[
request_id
]
=
request_event
self
.
request_events
[
request_id
]
=
request_event
logger
.
info
(
f
"Received request
{
request_id
}
: "
f
"prompt:
{
prompt
!
r
}
, "
f
"sampling params:
{
sampling_params
}
."
)
# Add the request into the cacheflow server's waiting queue.
# Add the request into the cacheflow server's waiting queue.
await
self
.
server
.
add_request
.
remote
(
if
self
.
server_use_ray
:
request_id
,
prompt
,
sampling_params
,
arrival_time
=
arrival_time
)
await
self
.
server
.
add_request
.
remote
(
request_id
,
prompt
,
sampling_params
,
arrival_time
=
arrival_time
)
else
:
self
.
server
.
add_request
(
request_id
,
prompt
,
sampling_params
,
arrival_time
=
arrival_time
)
# The cacheflow server does not have a background loop that keeps
# The cacheflow server does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking
# processing incoming requests. Therefore, we need to keep kicking
...
@@ -61,7 +82,7 @@ class AsyncLLMServer:
...
@@ -61,7 +82,7 @@ class AsyncLLMServer:
while
True
:
while
True
:
# Kick the server if the server is not running.
# Kick the server if the server is not running.
if
not
self
.
is_server_running
:
if
not
self
.
is_server_running
:
await
self
.
server_step
()
await
self
.
server_step
(
request_id
)
# Wait for new output. The group_event will be set in server_step
# Wait for new output. The group_event will be set in server_step
# when there is new output available for the sequence group.
# when there is new output available for the sequence group.
...
@@ -80,6 +101,8 @@ class AsyncLLMServer:
...
@@ -80,6 +101,8 @@ class AsyncLLMServer:
# Once finished, release the resources of the sequence group.
# Once finished, release the resources of the sequence group.
if
request_output
.
finished
():
if
request_output
.
finished
():
logger
.
info
(
f
"Finished request
{
request_id
}
."
)
del
self
.
request_outputs
[
request_id
]
del
self
.
request_outputs
[
request_id
]
del
self
.
request_events
[
request_id
]
del
self
.
request_events
[
request_id
]
# Kick the server if the server is not running. This is to
# Kick the server if the server is not running. This is to
...
@@ -89,15 +112,41 @@ class AsyncLLMServer:
...
@@ -89,15 +112,41 @@ class AsyncLLMServer:
await
self
.
server_step
()
await
self
.
server_step
()
break
break
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
if
request_id
not
in
self
.
request_events
:
# The request has already finished or been aborted.
return
logger
.
info
(
f
"Aborted request
{
request_id
}
."
)
if
self
.
server_use_ray
:
await
self
.
server
.
abort_request
.
remote
(
request_id
)
else
:
self
.
server
.
abort_request
(
request_id
)
if
request_id
in
self
.
request_events
:
del
self
.
request_events
[
request_id
]
if
request_id
in
self
.
request_outputs
:
del
self
.
request_outputs
[
request_id
]
# To prevent deadlock when a request is aborted while the server is
# running.
if
self
.
kicking_request_id
==
request_id
:
self
.
is_server_running
=
False
self
.
kicking_request_id
=
None
@
classmethod
@
classmethod
def
from_server_args
(
cls
,
server_args
:
ServerArgs
)
->
"AsyncLLMServer"
:
def
from_server_args
(
cls
,
server_args
:
Async
ServerArgs
)
->
"AsyncLLMServer"
:
# Create the server configs.
# Create the server configs.
server_configs
=
server_args
.
create_server_configs
()
server_configs
=
server_args
.
create_server_configs
()
parallel_config
=
server_configs
[
2
]
parallel_config
=
server_configs
[
2
]
# Initialize the cluster.
# Initialize the cluster.
distributed_init_method
,
devices
=
initialize_cluster
(
parallel_config
)
distributed_init_method
,
devices
=
initialize_cluster
(
parallel_config
,
server_args
.
server_use_ray
)
# Create the LLM server.
# Create the LLM server.
server
=
cls
(
server_args
.
use_ray
,
*
server_configs
,
server
=
cls
(
server_args
.
worker_use_ray
,
server_args
.
server_use_ray
,
*
server_configs
,
distributed_init_method
,
devices
,
distributed_init_method
,
devices
,
log_stats
=
not
server_args
.
disable_log_stats
)
log_stats
=
not
server_args
.
disable_log_stats
)
return
server
return
server
cacheflow/server/llm_server.py
View file @
1a956e13
import
time
import
time
from
typing
import
Any
,
List
,
Optional
from
typing
import
Any
,
List
,
Optional
try
:
import
ray
except
ImportError
:
ray
=
None
from
cacheflow.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
from
cacheflow.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
SchedulerConfig
)
from
cacheflow.core.scheduler
import
Scheduler
from
cacheflow.core.scheduler
import
Scheduler
...
@@ -13,7 +8,7 @@ from cacheflow.logger import init_logger
...
@@ -13,7 +8,7 @@ from cacheflow.logger import init_logger
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.ray_utils
import
initialize_cluster
from
cacheflow.server.ray_utils
import
ray
,
initialize_cluster
from
cacheflow.server.tokenizer_utils
import
(
get_tokenizer
,
from
cacheflow.server.tokenizer_utils
import
(
get_tokenizer
,
detokenize_incrementally
)
detokenize_incrementally
)
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
...
@@ -62,7 +57,7 @@ class LLMServer:
...
@@ -62,7 +57,7 @@ class LLMServer:
assert
len
(
stage_devices
)
==
1
,
"Only support one stage for now."
assert
len
(
stage_devices
)
==
1
,
"Only support one stage for now."
for
rank
,
node_resource
,
_
in
stage_devices
[
0
]:
for
rank
,
node_resource
,
_
in
stage_devices
[
0
]:
worker_cls
=
Worker
worker_cls
=
Worker
if
self
.
parallel_config
.
use_ray
:
if
self
.
parallel_config
.
worker_
use_ray
:
worker_cls
=
ray
.
remote
(
worker_cls
=
ray
.
remote
(
num_cpus
=
0
,
num_cpus
=
0
,
num_gpus
=
1
,
num_gpus
=
1
,
...
@@ -152,6 +147,9 @@ class LLMServer:
...
@@ -152,6 +147,9 @@ class LLMServer:
# Add the sequence group to the scheduler.
# Add the sequence group to the scheduler.
self
.
scheduler
.
add_seq_group
(
seq_group
)
self
.
scheduler
.
add_seq_group
(
seq_group
)
def
abort_request
(
self
,
request_id
:
str
)
->
None
:
self
.
scheduler
.
abort_seq_group
(
request_id
)
def
get_num_unfinished_requests
(
self
)
->
int
:
def
get_num_unfinished_requests
(
self
)
->
int
:
return
self
.
scheduler
.
get_num_unfinished_seq_groups
()
return
self
.
scheduler
.
get_num_unfinished_seq_groups
()
...
@@ -243,13 +241,13 @@ class LLMServer:
...
@@ -243,13 +241,13 @@ class LLMServer:
all_outputs
=
[]
all_outputs
=
[]
for
worker
in
self
.
workers
:
for
worker
in
self
.
workers
:
executor
=
getattr
(
worker
,
method
)
executor
=
getattr
(
worker
,
method
)
if
self
.
parallel_config
.
use_ray
:
if
self
.
parallel_config
.
worker_
use_ray
:
executor
=
executor
.
remote
executor
=
executor
.
remote
output
=
executor
(
*
args
,
**
kwargs
)
output
=
executor
(
*
args
,
**
kwargs
)
all_outputs
.
append
(
output
)
all_outputs
.
append
(
output
)
if
self
.
parallel_config
.
use_ray
:
if
self
.
parallel_config
.
worker_
use_ray
:
all_outputs
=
ray
.
get
(
all_outputs
)
all_outputs
=
ray
.
get
(
all_outputs
)
if
get_all_outputs
:
if
get_all_outputs
:
...
...
cacheflow/server/ray_utils.py
View file @
1a956e13
...
@@ -13,9 +13,18 @@ DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), devi
...
@@ -13,9 +13,18 @@ DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), devi
def
initialize_cluster
(
def
initialize_cluster
(
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
server_use_ray
:
bool
=
False
,
address
:
Optional
[
str
]
=
None
,
address
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
str
,
List
[
List
[
DeviceID
]]]:
)
->
Tuple
[
str
,
List
[
List
[
DeviceID
]]]:
if
not
parallel_config
.
use_ray
:
if
parallel_config
.
worker_use_ray
or
server_use_ray
:
if
ray
is
None
:
raise
ImportError
(
"Ray is not installed. Please install Ray to use distributed "
"serving."
)
# Connect to a ray cluster.
ray
.
init
(
address
=
address
)
if
not
parallel_config
.
worker_use_ray
:
# Initialize cluster locally.
# Initialize cluster locally.
port
=
random
.
randint
(
10000
,
20000
)
port
=
random
.
randint
(
10000
,
20000
)
# We need to setup the distributed init method to make sure
# We need to setup the distributed init method to make sure
...
@@ -24,13 +33,6 @@ def initialize_cluster(
...
@@ -24,13 +33,6 @@ def initialize_cluster(
all_stage_devices
=
[[(
0
,
None
,
0
)]]
all_stage_devices
=
[[(
0
,
None
,
0
)]]
return
distributed_init_method
,
all_stage_devices
return
distributed_init_method
,
all_stage_devices
if
ray
is
None
:
raise
ImportError
(
"Ray is not installed. Please install Ray to use distributed "
"serving."
)
# Connect to a ray cluster.
ray
.
init
(
address
=
address
)
# Assume we have a uniform cluster that each node has the same number of
# Assume we have a uniform cluster that each node has the same number of
# GPUs for now.
# GPUs for now.
valid_node_resources
=
[]
valid_node_resources
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment