Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
75f89dc4
Unverified
Commit
75f89dc4
authored
Dec 10, 2024
by
youkaichao
Committed by
GitHub
Dec 10, 2024
Browse files
[torch.compile] add a flag to track batchsize statistics (#11059)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
e7391949
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
37 additions
and
1 deletion
+37
-1
vllm/envs.py
vllm/envs.py
+3
-0
vllm/forward_context.py
vllm/forward_context.py
+31
-1
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+1
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+2
-0
No files found.
vllm/envs.py
View file @
75f89dc4
...
@@ -69,6 +69,7 @@ if TYPE_CHECKING:
...
@@ -69,6 +69,7 @@ if TYPE_CHECKING:
VLLM_DISABLED_KERNELS
:
List
[
str
]
=
[]
VLLM_DISABLED_KERNELS
:
List
[
str
]
=
[]
VLLM_USE_V1
:
bool
=
False
VLLM_USE_V1
:
bool
=
False
VLLM_ENABLE_V1_MULTIPROCESSING
:
bool
=
False
VLLM_ENABLE_V1_MULTIPROCESSING
:
bool
=
False
VLLM_LOG_BATCHSIZE_INTERVAL
:
float
=
-
1
def
get_default_cache_root
():
def
get_default_cache_root
():
...
@@ -452,6 +453,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -452,6 +453,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# If set, enable multiprocessing in LLM for the V1 code path.
# If set, enable multiprocessing in LLM for the V1 code path.
"VLLM_ENABLE_V1_MULTIPROCESSING"
:
"VLLM_ENABLE_V1_MULTIPROCESSING"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
"0"
))),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
"0"
))),
"VLLM_LOG_BATCHSIZE_INTERVAL"
:
lambda
:
float
(
os
.
getenv
(
"VLLM_LOG_BATCHSIZE_INTERVAL"
,
"-1"
)),
}
}
# end-env-vars-definition
# end-env-vars-definition
...
...
vllm/forward_context.py
View file @
75f89dc4
import
time
from
collections
import
Counter
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
Optional
from
typing
import
Any
,
Dict
,
Optional
import
vllm.envs
as
envs
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
track_batchsize
:
bool
=
envs
.
VLLM_LOG_BATCHSIZE_INTERVAL
>=
0
batchsize_counter
:
Counter
=
Counter
()
last_logging_time
:
float
=
0
batchsize_logging_interval
:
float
=
envs
.
VLLM_LOG_BATCHSIZE_INTERVAL
@
dataclass
@
dataclass
...
@@ -26,7 +37,26 @@ def get_forward_context() -> ForwardContext:
...
@@ -26,7 +37,26 @@ def get_forward_context() -> ForwardContext:
@
contextmanager
@
contextmanager
def
set_forward_context
(
context
:
Any
,
vllm_config
:
VllmConfig
):
def
set_forward_context
(
context
:
Any
,
vllm_config
:
VllmConfig
):
"""A context manager that stores the current forward context,
"""A context manager that stores the current forward context,
can be attention metadata, etc."""
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
"""
global
track_batchsize
,
batchsize_counter
global
last_logging_time
,
batchsize_logging_interval
if
track_batchsize
and
context
is
not
None
:
if
hasattr
(
context
,
"num_prefill_tokens"
):
# for v0 attention backends
batchsize
=
context
.
num_prefill_tokens
+
context
.
num_decode_tokens
else
:
# for v1 attention backends
batchsize
=
context
.
num_input_tokens
batchsize_counter
[
batchsize
]
+=
1
if
time
.
monotonic
()
-
last_logging_time
>
batchsize_logging_interval
:
last_logging_time
=
time
.
monotonic
()
sorted_data
=
sorted
(
batchsize_counter
.
items
(),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
logger
.
info
(
"Batchsize distribution (batchsize, count): %s"
,
sorted_data
)
global
_forward_context
global
_forward_context
prev_context
=
_forward_context
prev_context
=
_forward_context
_forward_context
=
ForwardContext
(
_forward_context
=
ForwardContext
(
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
75f89dc4
...
@@ -56,6 +56,7 @@ class FlashAttentionMetadata:
...
@@ -56,6 +56,7 @@ class FlashAttentionMetadata:
seq_start_loc
:
torch
.
Tensor
seq_start_loc
:
torch
.
Tensor
block_table
:
torch
.
Tensor
block_table
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
num_input_tokens
:
int
=
0
# Number of tokens including padding.
class
FlashAttentionImpl
(
AttentionImpl
):
class
FlashAttentionImpl
(
AttentionImpl
):
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
75f89dc4
...
@@ -445,6 +445,8 @@ class GPUModelRunner:
...
@@ -445,6 +445,8 @@ class GPUModelRunner:
# Eager mode.
# Eager mode.
num_input_tokens
=
num_scheduled_tokens
num_input_tokens
=
num_scheduled_tokens
attn_metadata
.
num_input_tokens
=
num_input_tokens
# Get the inputs embeds.
# Get the inputs embeds.
if
encoder_outputs
:
if
encoder_outputs
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment