Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
abfc4f33
Unverified
Commit
abfc4f33
authored
Mar 17, 2024
by
Woosuk Kwon
Committed by
GitHub
Mar 17, 2024
Browse files
[Misc] Use dataclass for InputMetadata (#3452)
Co-authored-by:
youkaichao
<
youkaichao@126.com
>
parent
6b78837b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
63 deletions
+24
-63
setup.py
setup.py
+0
-1
vllm/model_executor/input_metadata.py
vllm/model_executor/input_metadata.py
+14
-35
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+10
-27
No files found.
setup.py
View file @
abfc4f33
...
...
@@ -2,7 +2,6 @@ import contextlib
import
io
import
os
import
re
import
shutil
import
subprocess
import
warnings
from
pathlib
import
Path
...
...
vllm/model_executor/input_metadata.py
View file @
abfc4f33
from
dataclasses
import
dataclass
from
typing
import
Optional
import
torch
@
dataclass
class
InputMetadata
:
"""Metadata for input sequences. Used in PagedAttention.
...
...
@@ -15,40 +17,17 @@ class InputMetadata:
kv_cache_dtype: Data type to store kv cache.
"""
def
__init__
(
self
,
is_prompt
:
bool
,
slot_mapping
:
torch
.
Tensor
,
prompt_lens
:
Optional
[
torch
.
Tensor
],
max_seq_len
:
Optional
[
int
],
start_loc
:
Optional
[
torch
.
Tensor
],
max_context_len
:
Optional
[
int
],
context_lens
:
Optional
[
torch
.
Tensor
],
block_tables
:
Optional
[
torch
.
Tensor
],
use_cuda_graph
:
bool
,
kv_cache_dtype
:
str
,
)
->
None
:
self
.
is_prompt
=
is_prompt
self
.
prompt_lens
=
prompt_lens
self
.
max_seq_len
=
max_seq_len
self
.
start_loc
=
start_loc
self
.
max_context_len
=
max_context_len
self
.
slot_mapping
=
slot_mapping
self
.
context_lens
=
context_lens
self
.
block_tables
=
block_tables
self
.
use_cuda_graph
=
use_cuda_graph
self
.
kv_cache_dtype
=
kv_cache_dtype
is_prompt
:
bool
slot_mapping
:
torch
.
Tensor
prompt_lens
:
Optional
[
torch
.
Tensor
]
max_seq_len
:
Optional
[
int
]
start_loc
:
Optional
[
torch
.
Tensor
]
max_context_len
:
Optional
[
int
]
context_lens
:
Optional
[
torch
.
Tensor
]
block_tables
:
Optional
[
torch
.
Tensor
]
use_cuda_graph
:
bool
kv_cache_dtype
:
str
# Set during the execution of the first attention op.
#
FIXME(woosuk): This is a hack.
def
__post_init__
(
self
):
#
will not appear in the __repr__ and __init__
self
.
attn_bias
=
None
def
__repr__
(
self
)
->
str
:
return
(
"InputMetadata("
f
"is_prompt=
{
self
.
is_prompt
}
, "
f
"max_context_len=
{
self
.
max_context_len
}
, "
f
"slot_mapping=
{
self
.
slot_mapping
}
, "
f
"context_lens=
{
self
.
context_lens
}
, "
f
"block_tables=
{
self
.
block_tables
}
, "
f
"use_cuda_graph=
{
self
.
use_cuda_graph
}
, "
f
"kv_cache_dtype=
{
self
.
kv_cache_dtype
}
)"
)
vllm/worker/model_runner.py
View file @
abfc4f33
import
contextlib
import
dataclasses
import
time
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Set
,
Union
...
...
@@ -521,45 +522,27 @@ class ModelRunner:
metadata_dict
=
{
"input_tokens"
:
input_tokens
,
"input_positions"
:
input_positions
,
"is_prompt"
:
input_metadata
.
is_prompt
,
"slot_mapping"
:
input_metadata
.
slot_mapping
,
"prompt_lens"
:
input_metadata
.
prompt_lens
,
"max_seq_len"
:
input_metadata
.
max_seq_len
,
"start_loc"
:
input_metadata
.
start_loc
,
"max_context_len"
:
input_metadata
.
max_context_len
,
"context_lens"
:
input_metadata
.
context_lens
,
"block_tables"
:
input_metadata
.
block_tables
,
"use_cuda_graph"
:
input_metadata
.
use_cuda_graph
,
"kv_cache_dtype"
:
input_metadata
.
kv_cache_dtype
,
"selected_token_indices"
:
sampling_metadata
.
selected_token_indices
,
"lora_requests"
:
lora_requests
,
"lora_mapping"
:
lora_mapping
,
}
metadata_dict
.
update
(
dataclasses
.
asdict
(
input_metadata
))
broadcast_tensor_dict
(
metadata_dict
,
src
=
0
)
else
:
metadata_dict
=
broadcast_tensor_dict
(
src
=
0
)
input_tokens
=
metadata_dict
[
"input_tokens"
]
input_positions
=
metadata_dict
[
"input_positions"
]
lora_mapping
=
metadata_dict
[
"lora_mapping"
]
lora_requests
=
metadata_dict
[
"lora_requests"
]
input_metadata
=
InputMetadata
(
is_prompt
=
metadata_dict
[
"is_prompt"
],
slot_mapping
=
metadata_dict
[
"slot_mapping"
],
prompt_lens
=
metadata_dict
[
"prompt_lens"
],
max_seq_len
=
metadata_dict
[
"max_seq_len"
],
start_loc
=
metadata_dict
[
"start_loc"
],
max_context_len
=
metadata_dict
[
"max_context_len"
],
context_lens
=
metadata_dict
[
"context_lens"
],
block_tables
=
metadata_dict
[
"block_tables"
],
use_cuda_graph
=
metadata_dict
[
"use_cuda_graph"
],
kv_cache_dtype
=
metadata_dict
[
"kv_cache_dtype"
],
)
input_tokens
=
metadata_dict
.
pop
(
"input_tokens"
)
input_positions
=
metadata_dict
.
pop
(
"input_positions"
)
selected_token_indices
=
metadata_dict
.
pop
(
"selected_token_indices"
)
lora_mapping
=
metadata_dict
.
pop
(
"lora_mapping"
)
lora_requests
=
metadata_dict
.
pop
(
"lora_requests"
)
input_metadata
=
InputMetadata
(
**
metadata_dict
)
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
None
,
seq_data
=
None
,
prompt_lens
=
None
,
selected_token_indices
=
metadata_dict
[
"
selected_token_indices
"
]
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
None
,
generators
=
None
,
perform_sampling
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment