Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a1c02058
Unverified
Commit
a1c02058
authored
Dec 14, 2024
by
youkaichao
Committed by
GitHub
Dec 14, 2024
Browse files
[torch.compile] allow tracking forward time (#11081)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
15859f23
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
19 deletions
+42
-19
vllm/forward_context.py
vllm/forward_context.py
+42
-19
No files found.
vllm/forward_context.py
View file @
a1c02058
import
time
import
time
from
collections
import
Counter
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
Optional
from
typing
import
Any
,
Dict
,
Optional
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -11,9 +13,10 @@ from vllm.logger import init_logger
...
@@ -11,9 +13,10 @@ from vllm.logger import init_logger
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
track_batchsize
:
bool
=
envs
.
VLLM_LOG_BATCHSIZE_INTERVAL
>=
0
track_batchsize
:
bool
=
envs
.
VLLM_LOG_BATCHSIZE_INTERVAL
>=
0
batchsize_counter
:
Counter
=
Counter
()
last_logging_time
:
float
=
0
last_logging_time
:
float
=
0
forward_start_time
:
float
=
0
batchsize_logging_interval
:
float
=
envs
.
VLLM_LOG_BATCHSIZE_INTERVAL
batchsize_logging_interval
:
float
=
envs
.
VLLM_LOG_BATCHSIZE_INTERVAL
batchsize_forward_time
:
defaultdict
=
defaultdict
(
list
)
@
dataclass
@
dataclass
...
@@ -40,23 +43,10 @@ def set_forward_context(context: Any, vllm_config: VllmConfig):
...
@@ -40,23 +43,10 @@ def set_forward_context(context: Any, vllm_config: VllmConfig):
can be attention metadata, etc.
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
Here we can inject common logic for every model forward pass.
"""
"""
global
track_batchsize
,
batchsize_counter
global
forward_start_time
global
last_logging_time
,
batchsize_logging_interval
need_to_track_batchsize
=
track_batchsize
and
context
is
not
None
if
track_batchsize
and
context
is
not
None
:
if
need_to_track_batchsize
:
if
hasattr
(
context
,
"num_prefill_tokens"
):
forward_start_time
=
time
.
perf_counter
()
# for v0 attention backends
batchsize
=
context
.
num_prefill_tokens
+
context
.
num_decode_tokens
else
:
# for v1 attention backends
batchsize
=
context
.
num_input_tokens
batchsize_counter
[
batchsize
]
+=
1
if
time
.
monotonic
()
-
last_logging_time
>
batchsize_logging_interval
:
last_logging_time
=
time
.
monotonic
()
sorted_data
=
sorted
(
batchsize_counter
.
items
(),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
logger
.
info
(
"Batchsize distribution (batchsize, count): %s"
,
sorted_data
)
global
_forward_context
global
_forward_context
prev_context
=
_forward_context
prev_context
=
_forward_context
_forward_context
=
ForwardContext
(
_forward_context
=
ForwardContext
(
...
@@ -66,4 +56,37 @@ def set_forward_context(context: Any, vllm_config: VllmConfig):
...
@@ -66,4 +56,37 @@ def set_forward_context(context: Any, vllm_config: VllmConfig):
try
:
try
:
yield
yield
finally
:
finally
:
global
batchsize_counter
global
last_logging_time
,
batchsize_logging_interval
if
need_to_track_batchsize
:
if
hasattr
(
context
,
"num_prefill_tokens"
):
# for v0 attention backends
batchsize
=
context
.
num_prefill_tokens
+
\
context
.
num_decode_tokens
else
:
# for v1 attention backends
batchsize
=
context
.
num_input_tokens
# we use synchronous scheduling right now,
# adding a sync point here should not affect
# scheduling of the next batch
torch
.
cuda
.
synchronize
()
now
=
time
.
perf_counter
()
# time measurement is in milliseconds
batchsize_forward_time
[
batchsize
].
append
(
(
now
-
forward_start_time
)
*
1000
)
if
now
-
last_logging_time
>
batchsize_logging_interval
:
last_logging_time
=
now
forward_stats
=
[]
for
bs
,
times
in
batchsize_forward_time
.
items
():
if
len
(
times
)
<=
1
:
# can be cudagraph / profiling run
continue
medium
=
torch
.
quantile
(
torch
.
tensor
(
times
),
q
=
0.5
).
item
()
medium
=
round
(
medium
,
2
)
forward_stats
.
append
((
bs
,
len
(
times
),
medium
))
forward_stats
.
sort
(
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
if
forward_stats
:
logger
.
info
((
"Batchsize forward time stats "
"(batchsize, count, median_time(ms)): %s"
),
forward_stats
)
_forward_context
=
prev_context
_forward_context
=
prev_context
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment