Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a9b15c60
Unverified
Commit
a9b15c60
authored
Sep 27, 2024
by
youkaichao
Committed by
GitHub
Sep 27, 2024
Browse files
[torch.compile] use empty tensor instead of None for profiling (#8875)
parent
8df2dc3c
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
84 additions
and
32 deletions
+84
-32
tests/kernels/test_encoder_decoder_attn.py
tests/kernels/test_encoder_decoder_attn.py
+6
-2
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+4
-2
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+4
-2
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+3
-3
vllm/attention/backends/ipex_attn.py
vllm/attention/backends/ipex_attn.py
+6
-3
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+7
-5
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+4
-2
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+6
-3
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+5
-3
vllm/worker/embedding_model_runner.py
vllm/worker/embedding_model_runner.py
+7
-1
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+7
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+7
-1
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+2
-2
vllm/worker/tpu_worker.py
vllm/worker/tpu_worker.py
+9
-1
vllm/worker/xpu_model_runner.py
vllm/worker/xpu_model_runner.py
+7
-1
No files found.
tests/kernels/test_encoder_decoder_attn.py
View file @
a9b15c60
...
...
@@ -136,7 +136,9 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
)
if
test_pt
.
num_blocks
is
None
or
test_pt
.
num_heads
is
None
:
# Caller does not require a KV cache
return
TestResources
(
scale
,
attn_backend
,
attn
,
None
)
return
TestResources
(
scale
,
attn_backend
,
attn
,
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
CUDA_DEVICE
))
# Construct KV cache
kv_cache
=
make_kv_cache
(
test_pt
.
num_blocks
,
...
...
@@ -620,7 +622,9 @@ def _run_encoder_attention_test(
return
attn
.
forward
(
packed_qkv
.
query
,
packed_qkv
.
key
,
packed_qkv
.
value
,
None
,
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
packed_qkv
.
query
.
device
),
attn_metadata
,
attn_type
=
attn_type
)
...
...
vllm/attention/backends/blocksparse_attn.py
View file @
a9b15c60
...
...
@@ -357,6 +357,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
...
...
@@ -373,7 +375,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
is
not
None
:
if
kv_cache
.
numel
()
>
0
:
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
...
...
@@ -399,7 +401,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
assert
kv_cache
is
None
\
assert
kv_cache
.
numel
()
==
0
\
or
prefill_meta
.
block_tables
is
None
\
or
prefill_meta
.
block_tables
.
numel
()
==
0
,
\
"Does not support prefix-enabled attention."
...
...
vllm/attention/backends/flash_attn.py
View file @
a9b15c60
...
...
@@ -665,6 +665,8 @@ class FlashAttentionImpl(AttentionImpl):
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
...
...
@@ -685,7 +687,7 @@ class FlashAttentionImpl(AttentionImpl):
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
is
not
None
:
if
kv_cache
.
numel
()
>
0
:
key_cache
=
kv_cache
[
0
]
value_cache
=
kv_cache
[
1
]
...
...
@@ -722,7 +724,7 @@ class FlashAttentionImpl(AttentionImpl):
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
if
(
kv_cache
is
None
or
prefill_meta
.
block_tables
is
None
if
(
kv_cache
.
numel
()
==
0
or
prefill_meta
.
block_tables
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
):
# normal attention
# When block_tables are not filled, it means q and k are the
...
...
vllm/attention/backends/flashinfer.py
View file @
a9b15c60
...
...
@@ -746,7 +746,7 @@ class FlashInferImpl(AttentionImpl):
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
]
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashInferMetadata
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
...
...
@@ -770,7 +770,7 @@ class FlashInferImpl(AttentionImpl):
if
attn_metadata
.
num_decode_tokens
>
0
:
assert
attn_metadata
.
num_prefill_tokens
==
0
,
(
"Chunked prefill is not supported with flashinfer yet."
)
if
kv_cache
is
not
None
:
if
kv_cache
.
numel
()
>
0
:
# Use the same reshape and cache kernel as flash attention.
ops
.
reshape_and_cache_flash
(
key
,
...
...
@@ -796,7 +796,7 @@ class FlashInferImpl(AttentionImpl):
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if
kv_cache
is
None
:
if
kv_cache
.
numel
()
==
0
:
output
=
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
...
...
vllm/attention/backends/ipex_attn.py
View file @
a9b15c60
...
...
@@ -167,7 +167,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
]
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
IpexAttnMetadata
,
# type: ignore
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
...
...
@@ -180,6 +180,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
...
...
@@ -196,7 +198,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
is
not
None
:
if
kv_cache
.
numel
()
>
0
:
key_cache
,
value_cache
=
self
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
ipex_ops
.
reshape_and_cache
(
...
...
@@ -212,7 +214,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
if
attn_metadata
.
is_prompt
:
assert
attn_metadata
.
seq_lens
is
not
None
if
(
kv_cache
is
None
or
attn_metadata
.
block_tables
.
numel
()
==
0
):
if
(
kv_cache
.
numel
()
==
0
or
attn_metadata
.
block_tables
.
numel
()
==
0
):
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
...
...
vllm/attention/backends/pallas.py
View file @
a9b15c60
...
...
@@ -143,7 +143,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Tuple
[
Optional
[
torch
.
Tensor
]
,
Optional
[
torch
.
Tensor
]
]
,
kv_cache
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
attn_metadata
:
PallasMetadata
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
...
...
@@ -155,8 +155,10 @@ class PallasAttentionBackendImpl(AttentionImpl):
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
key_cache = [num_kv_heads, num_blocks, block_size, head_size]
value_cache = [num_kv_heads, num_blocks, block_size, head_size]
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
with shape [0] for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
...
...
@@ -173,7 +175,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
value
=
value
.
view
(
batch_size
,
seq_len
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
[
0
]
is
not
None
:
if
kv_cache
[
0
]
.
numel
()
>
0
:
slot_mapping
=
attn_metadata
.
slot_mapping
key_cache
,
value_cache
=
kv_cache
write_to_kv_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
)
...
...
@@ -205,7 +207,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
output
=
output
.
permute
(
0
,
2
,
1
,
3
)
else
:
# Decoding run.
assert
kv_cache
is
not
None
assert
kv_cache
[
0
].
numel
()
>
0
pages_per_compute_block
=
16
# TODO(woosuk): Tune this value.
if
self
.
megacore_mode
==
"batch"
and
batch_size
%
2
!=
0
:
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
a9b15c60
...
...
@@ -396,6 +396,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
...
...
@@ -412,7 +414,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
is
not
None
:
if
kv_cache
.
numel
()
>
0
:
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
...
...
@@ -449,7 +451,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
assert
prefill_meta
.
seq_lens
is
not
None
if
kv_cache
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
:
if
kv_cache
.
numel
()
==
0
or
prefill_meta
.
block_tables
.
numel
()
==
0
:
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
...
...
vllm/attention/backends/torch_sdpa.py
View file @
a9b15c60
...
...
@@ -151,7 +151,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
]
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
TorchSDPAMetadata
,
# type: ignore
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
...
...
@@ -164,6 +164,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
...
...
@@ -180,7 +182,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
is
not
None
:
if
kv_cache
.
numel
()
>
0
:
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
...
...
@@ -191,7 +193,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
if
attn_metadata
.
is_prompt
:
assert
attn_metadata
.
seq_lens
is
not
None
if
(
kv_cache
is
None
or
attn_metadata
.
block_tables
.
numel
()
==
0
):
if
(
kv_cache
.
numel
()
==
0
or
attn_metadata
.
block_tables
.
numel
()
==
0
):
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
...
...
vllm/attention/backends/xformers.py
View file @
a9b15c60
...
...
@@ -445,7 +445,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
],
value
:
Optional
[
torch
.
Tensor
],
kv_cache
:
Optional
[
torch
.
Tensor
]
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
"XFormersMetadata"
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
...
...
@@ -489,6 +489,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
attn_type: Select attention type, between encoder attention,
decoder self-attention, or encoder/decoder cross-
...
...
@@ -522,7 +524,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# which KV cache memory-mapping & which
# seqlen datastructures we utilize
if
(
attn_type
!=
AttentionType
.
ENCODER
and
kv_cache
is
not
None
):
if
(
attn_type
!=
AttentionType
.
ENCODER
and
kv_cache
.
numel
()
>
0
):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
...
...
@@ -588,7 +590,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
if
kv_cache
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
:
if
kv_cache
.
numel
()
==
0
or
prefill_meta
.
block_tables
.
numel
()
==
0
:
# normal attention.
# block tables are empty if the prompt does not have a cached
# prefix.
...
...
vllm/worker/embedding_model_runner.py
View file @
a9b15c60
...
...
@@ -97,7 +97,13 @@ class EmbeddingModelRunner(
model_executable
=
self
.
model
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
kv_caches
=
[
None
]
*
num_layers
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches
=
[
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
]
*
num_layers
execute_model_kwargs
=
{
"input_ids"
:
...
...
vllm/worker/enc_dec_model_runner.py
View file @
a9b15c60
...
...
@@ -340,7 +340,13 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
# Run the model with the dummy inputs.
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
kv_caches
=
[
None
]
*
num_layers
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches
=
[
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
]
*
num_layers
finished_requests_ids
=
[
seq
.
request_id
for
seq
in
seqs
]
model_input
=
self
.
prepare_model_input
(
seqs
,
finished_requests_ids
=
finished_requests_ids
)
...
...
vllm/worker/model_runner.py
View file @
a9b15c60
...
...
@@ -1223,7 +1223,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# Run the model with the dummy inputs.
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
kv_caches
=
[
None
]
*
num_layers
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches
=
[
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
]
*
num_layers
finished_requests_ids
=
[
seq
.
request_id
for
seq
in
seqs
]
model_input
=
self
.
prepare_model_input
(
seqs
,
finished_requests_ids
=
finished_requests_ids
)
...
...
vllm/worker/tpu_model_runner.py
View file @
a9b15c60
...
...
@@ -714,7 +714,7 @@ class ModelWrapper(TorchCompileWrapperWithCustomDispatcher):
t
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
num_samples
:
int
,
kv_caches
:
List
[
Tuple
[
Optional
[
torch
.
Tensor
]
,
Optional
[
torch
.
Tensor
]]
]
,
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
torch
.
Tensor
:
"""Executes the forward pass of the model and samples the next token.
...
...
@@ -745,7 +745,7 @@ class ModelWrapper(TorchCompileWrapperWithCustomDispatcher):
)
# Skip this in memory profiling at initialization.
if
kv_caches
[
0
][
0
]
is
not
None
:
if
kv_caches
[
0
][
0
]
.
numel
()
>
0
:
# index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
...
...
vllm/worker/tpu_worker.py
View file @
a9b15c60
...
...
@@ -115,7 +115,15 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
head_size
=
self
.
model_config
.
get_head_size
()
num_kv_heads
=
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
)
kv_caches
=
[(
None
,
None
)
for
_
in
range
(
num_layers
)]
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches
=
[(
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
self
.
device
),
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
self
.
device
))
for
_
in
range
(
num_layers
)]
self
.
model_runner
.
_dummy_run
(
batch_size
=
1
,
seq_len
=
self
.
scheduler_config
.
max_num_batched_tokens
,
...
...
vllm/worker/xpu_model_runner.py
View file @
a9b15c60
...
...
@@ -464,7 +464,13 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
# Run the model with the dummy inputs.
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
kv_caches
=
[
None
]
*
num_layers
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches
=
[
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
]
*
num_layers
finished_requests_ids
=
[
seq
.
request_id
for
seq
in
seqs
]
model_input
=
self
.
prepare_model_input
(
seqs
,
finished_requests_ids
=
finished_requests_ids
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment