Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
16b24e7d
Unverified
Commit
16b24e7d
authored
Oct 13, 2024
by
Tyler Michael Smith
Committed by
GitHub
Oct 13, 2024
Browse files
[Bugfix] Bandaid fix for speculative decoding tests (#9327)
parent
f519902c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
3 deletions
+18
-3
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+18
-3
No files found.
vllm/worker/model_runner.py
View file @
16b24e7d
...
@@ -17,6 +17,7 @@ import torch.nn as nn
...
@@ -17,6 +17,7 @@ import torch.nn as nn
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.attention.backends.abstract
import
AttentionState
from
vllm.attention.backends.abstract
import
AttentionState
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
...
@@ -1001,6 +1002,17 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1001,6 +1002,17 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
graph_block_tables
=
np
.
zeros
(
self
.
graph_block_tables
=
np
.
zeros
(
(
self
.
max_batchsize_to_capture
,
self
.
get_max_block_per_batch
()),
(
self
.
max_batchsize_to_capture
,
self
.
get_max_block_per_batch
()),
dtype
=
np
.
int32
)
dtype
=
np
.
int32
)
# Attention-free but stateful models like Mamba need a placeholder attn
# backend, as the attention metadata is needed to manage internal state.
# However we must bypass attention selection altogether for some models
# used for speculative decoding to avoid a divide-by-zero in
# model_config.get_head_size()
num_attn_heads
=
self
.
model_config
.
get_num_attention_heads
(
self
.
parallel_config
)
needs_attn_backend
=
(
num_attn_heads
!=
0
or
self
.
model_config
.
is_attention_free
)
self
.
attn_backend
=
get_attn_backend
(
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
get_sliding_window
(),
self
.
model_config
.
get_sliding_window
(),
...
@@ -1008,9 +1020,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1008,9 +1020,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
self
.
block_size
,
self
.
block_size
,
self
.
model_config
.
is_attention_free
,
self
.
model_config
.
is_attention_free
,
)
)
if
needs_attn_backend
else
None
self
.
attn_state
=
self
.
attn_backend
.
get_state_cls
()(
if
self
.
attn_backend
:
weakref
.
proxy
(
self
))
self
.
attn_state
=
self
.
attn_backend
.
get_state_cls
()(
weakref
.
proxy
(
self
))
else
:
self
.
attn_state
=
CommonAttentionState
(
weakref
.
proxy
(
self
))
# Multi-modal data support
# Multi-modal data support
self
.
input_registry
=
input_registry
self
.
input_registry
=
input_registry
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment