Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
abfc4f33
Unverified
Commit
abfc4f33
authored
Mar 17, 2024
by
Woosuk Kwon
Committed by
GitHub
Mar 17, 2024
Browse files
[Misc] Use dataclass for InputMetadata (#3452)
Co-authored-by:
youkaichao
<
youkaichao@126.com
>
parent
6b78837b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
63 deletions
+24
-63
setup.py
setup.py
+0
-1
vllm/model_executor/input_metadata.py
vllm/model_executor/input_metadata.py
+14
-35
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+10
-27
No files found.
setup.py
View file @
abfc4f33
...
@@ -2,7 +2,6 @@ import contextlib
...
@@ -2,7 +2,6 @@ import contextlib
import
io
import
io
import
os
import
os
import
re
import
re
import
shutil
import
subprocess
import
subprocess
import
warnings
import
warnings
from
pathlib
import
Path
from
pathlib
import
Path
...
...
vllm/model_executor/input_metadata.py
View file @
abfc4f33
from
dataclasses
import
dataclass
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
@
dataclass
class
InputMetadata
:
class
InputMetadata
:
"""Metadata for input sequences. Used in PagedAttention.
"""Metadata for input sequences. Used in PagedAttention.
...
@@ -15,40 +17,17 @@ class InputMetadata:
...
@@ -15,40 +17,17 @@ class InputMetadata:
kv_cache_dtype: Data type to store kv cache.
kv_cache_dtype: Data type to store kv cache.
"""
"""
def
__init__
(
is_prompt
:
bool
self
,
slot_mapping
:
torch
.
Tensor
is_prompt
:
bool
,
prompt_lens
:
Optional
[
torch
.
Tensor
]
slot_mapping
:
torch
.
Tensor
,
max_seq_len
:
Optional
[
int
]
prompt_lens
:
Optional
[
torch
.
Tensor
],
start_loc
:
Optional
[
torch
.
Tensor
]
max_seq_len
:
Optional
[
int
],
max_context_len
:
Optional
[
int
]
start_loc
:
Optional
[
torch
.
Tensor
],
context_lens
:
Optional
[
torch
.
Tensor
]
max_context_len
:
Optional
[
int
],
block_tables
:
Optional
[
torch
.
Tensor
]
context_lens
:
Optional
[
torch
.
Tensor
],
use_cuda_graph
:
bool
block_tables
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
use_cuda_graph
:
bool
,
kv_cache_dtype
:
str
,
)
->
None
:
self
.
is_prompt
=
is_prompt
self
.
prompt_lens
=
prompt_lens
self
.
max_seq_len
=
max_seq_len
self
.
start_loc
=
start_loc
self
.
max_context_len
=
max_context_len
self
.
slot_mapping
=
slot_mapping
self
.
context_lens
=
context_lens
self
.
block_tables
=
block_tables
self
.
use_cuda_graph
=
use_cuda_graph
self
.
kv_cache_dtype
=
kv_cache_dtype
# Set during the execution of the first attention op.
def
__post_init__
(
self
):
#
FIXME(woosuk): This is a hack.
#
will not appear in the __repr__ and __init__
self
.
attn_bias
=
None
self
.
attn_bias
=
None
def
__repr__
(
self
)
->
str
:
return
(
"InputMetadata("
f
"is_prompt=
{
self
.
is_prompt
}
, "
f
"max_context_len=
{
self
.
max_context_len
}
, "
f
"slot_mapping=
{
self
.
slot_mapping
}
, "
f
"context_lens=
{
self
.
context_lens
}
, "
f
"block_tables=
{
self
.
block_tables
}
, "
f
"use_cuda_graph=
{
self
.
use_cuda_graph
}
, "
f
"kv_cache_dtype=
{
self
.
kv_cache_dtype
}
)"
)
vllm/worker/model_runner.py
View file @
abfc4f33
import
contextlib
import
contextlib
import
dataclasses
import
time
import
time
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Set
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Set
,
Union
...
@@ -521,45 +522,27 @@ class ModelRunner:
...
@@ -521,45 +522,27 @@ class ModelRunner:
metadata_dict
=
{
metadata_dict
=
{
"input_tokens"
:
input_tokens
,
"input_tokens"
:
input_tokens
,
"input_positions"
:
input_positions
,
"input_positions"
:
input_positions
,
"is_prompt"
:
input_metadata
.
is_prompt
,
"slot_mapping"
:
input_metadata
.
slot_mapping
,
"prompt_lens"
:
input_metadata
.
prompt_lens
,
"max_seq_len"
:
input_metadata
.
max_seq_len
,
"start_loc"
:
input_metadata
.
start_loc
,
"max_context_len"
:
input_metadata
.
max_context_len
,
"context_lens"
:
input_metadata
.
context_lens
,
"block_tables"
:
input_metadata
.
block_tables
,
"use_cuda_graph"
:
input_metadata
.
use_cuda_graph
,
"kv_cache_dtype"
:
input_metadata
.
kv_cache_dtype
,
"selected_token_indices"
:
"selected_token_indices"
:
sampling_metadata
.
selected_token_indices
,
sampling_metadata
.
selected_token_indices
,
"lora_requests"
:
lora_requests
,
"lora_requests"
:
lora_requests
,
"lora_mapping"
:
lora_mapping
,
"lora_mapping"
:
lora_mapping
,
}
}
metadata_dict
.
update
(
dataclasses
.
asdict
(
input_metadata
))
broadcast_tensor_dict
(
metadata_dict
,
src
=
0
)
broadcast_tensor_dict
(
metadata_dict
,
src
=
0
)
else
:
else
:
metadata_dict
=
broadcast_tensor_dict
(
src
=
0
)
metadata_dict
=
broadcast_tensor_dict
(
src
=
0
)
input_tokens
=
metadata_dict
[
"input_tokens"
]
input_tokens
=
metadata_dict
.
pop
(
"input_tokens"
)
input_positions
=
metadata_dict
[
"input_positions"
]
input_positions
=
metadata_dict
.
pop
(
"input_positions"
)
lora_mapping
=
metadata_dict
[
"lora_mapping"
]
selected_token_indices
=
metadata_dict
.
pop
(
lora_requests
=
metadata_dict
[
"lora_requests"
]
"selected_token_indices"
)
input_metadata
=
InputMetadata
(
lora_mapping
=
metadata_dict
.
pop
(
"lora_mapping"
)
is_prompt
=
metadata_dict
[
"is_prompt"
],
lora_requests
=
metadata_dict
.
pop
(
"lora_requests"
)
slot_mapping
=
metadata_dict
[
"slot_mapping"
],
input_metadata
=
InputMetadata
(
**
metadata_dict
)
prompt_lens
=
metadata_dict
[
"prompt_lens"
],
max_seq_len
=
metadata_dict
[
"max_seq_len"
],
start_loc
=
metadata_dict
[
"start_loc"
],
max_context_len
=
metadata_dict
[
"max_context_len"
],
context_lens
=
metadata_dict
[
"context_lens"
],
block_tables
=
metadata_dict
[
"block_tables"
],
use_cuda_graph
=
metadata_dict
[
"use_cuda_graph"
],
kv_cache_dtype
=
metadata_dict
[
"kv_cache_dtype"
],
)
sampling_metadata
=
SamplingMetadata
(
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
None
,
seq_groups
=
None
,
seq_data
=
None
,
seq_data
=
None
,
prompt_lens
=
None
,
prompt_lens
=
None
,
selected_token_indices
=
metadata_dict
[
"
selected_token_indices
"
]
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
None
,
categorized_sample_indices
=
None
,
generators
=
None
,
generators
=
None
,
perform_sampling
=
False
,
perform_sampling
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment