Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
d6fa1be3
Unverified
Commit
d6fa1be3
authored
Jul 03, 2023
by
Zhuohan Li
Committed by
GitHub
Jul 03, 2023
Browse files
[Quality] Add code formatter and linter (#326)
parent
0ffded81
Changes
47
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
106 additions
and
43 deletions
+106
-43
vllm/outputs.py
vllm/outputs.py
+4
-2
vllm/sampling_params.py
vllm/sampling_params.py
+8
-2
vllm/sequence.py
vllm/sequence.py
+70
-19
vllm/transformers_utils/tokenizer.py
vllm/transformers_utils/tokenizer.py
+3
-2
vllm/utils.py
vllm/utils.py
+3
-2
vllm/worker/cache_engine.py
vllm/worker/cache_engine.py
+5
-6
vllm/worker/worker.py
vllm/worker/worker.py
+13
-10
No files found.
vllm/outputs.py
View file @
d6fa1be3
...
@@ -55,6 +55,7 @@ class RequestOutput:
...
@@ -55,6 +55,7 @@ class RequestOutput:
outputs: The output sequences of the request.
outputs: The output sequences of the request.
finished: Whether the whole request is finished.
finished: Whether the whole request is finished.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
...
@@ -75,8 +76,9 @@ class RequestOutput:
...
@@ -75,8 +76,9 @@ class RequestOutput:
n
=
seq_group
.
sampling_params
.
n
n
=
seq_group
.
sampling_params
.
n
seqs
=
seq_group
.
get_seqs
()
seqs
=
seq_group
.
get_seqs
()
assert
n
<=
len
(
seqs
)
assert
n
<=
len
(
seqs
)
sorted_seqs
=
sorted
(
sorted_seqs
=
sorted
(
seqs
,
seqs
,
key
=
lambda
seq
:
seq
.
get_cumulative_logprob
(),
reverse
=
True
)
key
=
lambda
seq
:
seq
.
get_cumulative_logprob
(),
reverse
=
True
)
top_n_seqs
=
sorted_seqs
[:
n
]
top_n_seqs
=
sorted_seqs
[:
n
]
# Create the outputs.
# Create the outputs.
...
...
vllm/sampling_params.py
View file @
d6fa1be3
...
@@ -3,6 +3,7 @@ from typing import List, Optional, Union
...
@@ -3,6 +3,7 @@ from typing import List, Optional, Union
_SAMPLING_EPS
=
1e-5
_SAMPLING_EPS
=
1e-5
class
SamplingParams
:
class
SamplingParams
:
"""Sampling parameters for text generation.
"""Sampling parameters for text generation.
...
@@ -51,7 +52,7 @@ class SamplingParams:
...
@@ -51,7 +52,7 @@ class SamplingParams:
top_p
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_k
:
int
=
-
1
,
top_k
:
int
=
-
1
,
use_beam_search
:
bool
=
False
,
use_beam_search
:
bool
=
False
,
stop
:
Union
[
str
,
List
[
str
]]
=
[]
,
stop
:
Union
[
None
,
str
,
List
[
str
]]
=
None
,
ignore_eos
:
bool
=
False
,
ignore_eos
:
bool
=
False
,
max_tokens
:
int
=
16
,
max_tokens
:
int
=
16
,
logprobs
:
Optional
[
int
]
=
None
,
logprobs
:
Optional
[
int
]
=
None
,
...
@@ -64,7 +65,12 @@ class SamplingParams:
...
@@ -64,7 +65,12 @@ class SamplingParams:
self
.
top_p
=
top_p
self
.
top_p
=
top_p
self
.
top_k
=
top_k
self
.
top_k
=
top_k
self
.
use_beam_search
=
use_beam_search
self
.
use_beam_search
=
use_beam_search
self
.
stop
=
[
stop
]
if
isinstance
(
stop
,
str
)
else
list
(
stop
)
if
stop
is
None
:
self
.
stop
=
[]
elif
isinstance
(
stop
,
str
):
self
.
stop
=
[
stop
]
else
:
self
.
stop
=
list
(
stop
)
self
.
ignore_eos
=
ignore_eos
self
.
ignore_eos
=
ignore_eos
self
.
max_tokens
=
max_tokens
self
.
max_tokens
=
max_tokens
self
.
logprobs
=
logprobs
self
.
logprobs
=
logprobs
...
...
vllm/sequence.py
View file @
d6fa1be3
"""Sequence and its related classes."""
import
copy
import
copy
import
enum
import
enum
from
typing
import
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
...
@@ -7,6 +8,7 @@ from vllm.sampling_params import SamplingParams
...
@@ -7,6 +8,7 @@ from vllm.sampling_params import SamplingParams
class
SequenceStatus
(
enum
.
Enum
):
class
SequenceStatus
(
enum
.
Enum
):
"""Status of a sequence."""
WAITING
=
enum
.
auto
()
WAITING
=
enum
.
auto
()
RUNNING
=
enum
.
auto
()
RUNNING
=
enum
.
auto
()
SWAPPED
=
enum
.
auto
()
SWAPPED
=
enum
.
auto
()
...
@@ -21,7 +23,7 @@ class SequenceStatus(enum.Enum):
...
@@ -21,7 +23,7 @@ class SequenceStatus(enum.Enum):
SequenceStatus
.
FINISHED_STOPPED
,
SequenceStatus
.
FINISHED_STOPPED
,
SequenceStatus
.
FINISHED_LENGTH_CAPPED
,
SequenceStatus
.
FINISHED_LENGTH_CAPPED
,
SequenceStatus
.
FINISHED_ABORTED
,
SequenceStatus
.
FINISHED_ABORTED
,
SequenceStatus
.
FINISHED_IGNORED
SequenceStatus
.
FINISHED_IGNORED
,
]
]
@
staticmethod
@
staticmethod
...
@@ -40,6 +42,17 @@ class SequenceStatus(enum.Enum):
...
@@ -40,6 +42,17 @@ class SequenceStatus(enum.Enum):
class
SequenceData
:
class
SequenceData
:
"""Data associated with a sequence.
Args:
prompt_token_ids: The token IDs of the prompt.
Attributes:
prompt_token_ids: The token IDs of the prompt.
output_token_ids: The token IDs of the output.
cumulative_logprob: The cumulative log probability of the output.
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -75,6 +88,15 @@ class SequenceData:
...
@@ -75,6 +88,15 @@ class SequenceData:
class
Sequence
:
class
Sequence
:
"""Stores the data, status, and block information of a sequence.
Args:
seq_id: The ID of the sequence.
prompt: The prompt of the sequence.
prompt_token_ids: The token IDs of the prompt.
block_size: The block size of the sequence. Should be the same as the
block size used by the block manager and cache engine.
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -149,19 +171,27 @@ class Sequence:
...
@@ -149,19 +171,27 @@ class Sequence:
def
is_finished
(
self
)
->
bool
:
def
is_finished
(
self
)
->
bool
:
return
SequenceStatus
.
is_finished
(
self
.
status
)
return
SequenceStatus
.
is_finished
(
self
.
status
)
def
fork
(
self
,
child_seq
:
'Sequence'
)
->
None
:
def
fork
(
self
,
child_seq
:
"Sequence"
)
->
None
:
child_seq
.
logical_token_blocks
=
copy
.
deepcopy
(
self
.
logical_token_blocks
)
child_seq
.
logical_token_blocks
=
copy
.
deepcopy
(
self
.
logical_token_blocks
)
child_seq
.
output_logprobs
=
copy
.
deepcopy
(
self
.
output_logprobs
)
child_seq
.
output_logprobs
=
copy
.
deepcopy
(
self
.
output_logprobs
)
child_seq
.
data
=
copy
.
deepcopy
(
self
.
data
)
child_seq
.
data
=
copy
.
deepcopy
(
self
.
data
)
return
None
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
'
Sequence(seq_id=
{
self
.
seq_id
}
,
'
return
(
f
"
Sequence(seq_id=
{
self
.
seq_id
}
,
"
f
'
status=
{
self
.
status
.
name
}
,
'
f
"
status=
{
self
.
status
.
name
}
,
"
f
'
num_blocks=
{
len
(
self
.
logical_token_blocks
)
}
)
'
)
f
"
num_blocks=
{
len
(
self
.
logical_token_blocks
)
}
)
"
)
class
SequenceGroup
:
class
SequenceGroup
:
"""A group of sequences that are generated from the same prompt.
Args:
request_id: The ID of the request.
seqs: The list of sequences.
sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request.
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -191,7 +221,7 @@ class SequenceGroup:
...
@@ -191,7 +221,7 @@ class SequenceGroup:
for
seq
in
self
.
seqs
:
for
seq
in
self
.
seqs
:
if
seq
.
seq_id
==
seq_id
:
if
seq
.
seq_id
==
seq_id
:
return
seq
return
seq
raise
ValueError
(
f
'
Sequence
{
seq_id
}
not found.
'
)
raise
ValueError
(
f
"
Sequence
{
seq_id
}
not found.
"
)
def
is_finished
(
self
)
->
bool
:
def
is_finished
(
self
)
->
bool
:
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
seqs
)
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
seqs
)
...
@@ -203,14 +233,25 @@ class SequenceGroup:
...
@@ -203,14 +233,25 @@ class SequenceGroup:
class
SequenceGroupMetadata
:
class
SequenceGroupMetadata
:
"""Metadata for a sequence group. Used to create `InputMetadata`.
Args:
request_id: The ID of the request.
is_prompt: Whether the request is at prompt stage.
seq_data: The sequence data. (Seq id -> sequence data)
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
numbers)
"""
def
__init__
(
def
__init__
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
is_prompt
:
bool
,
is_prompt
:
bool
,
seq_data
:
Dict
[
int
,
SequenceData
],
# Seq id -> sequence data.
seq_data
:
Dict
[
int
,
SequenceData
],
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
block_tables
:
Dict
[
int
,
List
[
int
]],
# Seq id -> list of physical block numbers.
block_tables
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
)
->
None
:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
is_prompt
=
is_prompt
self
.
is_prompt
=
is_prompt
...
@@ -220,13 +261,23 @@ class SequenceGroupMetadata:
...
@@ -220,13 +261,23 @@ class SequenceGroupMetadata:
class
SequenceOutputs
:
class
SequenceOutputs
:
"""The model output associated with a sequence.
Args:
seq_id: The ID of the sequence.
parent_seq_id: The ID of the parent sequence (for forking in beam
search).
output_token: The output token ID.
logprobs: The logprobs of the output token.
(Token id -> logP(x_i+1 | x_0, ..., x_i))
"""
def
__init__
(
def
__init__
(
self
,
self
,
seq_id
:
int
,
seq_id
:
int
,
parent_seq_id
:
int
,
parent_seq_id
:
int
,
output_token
:
int
,
output_token
:
int
,
logprobs
:
Dict
[
int
,
float
],
# Token id -> logP(x_i+1 | x_0, ..., x_i).
logprobs
:
Dict
[
int
,
float
],
)
->
None
:
)
->
None
:
self
.
seq_id
=
seq_id
self
.
seq_id
=
seq_id
self
.
parent_seq_id
=
parent_seq_id
self
.
parent_seq_id
=
parent_seq_id
...
@@ -234,15 +285,15 @@ class SequenceOutputs:
...
@@ -234,15 +285,15 @@ class SequenceOutputs:
self
.
logprobs
=
logprobs
self
.
logprobs
=
logprobs
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
'
SequenceOutputs(seq_id=
{
self
.
seq_id
}
,
'
return
(
f
"
SequenceOutputs(seq_id=
{
self
.
seq_id
}
,
"
f
'
parent_seq_id=
{
self
.
parent_seq_id
}
,
'
f
"
parent_seq_id=
{
self
.
parent_seq_id
}
,
"
f
'
output_token=
{
self
.
output_token
}
),
'
f
"
output_token=
{
self
.
output_token
}
),
"
f
'
logprobs=
{
self
.
logprobs
}
'
)
f
"
logprobs=
{
self
.
logprobs
}
"
)
def
__eq__
(
self
,
other
:
object
)
->
bool
:
def
__eq__
(
self
,
other
:
object
)
->
bool
:
if
not
isinstance
(
other
,
SequenceOutputs
):
if
not
isinstance
(
other
,
SequenceOutputs
):
return
NotImplemented
return
NotImplemented
return
(
self
.
seq_id
==
other
.
seq_id
and
return
(
self
.
seq_id
==
other
.
seq_id
self
.
parent_seq_id
==
other
.
parent_seq_id
and
and
self
.
parent_seq_id
==
other
.
parent_seq_id
self
.
output_token
==
other
.
output_token
and
and
self
.
output_token
==
other
.
output_token
self
.
logprobs
==
other
.
logprobs
)
and
self
.
logprobs
==
other
.
logprobs
)
vllm/transformers_utils/tokenizer.py
View file @
d6fa1be3
...
@@ -13,8 +13,8 @@ _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
...
@@ -13,8 +13,8 @@ _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
def
get_tokenizer
(
def
get_tokenizer
(
tokenizer_name
:
str
,
tokenizer_name
:
str
,
tokenizer_mode
:
str
=
"auto"
,
*
args
,
*
args
,
tokenizer_mode
:
str
=
"auto"
,
**
kwargs
,
**
kwargs
,
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
"""Gets a tokenizer for the given model name via Huggingface."""
"""Gets a tokenizer for the given model name via Huggingface."""
...
@@ -73,7 +73,8 @@ def detokenize_incrementally(
...
@@ -73,7 +73,8 @@ def detokenize_incrementally(
output_text
=
tokenizer
.
convert_tokens_to_string
(
output_tokens
)
output_text
=
tokenizer
.
convert_tokens_to_string
(
output_tokens
)
return
new_token
,
output_text
return
new_token
,
output_text
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# NOTE(woosuk): The following code is slow because it runs a for loop over
# NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow
# the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple.
# even when the loop body is very simple.
...
...
vllm/utils.py
View file @
d6fa1be3
...
@@ -17,9 +17,9 @@ class Counter:
...
@@ -17,9 +17,9 @@ class Counter:
self
.
counter
=
start
self
.
counter
=
start
def
__next__
(
self
)
->
int
:
def
__next__
(
self
)
->
int
:
i
d
=
self
.
counter
i
=
self
.
counter
self
.
counter
+=
1
self
.
counter
+=
1
return
i
d
return
i
def
reset
(
self
)
->
None
:
def
reset
(
self
)
->
None
:
self
.
counter
=
0
self
.
counter
=
0
...
@@ -38,6 +38,7 @@ def get_cpu_memory() -> int:
...
@@ -38,6 +38,7 @@ def get_cpu_memory() -> int:
def
random_uuid
()
->
str
:
def
random_uuid
()
->
str
:
return
str
(
uuid
.
uuid4
().
hex
)
return
str
(
uuid
.
uuid4
().
hex
)
def
in_wsl
()
->
bool
:
def
in_wsl
()
->
bool
:
# Reference: https://github.com/microsoft/WSL/issues/4071
# Reference: https://github.com/microsoft/WSL/issues/4071
return
"microsoft"
in
" "
.
join
(
uname
()).
lower
()
return
"microsoft"
in
" "
.
join
(
uname
()).
lower
()
vllm/worker/cache_engine.py
View file @
d6fa1be3
...
@@ -93,8 +93,8 @@ class CacheEngine:
...
@@ -93,8 +93,8 @@ class CacheEngine:
if
not
pin_memory
:
if
not
pin_memory
:
# Pinning memory in WSL is not supported.
# Pinning memory in WSL is not supported.
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
logger
.
warn
(
"Using 'pin_memory=False' as WSL is detected. "
logger
.
warn
ing
(
"Using 'pin_memory=False' as WSL is detected. "
"This may slow down the performance."
)
"This may slow down the performance."
)
for
_
in
range
(
self
.
num_layers
):
for
_
in
range
(
self
.
num_layers
):
key_blocks
=
torch
.
empty
(
key_blocks
=
torch
.
empty
(
size
=
(
self
.
num_cpu_blocks
,
*
key_block_shape
),
size
=
(
self
.
num_cpu_blocks
,
*
key_block_shape
),
...
@@ -120,11 +120,10 @@ class CacheEngine:
...
@@ -120,11 +120,10 @@ class CacheEngine:
src_key_cache
,
src_value_cache
=
src
[
i
]
src_key_cache
,
src_value_cache
=
src
[
i
]
dst_key_cache
,
dst_value_cache
=
dst
[
i
]
dst_key_cache
,
dst_value_cache
=
dst
[
i
]
# Copy the key blocks.
# Copy the key blocks.
cache_ops
.
swap_blocks
(
cache_ops
.
swap_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst
)
src_key_cache
,
dst_key_cache
,
src_to_dst
)
# Copy the value blocks.
# Copy the value blocks.
cache_ops
.
swap_blocks
(
cache_ops
.
swap_blocks
(
src_value_cache
,
dst_value_cache
,
src_value_cache
,
dst_value_cache
,
src_to_dst
)
src_to_dst
)
event
=
self
.
events
[
i
]
event
=
self
.
events
[
i
]
event
.
record
(
stream
=
self
.
cache_stream
)
event
.
record
(
stream
=
self
.
cache_stream
)
...
...
vllm/worker/worker.py
View file @
d6fa1be3
...
@@ -73,8 +73,8 @@ class Worker:
...
@@ -73,8 +73,8 @@ class Worker:
# number of tokens equal to max_num_batched_tokens.
# number of tokens equal to max_num_batched_tokens.
# Enable top-k sampling to reflect the accurate memory usage.
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params
=
SamplingParams
(
top_p
=
0.99
,
vocab_size
=
self
.
model
.
config
.
vocab_size
top_k
=
self
.
model
.
config
.
vocab_size
-
1
)
sampling_params
=
SamplingParams
(
top_p
=
0.99
,
top_k
=
vocab_size
-
1
)
max_num_batched_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
max_num_batched_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
seqs
=
[]
seqs
=
[]
...
@@ -91,7 +91,8 @@ class Worker:
...
@@ -91,7 +91,8 @@ class Worker:
)
)
seqs
.
append
(
seq
)
seqs
.
append
(
seq
)
input_tokens
,
input_positions
,
input_metadata
=
self
.
_prepare_inputs
(
seqs
)
input_tokens
,
input_positions
,
input_metadata
=
self
.
_prepare_inputs
(
seqs
)
# Execute the model.
# Execute the model.
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
...
@@ -110,8 +111,9 @@ class Worker:
...
@@ -110,8 +111,9 @@ class Worker:
total_gpu_memory
=
get_gpu_memory
()
total_gpu_memory
=
get_gpu_memory
()
cache_block_size
=
CacheEngine
.
get_cache_block_size
(
cache_block_size
=
CacheEngine
.
get_cache_block_size
(
block_size
,
self
.
model_config
,
self
.
parallel_config
)
block_size
,
self
.
model_config
,
self
.
parallel_config
)
num_gpu_blocks
=
int
((
total_gpu_memory
*
gpu_memory_utilization
num_gpu_blocks
=
int
(
-
peak_memory
)
//
cache_block_size
)
(
total_gpu_memory
*
gpu_memory_utilization
-
peak_memory
)
//
cache_block_size
)
num_cpu_blocks
=
int
(
cpu_swap_space
//
cache_block_size
)
num_cpu_blocks
=
int
(
cpu_swap_space
//
cache_block_size
)
num_gpu_blocks
=
max
(
num_gpu_blocks
,
0
)
num_gpu_blocks
=
max
(
num_gpu_blocks
,
0
)
num_cpu_blocks
=
max
(
num_cpu_blocks
,
0
)
num_cpu_blocks
=
max
(
num_cpu_blocks
,
0
)
...
@@ -125,8 +127,8 @@ class Worker:
...
@@ -125,8 +127,8 @@ class Worker:
def
init_cache_engine
(
self
,
cache_config
:
CacheConfig
)
->
None
:
def
init_cache_engine
(
self
,
cache_config
:
CacheConfig
)
->
None
:
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
self
.
block_size
=
cache_config
.
block_size
self
.
block_size
=
cache_config
.
block_size
self
.
cache_engine
=
CacheEngine
(
self
.
cache_engine
=
CacheEngine
(
self
.
cache_config
,
self
.
model_config
,
self
.
cache_config
,
self
.
model_config
,
self
.
parallel_config
)
self
.
parallel_config
)
self
.
cache_events
=
self
.
cache_engine
.
events
self
.
cache_events
=
self
.
cache_engine
.
events
self
.
gpu_cache
=
self
.
cache_engine
.
gpu_cache
self
.
gpu_cache
=
self
.
cache_engine
.
gpu_cache
...
@@ -202,8 +204,8 @@ class Worker:
...
@@ -202,8 +204,8 @@ class Worker:
generation_block_tables
.
append
(
block_table
)
generation_block_tables
.
append
(
block_table
)
max_context_len
=
max
(
max_context_len
,
context_len
)
max_context_len
=
max
(
max_context_len
,
context_len
)
max_num_blocks_per_seq
=
max
(
max_num_blocks_per_seq
=
max
(
max_num_blocks_per_seq
,
max_num_blocks_per_seq
,
len
(
block_table
))
len
(
block_table
))
context_lens
.
append
(
context_len
)
context_lens
.
append
(
context_len
)
block_number
=
block_table
[
position
//
self
.
block_size
]
block_number
=
block_table
[
position
//
self
.
block_size
]
...
@@ -223,7 +225,8 @@ class Worker:
...
@@ -223,7 +225,8 @@ class Worker:
context_lens_tensor
=
torch
.
cuda
.
IntTensor
(
context_lens
)
context_lens_tensor
=
torch
.
cuda
.
IntTensor
(
context_lens
)
padded_block_tables
=
[
padded_block_tables
=
[
_pad_to_max
(
block_table
,
max_num_blocks_per_seq
)
_pad_to_max
(
block_table
,
max_num_blocks_per_seq
)
for
block_table
in
generation_block_tables
]
for
block_table
in
generation_block_tables
]
block_tables_tensor
=
torch
.
cuda
.
IntTensor
(
padded_block_tables
)
block_tables_tensor
=
torch
.
cuda
.
IntTensor
(
padded_block_tables
)
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment