Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5faedf1b
Unverified
Commit
5faedf1b
authored
Sep 10, 2024
by
Kevin Lin
Committed by
GitHub
Sep 10, 2024
Browse files
[Spec Decode] Move ops.advance_step to flash attn advance_step (#8224)
parent
02751a7a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
33 deletions
+23
-33
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+15
-6
vllm/spec_decode/draft_model_runner.py
vllm/spec_decode/draft_model_runner.py
+3
-13
vllm/worker/multi_step_model_runner.py
vllm/worker/multi_step_model_runner.py
+5
-14
No files found.
vllm/attention/backends/flash_attn.py
View file @
5faedf1b
...
@@ -16,7 +16,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
...
@@ -16,7 +16,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
from
vllm_flash_attn
import
flash_attn_varlen_func
as
_flash_attn_varlen_func
from
vllm_flash_attn
import
flash_attn_varlen_func
as
_flash_attn_varlen_func
from
vllm_flash_attn
import
flash_attn_with_kvcache
as
_flash_attn_with_kvcache
from
vllm_flash_attn
import
flash_attn_with_kvcache
as
_flash_attn_with_kvcache
...
@@ -302,14 +303,12 @@ class FlashAttentionMetadata(AttentionMetadata):
...
@@ -302,14 +303,12 @@ class FlashAttentionMetadata(AttentionMetadata):
)
)
return
self
.
_cached_decode_metadata
return
self
.
_cached_decode_metadata
def
advance_step
(
self
,
num_seqs
:
int
,
num_queries
:
int
):
def
advance_step
(
self
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
sampled_token_ids
:
Optional
[
torch
.
Tensor
],
block_size
:
int
,
num_seqs
:
int
,
num_queries
:
int
):
"""
"""
Update metadata in-place to advance one decode step.
Update metadata in-place to advance one decode step.
"""
"""
# GPU in-place update is currently called separately through
# custom_ops.advance_step(). See draft_model_runner. TODO(will): Move
# this logic to the backend.
# When using cudagraph, the num_seqs is padded to the next captured
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
# the batch. For --enforce-eager mode, num_seqs == num_queries
...
@@ -347,6 +346,16 @@ class FlashAttentionMetadata(AttentionMetadata):
...
@@ -347,6 +346,16 @@ class FlashAttentionMetadata(AttentionMetadata):
self
.
seq_lens
[
i
]
+=
1
self
.
seq_lens
[
i
]
+=
1
self
.
max_decode_seq_len
=
max
(
self
.
seq_lens
)
self
.
max_decode_seq_len
=
max
(
self
.
seq_lens
)
ops
.
advance_step
(
num_seqs
=
num_seqs
,
num_queries
=
num_queries
,
block_size
=
block_size
,
input_tokens
=
model_input
.
input_tokens
,
sampled_token_ids
=
sampled_token_ids
,
input_positions
=
model_input
.
input_positions
,
seq_lens
=
self
.
seq_lens_tensor
,
slot_mapping
=
self
.
slot_mapping
,
block_tables
=
self
.
block_tables
)
class
FlashAttentionMetadataBuilder
(
class
FlashAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
FlashAttentionMetadata
]):
AttentionMetadataBuilder
[
FlashAttentionMetadata
]):
...
...
vllm/spec_decode/draft_model_runner.py
View file @
5faedf1b
...
@@ -2,7 +2,6 @@ from typing import List, Optional
...
@@ -2,7 +2,6 @@ from typing import List, Optional
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
try
:
try
:
...
@@ -116,18 +115,9 @@ class TP1DraftModelRunner(ModelRunner):
...
@@ -116,18 +115,9 @@ class TP1DraftModelRunner(ModelRunner):
# Update attn_metadata
# Update attn_metadata
attn_metadata
=
model_input
.
attn_metadata
attn_metadata
=
model_input
.
attn_metadata
assert
isinstance
(
attn_metadata
,
FlashAttentionMetadata
)
assert
isinstance
(
attn_metadata
,
FlashAttentionMetadata
)
attn_metadata
.
advance_step
(
num_seqs
,
num_queries
)
attn_metadata
.
advance_step
(
model_input
,
sampled_token_ids
,
# Update GPU tensors
self
.
block_size
,
num_seqs
,
num_queries
)
ops
.
advance_step
(
num_seqs
=
num_seqs
,
num_queries
=
num_queries
,
block_size
=
self
.
block_size
,
input_tokens
=
model_input
.
input_tokens
,
sampled_token_ids
=
sampled_token_ids
,
input_positions
=
model_input
.
input_positions
,
seq_lens
=
attn_metadata
.
seq_lens_tensor
,
slot_mapping
=
attn_metadata
.
slot_mapping
,
block_tables
=
attn_metadata
.
block_tables
)
# Update sampling_metadata
# Update sampling_metadata
sampling_metadata
=
model_input
.
sampling_metadata
sampling_metadata
=
model_input
.
sampling_metadata
...
...
vllm/worker/multi_step_model_runner.py
View file @
5faedf1b
...
@@ -13,7 +13,6 @@ except ModuleNotFoundError:
...
@@ -13,7 +13,6 @@ except ModuleNotFoundError:
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.distributed
import
get_pp_group
from
vllm.distributed
import
get_pp_group
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.sampler
import
(
PromptLogprobs
,
SampleLogprobs
,
from
vllm.model_executor.layers.sampler
import
(
PromptLogprobs
,
SampleLogprobs
,
...
@@ -499,19 +498,11 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -499,19 +498,11 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
attn_metadata
=
frozen_model_input
.
attn_metadata
attn_metadata
=
frozen_model_input
.
attn_metadata
assert
isinstance
(
attn_metadata
,
FlashAttentionMetadata
)
assert
isinstance
(
attn_metadata
,
FlashAttentionMetadata
)
attn_metadata
.
advance_step
(
num_seqs
,
num_queries
)
attn_metadata
.
advance_step
(
# Update GPU tensors
frozen_model_input
,
ops
.
advance_step
(
model_input
.
cached_outputs
[
-
1
].
sampled_token_ids
,
self
.
block_size
,
num_seqs
=
num_seqs
,
num_seqs
,
num_queries
)
num_queries
=
num_queries
,
block_size
=
self
.
block_size
,
input_tokens
=
frozen_model_input
.
input_tokens
,
sampled_token_ids
=
model_input
.
cached_outputs
[
-
1
].
sampled_token_ids
,
input_positions
=
frozen_model_input
.
input_positions
,
seq_lens
=
attn_metadata
.
seq_lens_tensor
,
slot_mapping
=
attn_metadata
.
slot_mapping
,
block_tables
=
attn_metadata
.
block_tables
)
if
frozen_model_input
.
seq_lens
is
not
None
:
if
frozen_model_input
.
seq_lens
is
not
None
:
for
i
in
range
(
num_queries
):
for
i
in
range
(
num_queries
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment