Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9aaf14c6
Unverified
Commit
9aaf14c6
authored
Oct 03, 2024
by
youkaichao
Committed by
GitHub
Oct 03, 2024
Browse files
[misc] add forward context for attention (#9029)
parent
63e39937
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
250 additions
and
334 deletions
+250
-334
tests/kernels/test_flash_attn.py
tests/kernels/test_flash_attn.py
+7
-49
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+178
-251
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+2
-2
vllm/forward_context.py
vllm/forward_context.py
+22
-0
vllm/spec_decode/draft_model_runner.py
vllm/spec_decode/draft_model_runner.py
+12
-10
vllm/worker/embedding_model_runner.py
vllm/worker/embedding_model_runner.py
+3
-1
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+13
-11
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+13
-10
No files found.
tests/kernels/test_flash_attn.py
View file @
9aaf14c6
...
...
@@ -3,9 +3,9 @@ from typing import List, Optional, Tuple
import
pytest
import
torch
import
vllm.attention.backends.flash_attn
# noqa: F401
from
tests.kernels.utils
import
opcheck
from
vllm.utils
import
seed_everything
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
)
NUM_HEADS
=
[(
4
,
4
),
(
8
,
2
),
(
16
,
2
)]
HEAD_SIZES
=
[
128
,
256
]
...
...
@@ -112,10 +112,10 @@ def test_flash_attn_with_paged_kv(
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
output
=
torch
.
ops
.
vllm
.
flash_attn_with_kvcache
(
decode_query
=
query
.
unsqueeze
(
1
),
k
ey
_cache
=
key_cache
,
v
alue
_cache
=
value_cache
,
output
=
flash_attn_with_kvcache
(
q
=
query
.
unsqueeze
(
1
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
softmax_scale
=
scale
,
causal
=
True
,
block_table
=
block_tables
,
...
...
@@ -123,25 +123,6 @@ def test_flash_attn_with_paged_kv(
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
).
squeeze
(
1
)
if
num_blocks
<=
2048
:
test_utils
=
[
"test_faketensor"
,
"test_schema"
]
else
:
test_utils
=
[
"test_faketensor"
]
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_with_kvcache
,
args
=
tuple
(),
kwargs
=
dict
(
decode_query
=
query
.
unsqueeze
(
1
),
key_cache
=
key_cache
,
value_cache
=
value_cache
,
softmax_scale
=
scale
,
causal
=
True
,
block_table
=
block_tables
,
cache_seqlens
=
kv_lens_tensor
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
),
test_utils
=
test_utils
)
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
...
...
@@ -213,7 +194,7 @@ def test_varlen_with_paged_kv(
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
output
=
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
...
...
@@ -228,29 +209,6 @@ def test_varlen_with_paged_kv(
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
)
if
num_blocks
<=
2048
:
test_utils
=
[
"test_faketensor"
,
"test_schema"
]
else
:
test_utils
=
[
"test_faketensor"
]
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_varlen_func
,
args
=
tuple
(),
kwargs
=
dict
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
cu_query_lens
,
cu_seqlens_k
=
cu_kv_lens
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
,
softmax_scale
=
scale
,
causal
=
True
,
window_size
=
window_size
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
),
test_utils
=
test_utils
)
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
...
...
vllm/attention/backends/flash_attn.py
View file @
9aaf14c6
...
...
@@ -13,152 +13,15 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
from
vllm.forward_context
import
get_forward_context
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
# yapf: disable
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
as
_flash_attn_varlen_func
)
from
vllm.vllm_flash_attn
import
(
flash_attn_with_kvcache
as
_flash_attn_with_kvcache
)
# yapf: enable
@
torch
.
library
.
custom_op
(
"vllm::flash_attn_varlen_func"
,
mutates_args
=
[])
def
flash_attn_varlen_func
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens_q
:
torch
.
Tensor
,
cu_seqlens_k
:
torch
.
Tensor
,
max_seqlen_q
:
int
,
max_seqlen_k
:
int
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
window_size
:
Optional
[
List
[
int
]]
=
None
,
softcap
:
float
=
0.0
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
block_table
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
# custom op does not support tuple input
real_window_size
:
Tuple
[
int
,
int
]
if
window_size
is
None
:
real_window_size
=
(
-
1
,
-
1
)
else
:
assert
len
(
window_size
)
==
2
real_window_size
=
(
window_size
[
0
],
window_size
[
1
])
return
_flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
window_size
=
real_window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
block_table
=
block_table
,
)
@
flash_attn_varlen_func
.
register_fake
# type: ignore
def
_
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens_q
:
torch
.
Tensor
,
cu_seqlens_k
:
torch
.
Tensor
,
max_seqlen_q
:
int
,
max_seqlen_k
:
int
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
window_size
:
Optional
[
List
[
int
]]
=
None
,
softcap
:
float
=
0.0
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
block_table
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
q
)
@
torch
.
library
.
custom_op
(
"vllm::flash_attn_with_kvcache"
,
mutates_args
=
[])
def
flash_attn_with_kvcache
(
decode_query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
cache_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
block_table
:
Optional
[
torch
.
Tensor
]
=
None
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
softcap
:
float
=
0.0
,
)
->
torch
.
Tensor
:
return
_flash_attn_with_kvcache
(
decode_query
,
key_cache
,
value_cache
,
cache_seqlens
=
cache_seqlens
,
block_table
=
block_table
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
alibi_slopes
=
alibi_slopes
,
softcap
=
softcap
,
)
@
flash_attn_with_kvcache
.
register_fake
# type: ignore
def
_
(
decode_query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
cache_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
block_table
:
Optional
[
torch
.
Tensor
]
=
None
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
softcap
:
float
=
0.0
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
decode_query
)
@
torch
.
library
.
custom_op
(
"vllm::reshape_and_cache_flash"
,
mutates_args
=
[
"kv_cache"
])
def
reshape_and_cache_flash
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
)
->
None
:
"""Inductor cannot deal with inplace operations on views.
See https://github.com/pytorch/pytorch/issues/131192
and https://github.com/pytorch/pytorch/issues/130174
This is a workaround to hide the view operation from the inductor.
"""
return
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
kv_cache
[
0
],
kv_cache
[
1
],
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
@
reshape_and_cache_flash
.
register_fake
# type: ignore
def
_
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
)
->
None
:
pass
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
)
class
FlashAttentionBackend
(
AttentionBackend
):
...
...
@@ -721,11 +584,55 @@ class FlashAttentionImpl(AttentionImpl):
assert
k_scale
==
1.0
and
v_scale
==
1.0
,
(
"key/v_scale is not supported in FlashAttention."
)
output
=
torch
.
ops
.
vllm
.
unified_flash_attention
(
query
,
key
,
value
,
self
.
num_heads
,
self
.
head_size
,
self
.
num_kv_heads
,
kv_cache
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
self
.
scale
,
self
.
sliding_window
,
self
.
alibi_slopes
,
self
.
logits_soft_cap
,
)
return
output
@
torch
.
library
.
custom_op
(
"vllm::unified_flash_attention"
,
mutates_args
=
[
"kv_cache"
])
def
unified_flash_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
num_heads
:
int
,
head_size
:
int
,
num_kv_heads
:
int
,
kv_cache
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
softmax_scale
:
float
,
window_size
:
Optional
[
List
[
int
]]
=
None
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
current_metadata
=
get_forward_context
()
assert
current_metadata
is
not
None
assert
isinstance
(
current_metadata
,
FlashAttentionMetadata
)
attn_metadata
:
FlashAttentionMetadata
=
current_metadata
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_kv_heads
,
head_size
)
if
kv_cache
.
numel
()
>
0
:
key_cache
=
kv_cache
[
0
]
...
...
@@ -734,12 +641,13 @@ class FlashAttentionImpl(AttentionImpl):
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
torch
.
ops
.
vllm
.
reshape_and_cache_flash
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
kv_cache
,
kv_cache
[
0
],
kv_cache
[
1
],
attn_metadata
.
slot_mapping
.
flatten
(),
self
.
kv_cache_dtype
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
...
...
@@ -771,7 +679,7 @@ class FlashAttentionImpl(AttentionImpl):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
prefill_output
=
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
prefill_output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
...
...
@@ -779,17 +687,17 @@ class FlashAttentionImpl(AttentionImpl):
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
s
elf
.
scale
,
softmax_scale
=
s
oftmax_
scale
,
causal
=
True
,
window_size
=
self
.
sliding_
window
,
alibi_slopes
=
self
.
alibi_slopes
,
softcap
=
self
.
logits_soft_cap
,
window_size
=
window
_size
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
)
else
:
# prefix-enabled attention
assert
prefill_meta
.
seq_lens
is
not
None
max_seq_len
=
max
(
prefill_meta
.
seq_lens
)
prefill_output
=
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
# noqa
prefill_output
=
flash_attn_varlen_func
(
# noqa
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
...
...
@@ -797,30 +705,29 @@ class FlashAttentionImpl(AttentionImpl):
max_seqlen_q
=
prefill_meta
.
max_query_len
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_k
=
max_seq_len
,
softmax_scale
=
s
elf
.
scale
,
softmax_scale
=
s
oftmax_
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
block_table
=
prefill_meta
.
block_tables
,
softcap
=
self
.
logits_soft_cap
,
softcap
=
logits_soft_cap
,
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
_
,
num_head
,
head_dim
=
decode_query
.
shape
decode_query
=
decode_query
.
reshape
(
-
1
,
decode_meta
.
decode_query_len
,
decode_query
=
decode_query
.
reshape
(
-
1
,
decode_meta
.
decode_query_len
,
num_head
,
head_dim
)
decode_output
=
torch
.
ops
.
vllm
.
flash_attn_with_kvcache
(
decode_query
,
key_cache
,
value_cache
,
decode_output
=
flash_attn_with_kvcache
(
q
=
decode_query
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
block_table
=
decode_meta
.
block_tables
,
cache_seqlens
=
decode_meta
.
seq_lens_tensor
,
softmax_scale
=
s
elf
.
scale
,
softmax_scale
=
s
oftmax_
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
softcap
=
self
.
logits_soft_cap
,
)
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
).
squeeze
(
1
)
if
prefill_output
is
None
:
assert
decode_output
is
not
None
...
...
@@ -836,3 +743,23 @@ class FlashAttentionImpl(AttentionImpl):
decode_output
=
decode_output
.
squeeze
(
1
)
output
=
torch
.
cat
([
prefill_output
,
decode_output
],
dim
=
0
)
return
output
.
view
(
num_tokens
,
hidden_size
)
@
unified_flash_attention
.
register_fake
def
_
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
num_heads
:
int
,
head_size
:
int
,
num_kv_heads
:
int
,
kv_cache
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
softmax_scale
:
float
,
window_size
:
Optional
[
List
[
int
]]
=
None
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
query
)
vllm/attention/backends/flashinfer.py
View file @
9aaf14c6
...
...
@@ -7,7 +7,7 @@ try:
from
flashinfer.decode
import
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from
flashinfer.prefill
import
BatchPrefillWithPagedKVCacheWrapper
import
vllm.attention.backends.flash_attn
# noqa
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
FLASHINFER_WORKSPACE_BUFFER_SIZE
=
256
*
1024
*
1024
except
ImportError
:
BatchDecodeWithPagedKVCacheWrapper
=
None
...
...
@@ -799,7 +799,7 @@ class FlashInferImpl(AttentionImpl):
# This happens when vllm runs the profiling to
# determine the number of blocks.
if
kv_cache
.
numel
()
==
0
:
output
=
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
...
...
vllm/forward_context.py
0 → 100644
View file @
9aaf14c6
from
contextlib
import
contextmanager
from
typing
import
Any
_forward_context
:
Any
=
None
def
get_forward_context
()
->
Any
:
"""Get the current forward context."""
return
_forward_context
@
contextmanager
def
set_forward_context
(
context
:
Any
):
"""A context manager that stores the current forward context,
can be attention metadata, etc."""
global
_forward_context
prev_context
=
_forward_context
_forward_context
=
context
try
:
yield
finally
:
_forward_context
=
prev_context
vllm/spec_decode/draft_model_runner.py
View file @
9aaf14c6
...
...
@@ -2,6 +2,7 @@ from typing import List, Optional
import
torch
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.layers.sampler
import
SamplerOutput
try
:
...
...
@@ -291,6 +292,7 @@ class TP1DraftModelRunner(ModelRunner):
if
previous_hidden_states
is
not
None
else
{}
# Run model
with
set_forward_context
(
model_input
.
attn_metadata
):
hidden_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
...
...
vllm/worker/embedding_model_runner.py
View file @
9aaf14c6
...
...
@@ -6,6 +6,7 @@ import torch
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
)
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.multimodal
import
MultiModalInputs
...
...
@@ -119,6 +120,7 @@ class EmbeddingModelRunner(
device
=
self
.
device
),
}
with
set_forward_context
(
model_input
.
attn_metadata
):
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
# Only perform pooling in the driver worker.
...
...
vllm/worker/enc_dec_model_runner.py
View file @
9aaf14c6
...
...
@@ -14,6 +14,7 @@ from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
)
from
vllm.forward_context
import
set_forward_context
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
...
...
@@ -198,6 +199,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
}
if
self
.
has_seqlen_agnostic
else
{}
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
or
{}
with
set_forward_context
(
model_input
.
attn_metadata
):
hidden_or_intermediate_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
...
...
vllm/worker/model_runner.py
View file @
9aaf14c6
...
...
@@ -24,6 +24,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.distributed
import
get_pp_group
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.forward_context
import
set_forward_context
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
LoRAMapping
...
...
@@ -1499,6 +1500,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
_update_inputs_to_capture_for_enc_dec_model
(
capture_inputs
)
with
set_forward_context
(
attn_metadata
):
graph_runner
.
capture
(
**
capture_inputs
)
self
.
graph_memory_pool
=
graph_runner
.
graph
.
pool
()
self
.
graph_runners
[
virtual_engine
][
batch_size
]
=
(
...
...
@@ -1641,6 +1643,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_forward_end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
model_forward_start
.
record
()
with
set_forward_context
(
model_input
.
attn_metadata
):
hidden_or_intermediate_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment