Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
8917782a
Unverified
Commit
8917782a
authored
May 08, 2023
by
Woosuk Kwon
Committed by
GitHub
May 08, 2023
Browse files
Add a system logger (#85)
parent
7addca59
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
85 additions
and
13 deletions
+85
-13
cacheflow/logger.py
cacheflow/logger.py
+51
-0
cacheflow/master/server.py
cacheflow/master/server.py
+16
-3
cacheflow/master/simple_frontend.py
cacheflow/master/simple_frontend.py
+6
-2
cacheflow/models/memory_analyzer.py
cacheflow/models/memory_analyzer.py
+11
-7
simple_server.py
simple_server.py
+1
-1
No files found.
cacheflow/logger.py
0 → 100644
View file @
8917782a
# Adapted from https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
import
logging
import
sys
_FORMAT
=
"%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
_DATE_FORMAT
=
"%m-%d %H:%M:%S"
class
NewLineFormatter
(
logging
.
Formatter
):
"""Adds logging prefix to newlines to align multi-line messages."""
def
__init__
(
self
,
fmt
,
datefmt
=
None
):
logging
.
Formatter
.
__init__
(
self
,
fmt
,
datefmt
)
def
format
(
self
,
record
):
msg
=
logging
.
Formatter
.
format
(
self
,
record
)
if
record
.
message
!=
""
:
parts
=
msg
.
split
(
record
.
message
)
msg
=
msg
.
replace
(
"
\n
"
,
"
\r\n
"
+
parts
[
0
])
return
msg
_root_logger
=
logging
.
getLogger
(
"cacheflow"
)
_default_handler
=
None
def
_setup_logger
():
_root_logger
.
setLevel
(
logging
.
DEBUG
)
global
_default_handler
if
_default_handler
is
None
:
_default_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
_default_handler
.
flush
=
sys
.
stdout
.
flush
# type: ignore
_default_handler
.
setLevel
(
logging
.
INFO
)
_root_logger
.
addHandler
(
_default_handler
)
fmt
=
NewLineFormatter
(
_FORMAT
,
datefmt
=
_DATE_FORMAT
)
_default_handler
.
setFormatter
(
fmt
)
# Setting this will avoid the message
# being propagated to the parent logger.
_root_logger
.
propagate
=
False
# The logger is initialized when the module is imported.
# This is thread-safe as the module is only imported once,
# guaranteed by the Python GIL.
_setup_logger
()
def
init_logger
(
name
:
str
):
return
logging
.
getLogger
(
name
)
cacheflow/master/server.py
View file @
8917782a
...
@@ -8,6 +8,7 @@ try:
...
@@ -8,6 +8,7 @@ try:
except
ImportError
:
except
ImportError
:
ray
=
None
ray
=
None
from
cacheflow.logger
import
init_logger
from
cacheflow.master.scheduler
import
Scheduler
from
cacheflow.master.scheduler
import
Scheduler
from
cacheflow.master.simple_frontend
import
SimpleFrontend
from
cacheflow.master.simple_frontend
import
SimpleFrontend
from
cacheflow.models
import
get_memory_analyzer
from
cacheflow.models
import
get_memory_analyzer
...
@@ -17,6 +18,9 @@ from cacheflow.sampling_params import SamplingParams
...
@@ -17,6 +18,9 @@ from cacheflow.sampling_params import SamplingParams
from
cacheflow.utils
import
get_gpu_memory
,
get_cpu_memory
from
cacheflow.utils
import
get_gpu_memory
,
get_cpu_memory
logger
=
init_logger
(
__name__
)
class
Server
:
class
Server
:
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -42,6 +46,17 @@ class Server:
...
@@ -42,6 +46,17 @@ class Server:
collect_stats
:
bool
=
False
,
collect_stats
:
bool
=
False
,
do_memory_analysis
:
bool
=
False
,
do_memory_analysis
:
bool
=
False
,
):
):
logger
.
info
(
"Initializing a server with config: "
f
"model=
{
model
!
r
}
, "
f
"dtype=
{
dtype
}
, "
f
"use_dummy_weights=
{
use_dummy_weights
}
, "
f
"cache_dir=
{
cache_dir
}
, "
f
"use_np_cache=
{
use_np_cache
}
, "
f
"tensor_parallel_size=
{
tensor_parallel_size
}
, "
f
"block_size=
{
block_size
}
, "
f
"seed=
{
seed
}
)"
)
self
.
num_nodes
=
num_nodes
self
.
num_nodes
=
num_nodes
self
.
num_devices_per_node
=
num_devices_per_node
self
.
num_devices_per_node
=
num_devices_per_node
self
.
world_size
=
pipeline_parallel_size
*
tensor_parallel_size
self
.
world_size
=
pipeline_parallel_size
*
tensor_parallel_size
...
@@ -61,9 +76,7 @@ class Server:
...
@@ -61,9 +76,7 @@ class Server:
self
.
num_gpu_blocks
=
self
.
memory_analyzer
.
get_max_num_gpu_blocks
(
self
.
num_gpu_blocks
=
self
.
memory_analyzer
.
get_max_num_gpu_blocks
(
max_num_batched_tokens
=
max_num_batched_tokens
)
max_num_batched_tokens
=
max_num_batched_tokens
)
self
.
num_cpu_blocks
=
self
.
memory_analyzer
.
get_max_num_cpu_blocks
(
self
.
num_cpu_blocks
=
self
.
memory_analyzer
.
get_max_num_cpu_blocks
(
swap_space
=
swap_space
)
swap_space_gib
=
swap_space
)
print
(
f
'# GPU blocks:
{
self
.
num_gpu_blocks
}
, '
f
'# CPU blocks:
{
self
.
num_cpu_blocks
}
'
)
# Create a controller for each pipeline stage.
# Create a controller for each pipeline stage.
self
.
controllers
:
List
[
Controller
]
=
[]
self
.
controllers
:
List
[
Controller
]
=
[]
...
...
cacheflow/master/simple_frontend.py
View file @
8917782a
import
time
import
time
from
typing
import
List
,
Optional
,
Set
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
cacheflow.logger
import
init_logger
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
from
cacheflow.utils
import
Counter
from
cacheflow.utils
import
Counter
logger
=
init_logger
(
__name__
)
class
SimpleFrontend
:
class
SimpleFrontend
:
def
__init__
(
def
__init__
(
...
@@ -66,4 +70,4 @@ class SimpleFrontend:
...
@@ -66,4 +70,4 @@ class SimpleFrontend:
token_ids
=
seq
.
get_token_ids
()
token_ids
=
seq
.
get_token_ids
()
output
=
self
.
tokenizer
.
decode
(
token_ids
,
skip_special_tokens
=
True
)
output
=
self
.
tokenizer
.
decode
(
token_ids
,
skip_special_tokens
=
True
)
output
=
output
.
strip
()
output
=
output
.
strip
()
print
(
f
'
Seq
{
seq
.
seq_id
}
:
{
output
!
r
}
'
)
logger
.
info
(
f
"
Seq
{
seq
.
seq_id
}
:
{
output
!
r
}
"
)
cacheflow/models/memory_analyzer.py
View file @
8917782a
import
torch
import
torch
from
transformers
import
AutoConfig
from
transformers
import
AutoConfig
from
cacheflow.logger
import
init_logger
from
cacheflow.models.utils
import
get_dtype_size
from
cacheflow.models.utils
import
get_dtype_size
logger
=
init_logger
(
__name__
)
_GiB
=
1
<<
30
_GiB
=
1
<<
30
...
@@ -23,20 +27,20 @@ class CacheFlowMemoryAnalyzer:
...
@@ -23,20 +27,20 @@ class CacheFlowMemoryAnalyzer:
def
get_max_num_cpu_blocks
(
def
get_max_num_cpu_blocks
(
self
,
self
,
swap_space
:
int
,
swap_space
_gib
:
int
,
)
->
int
:
)
->
int
:
swap_space
=
swap_space
*
_GiB
swap_space
=
swap_space
_gib
*
_GiB
cpu_memory
=
self
.
cpu_memory
cpu_memory
=
self
.
cpu_memory
if
swap_space
>
0.8
*
cpu_memory
:
if
swap_space
>
0.8
*
cpu_memory
:
raise
ValueError
(
f
'The swap space (
{
swap_space
/
_GiB
:.
2
f
}
GiB) '
raise
ValueError
(
f
'The swap space (
{
swap_space
_gib
:.
2
f
}
GiB) '
'takes more than 80% of the available memory '
'takes more than 80% of the available memory '
f
'(
{
cpu_memory
/
_GiB
:.
2
f
}
GiB).'
f
'(
{
cpu_memory
/
_GiB
:.
2
f
}
GiB).'
'Please check the swap space size.'
)
'Please check the swap space size.'
)
if
swap_space
>
0.5
*
cpu_memory
:
if
swap_space
>
0.5
*
cpu_memory
:
print
(
f
'WARNING: The swap space (
{
swap_space
/
_GiB
:.
2
f
}
GiB) '
logger
.
info
(
f
'WARNING: The swap space (
{
swap_space
_gib
:.
2
f
}
GiB) '
'takes more than 50% of the available memory '
'takes more than 50% of the available memory '
f
'(
{
cpu_memory
/
_GiB
:.
2
f
}
GiB).'
f
'(
{
cpu_memory
/
_GiB
:.
2
f
}
GiB).'
'This may slow the system performance.'
)
'This may slow the system performance.'
)
max_num_blocks
=
swap_space
//
self
.
get_cache_block_size
()
max_num_blocks
=
swap_space
//
self
.
get_cache_block_size
()
return
max_num_blocks
return
max_num_blocks
...
...
simple_server.py
View file @
8917782a
import
argparse
import
argparse
from
typing
import
List
from
cacheflow.master.server
import
(
from
cacheflow.master.server
import
(
add_server_arguments
,
process_server_arguments
,
add_server_arguments
,
process_server_arguments
,
init_local_server_and_frontend_with_arguments
)
init_local_server_and_frontend_with_arguments
)
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
def
main
(
args
:
argparse
.
Namespace
):
def
main
(
args
:
argparse
.
Namespace
):
server
,
frontend
=
init_local_server_and_frontend_with_arguments
(
args
)
server
,
frontend
=
init_local_server_and_frontend_with_arguments
(
args
)
# Test the following inputs.
# Test the following inputs.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment