Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
07ab1607
Unverified
Commit
07ab1607
authored
Aug 09, 2024
by
Mor Zusman
Committed by
GitHub
Aug 09, 2024
Browse files
[Model][Jamba] Mamba cache single buffer (#6739)
Co-authored-by:
Mor Zusman
<
morz@ai21.com
>
parent
b4e9528f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
148 additions
and
124 deletions
+148
-124
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+148
-121
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+0
-3
No files found.
vllm/model_executor/models/jamba.py
View file @
07ab1607
...
...
@@ -609,12 +609,8 @@ class JambaForCausalLM(nn.Module, HasInnerState):
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
)
# Current step used indices
self
.
current_indices
:
List
[
int
]
=
[]
# Used to track and store by the Mamba cache between steps.
self
.
mamba_cache
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
=
tuple
()
# Used as an input_buffer for the CUDA graph runs.
self
.
mamba_gc_cache_buffer
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
=
tuple
()
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the self.mamba_cache
self
.
mamba_cache_indices_mapping
:
Dict
[
str
,
Dict
[
int
,
int
]]
=
{}
...
...
@@ -644,95 +640,148 @@ class JambaForCausalLM(nn.Module, HasInnerState):
batch_size
=
input_ids
.
shape
[
0
]
if
attn_metadata
.
prefill_metadata
:
batch_size
=
len
(
request_ids_to_seq_ids
)
(
current_seqlen_agnostic_cache
,
indices
,
)
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
batch_size
,
finished_requests_ids
)
mamba_cache
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
batch_size
,
finished_requests_ids
)
else
:
# CUDA graph capturing runs
current_seqlen_agnostic_cache
,
indices
=
(
kwargs
[
"seqlen_agnostic_capture_inputs"
],
[],
)
self
.
current_indices
=
indices
mamba_cache
=
kwargs
[
"seqlen_agnostic_capture_inputs"
]
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
current_seqlen_agnostic_cache
[
0
],
current_seqlen_agnostic_cache
[
1
])
if
"seqlen_agnostic_capture_inputs"
not
in
kwargs
:
self
.
_copy_mamba_cache_by_indices
(
self
.
current_indices
,
current_seqlen_agnostic_cache
)
attn_metadata
,
mamba_cache
[
0
],
mamba_cache
[
1
])
return
hidden_states
def
_
copy
_mamba_cache
_by_indices
(
self
,
indices
:
List
[
int
],
current_seqlen_agnostic_cache
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
])
:
for
i
,
offset
in
enumerate
(
indices
):
self
.
_copy_mamba_cache
(
offset
,
i
,
current_seqlen_agnostic_cache
)
def
_
swap
_mamba_cache
(
self
,
from_index
:
int
,
to_index
:
int
):
assert
len
(
self
.
mamba_cache
)
>
0
for
cache_t
in
self
.
mamba_cache
:
cache_t
[:,
[
to_index
,
from_index
]]
=
\
cache_t
[:,
[
from_index
,
to_index
]]
def
_copy_mamba_cache
(
self
,
index_to
:
int
,
index_from
:
int
,
from_buffer
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]):
def
_copy_mamba_cache
(
self
,
from_index
:
int
,
to_index
:
int
):
assert
len
(
self
.
mamba_cache
)
>
0
for
(
cache_t
,
from_buffer_t
)
in
zip
(
self
.
mamba_cache
,
from_buffer
)
:
cache_t
[:,
index
_to
].
copy_
(
from_buffer_t
[:,
index_from
],
for
cache_t
in
self
.
mamba_cache
:
cache_t
[:,
to_
index
].
copy_
(
cache_t
[:,
from_index
],
non_blocking
=
True
)
def
_assign_seq_id_to_mamba_cache
(
self
,
cur_rid
:
str
,
seqs_id
:
List
[
int
])
->
List
[
int
]:
indices_for_current_run
=
[]
for
seq_id
in
seqs_id
:
if
cur_rid
not
in
self
.
mamba_cache_indices_mapping
:
self
.
mamba_cache_indices_mapping
[
cur_rid
]
=
{}
def
_move_out_if_already_occupied
(
self
,
index
:
int
,
all_occupied_indices
:
List
[
int
]):
if
index
in
all_occupied_indices
:
first_free_index
=
self
.
_first_free_index_in_mamba_cache
()
self
.
mamba_cache_indices_mapping
[
cur_rid
][
seq_id
]
=
first_free_index
index_for_current_run
=
first_free_index
## case of decoding n>1, copy prefill cache to decoding indices
# In case occupied, move the occupied to a new empty block
self
.
_move_cache_index_and_mappings
(
from_index
=
index
,
to_index
=
first_free_index
)
def
_assign_seq_id_to_mamba_cache_in_specific_dest
(
self
,
cur_rid
:
str
,
seq_id
:
int
,
destination_index
:
int
):
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
all_occupied_indices
=
self
.
_get_all_occupied_indices
()
if
cur_rid
not
in
self
.
mamba_cache_indices_mapping
:
self
.
_move_out_if_already_occupied
(
index
=
destination_index
,
all_occupied_indices
=
all_occupied_indices
)
self
.
mamba_cache_indices_mapping
[
cur_rid
]
=
{
seq_id
:
destination_index
}
elif
seq_id
not
in
(
seq_ids2indices
:
=
self
.
mamba_cache_indices_mapping
[
cur_rid
]):
first_free_index
=
self
.
_first_free_index_in_mamba_cache
()
index_exist
=
list
(
seq_ids2indices
.
values
())[
0
]
self
.
_copy_mamba_cache
(
index_from
=
index_exist
,
index_to
=
first_free_index
,
from_buffer
=
self
.
mamba_cache
)
# parallel sampling , where n > 1, assume prefill have
# already happened now we only need to copy the already
# existing cache into the siblings seq_ids caches
self
.
_move_out_if_already_occupied
(
index
=
destination_index
,
all_occupied_indices
=
all_occupied_indices
)
index_exists
=
list
(
seq_ids2indices
.
values
())[
0
]
# case of decoding n>1, copy prefill cache to decoding indices
self
.
_copy_mamba_cache
(
from_index
=
index_exists
,
to_index
=
destination_index
)
self
.
mamba_cache_indices_mapping
[
cur_rid
][
seq_id
]
=
first_free_index
index_for_current_run
=
first_free_index
seq_id
]
=
destination_index
else
:
index_for_current_run
=
self
.
mamba_cache_indices_mapping
[
# already exists
cache_index_already_exists
=
self
.
mamba_cache_indices_mapping
[
cur_rid
][
seq_id
]
indices_for_current_run
.
append
(
index_for_current_run
)
return
indices_for_current_run
if
cache_index_already_exists
!=
destination_index
:
# In case the seq id already exists but not in
# the right destination, swap it with what's occupying it
self
.
_swap_pair_indices_and_mappings
(
from_index
=
cache_index_already_exists
,
to_index
=
destination_index
)
def
_prepare_current_run_mamba_cache
(
self
,
request_ids_to_seq_ids
:
Dict
[
str
,
list
[
int
]],
batch_size
:
int
,
finished_requests_ids
:
List
[
str
]
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
List
[
int
]]:
indices_for_current_run
=
[]
for
request_id
,
seqs_id
in
request_ids_to_seq_ids
.
items
():
self
,
request_ids_to_seq_ids
:
Dict
[
str
,
list
[
int
]],
batch_size
:
int
,
finished_requests_ids
:
List
[
str
]):
running_indices
=
[]
request_ids_to_seq_ids_flatten
=
[
(
req_id
,
seq_id
)
for
req_id
,
seq_ids
in
request_ids_to_seq_ids
.
items
()
for
seq_id
in
seq_ids
]
for
dest_index
,
(
request_id
,
seq_id
)
in
enumerate
(
request_ids_to_seq_ids_flatten
):
if
request_id
in
finished_requests_ids
:
# Do not allocate cache for requests that run
# Do not allocate cache
index
for requests that run
# and finish right after
continue
indices_for_current_run
+=
self
.
_assign_seq_id_to_mamba_cache
(
request_id
,
seqs_id
)
## Pad the batch in case of running batch that was not captured via CG
padded_indices
=
indices_for_current_run
.
copy
()
pad_index
=
self
.
_first_free_index_in_mamba_cache
()
self
.
_assign_seq_id_to_mamba_cache_in_specific_dest
(
request_id
,
seq_id
,
dest_index
)
running_indices
.
append
(
dest_index
)
self
.
_clean_up_first_bs_blocks
(
batch_size
,
running_indices
)
conv_state
=
self
.
mamba_cache
[
0
][:,
:
batch_size
]
temporal_state
=
self
.
mamba_cache
[
1
][:,
:
batch_size
]
for
_
in
range
(
batch_size
-
len
(
indices_for_current_run
)):
padded_indices
.
append
(
pad_index
)
return
(
conv_state
,
temporal_state
)
conv_state
=
self
.
mamba_cache
[
0
][:,
padded_indices
]
temporal_state
=
self
.
mamba_cache
[
1
][:,
padded_indices
]
def
_get_all_occupied_indices
(
self
):
return
[
cache_idx
for
seq_ids2indices
in
self
.
mamba_cache_indices_mapping
.
values
()
for
cache_idx
in
seq_ids2indices
.
values
()
]
return
(
conv_state
,
temporal_state
),
indices_for_current_run
def
_clean_up_first_bs_blocks
(
self
,
batch_size
:
int
,
indices_for_current_run
:
List
[
int
]):
# move out all of the occupied but currently not running blocks
# outside of the first n blocks
destination_indices
=
set
([
range
(
batch_size
)])
max_possible_batch_size
=
self
.
mamba_cache
[
0
].
shape
[
1
]
for
destination_index
in
destination_indices
:
if
destination_index
in
self
.
_get_all_occupied_indices
()
and
\
destination_index
not
in
indices_for_current_run
:
# move not running indices outside of the batch
all_other_indices
=
list
(
range
(
batch_size
,
max_possible_batch_size
))
first_avail_index
=
self
.
_first_free_index_in_mamba_cache
(
all_other_indices
)
self
.
_swap_indices
(
from_index
=
destination_index
,
to_index
=
first_avail_index
)
def
_move_cache_index_and_mappings
(
self
,
from_index
:
int
,
to_index
:
int
):
self
.
_copy_mamba_cache
(
from_index
=
from_index
,
to_index
=
to_index
)
self
.
_update_mapping_index
(
from_index
=
from_index
,
to_index
=
to_index
)
def
_swap_pair_indices_and_mappings
(
self
,
from_index
:
int
,
to_index
:
int
):
self
.
_swap_mamba_cache
(
from_index
=
from_index
,
to_index
=
to_index
)
self
.
_swap_mapping_index
(
from_index
=
from_index
,
to_index
=
to_index
)
def
_swap_mapping_index
(
self
,
from_index
:
int
,
to_index
:
int
):
for
seq_ids2index
in
self
.
mamba_cache_indices_mapping
.
values
():
for
seq_id
,
index
in
seq_ids2index
.
items
():
if
from_index
==
index
:
seq_ids2index
.
update
({
seq_id
:
to_index
})
elif
to_index
==
index
:
seq_ids2index
.
update
({
seq_id
:
from_index
})
def
_update_mapping_index
(
self
,
from_index
:
int
,
to_index
:
int
):
for
seq_ids2index
in
self
.
mamba_cache_indices_mapping
.
values
():
for
seq_id
,
index
in
seq_ids2index
.
items
():
if
from_index
==
index
:
seq_ids2index
.
update
({
seq_id
:
to_index
})
return
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
"""
...
...
@@ -747,28 +796,9 @@ class JambaForCausalLM(nn.Module, HasInnerState):
self
.
_release_mamba_cache
(
finished_requests_ids
)
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
cg_batch_size
=
input_buffers
[
'input_ids'
].
shape
[
0
]
(
current_mamba_cache
,
indices
,
)
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
cg_batch_size
,
finished_requests_ids
)
self
.
current_indices
=
indices
for
input_buffer
,
current_cache_buffer
in
zip
(
input_buffers
[
"seqlen_agnostic_capture_inputs"
],
current_mamba_cache
):
input_buffer
.
copy_
(
current_cache_buffer
,
non_blocking
=
True
)
def
copy_outputs_after_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
"""
Copy the relevant Mamba cache from the CUDA graph input_buffers
back to the JambaForCausalLM.mamba_cache after CUDA
graph replay run is done.
"""
self
.
_copy_mamba_cache_by_indices
(
self
.
current_indices
,
input_buffers
[
"seqlen_agnostic_capture_inputs"
])
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
):
"""
...
...
@@ -776,26 +806,25 @@ class JambaForCausalLM(nn.Module, HasInnerState):
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
"""
return
tuple
(
buffer
[:,
:
batch_size
]
for
buffer
in
self
.
mamba_gc_cache_buffer
)
return
tuple
(
buffer
[:,
:
batch_size
]
for
buffer
in
self
.
mamba_cache
)
def
_release_mamba_cache
(
self
,
finished_seq_groups_req_ids
:
List
[
str
]):
for
req_id
in
finished_seq_groups_req_ids
:
if
req_id
in
self
.
mamba_cache_indices_mapping
:
self
.
mamba_cache_indices_mapping
.
pop
(
req_id
)
def
_first_free_index_in_mamba_cache
(
self
)
->
int
:
if
self
.
mamba_cache
:
def
_first_free_index_in_mamba_cache
(
self
,
indices_range
:
Optional
[
List
[
int
]]
=
None
)
->
int
:
assert
self
.
mamba_cache
is
not
None
if
indices_range
is
None
:
max_possible_batch_size
=
self
.
mamba_cache
[
0
].
shape
[
1
]
occupied
=
[
id
for
seq_ids
in
self
.
mamba_cache_indices_mapping
.
values
()
for
id
in
seq_ids
.
values
()
]
first_free_index
=
[
i
not
in
occupied
for
i
in
range
(
max_possible_batch_size
)
].
index
(
True
)
return
first_free_index
return
0
indices_range
=
list
(
range
(
max_possible_batch_size
))
all_occupied_indices
=
self
.
_get_all_occupied_indices
()
for
i
in
indices_range
:
if
i
not
in
all_occupied_indices
:
return
i
raise
Exception
(
"Couldn't find a free spot in the mamba cache! This"
"should never happen"
)
def
_get_mamba_cache_shape
(
self
...
...
@@ -819,12 +848,11 @@ class JambaForCausalLM(nn.Module, HasInnerState):
[
layer_type
==
"mamba"
for
layer_type
in
layers_type
])
max_batch_size
=
(
_get_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
if
self
.
scheduler_config
else
max
(
_BATCH_SIZES_TO_CAPTURE
)
)
+
10
max
(
_BATCH_SIZES_TO_CAPTURE
)
+
2
)
conv_state_shape
,
temporal_state_shape
=
self
.
_get_mamba_cache_shape
()
assert
conv_state_shape
is
not
None
and
temporal_state_shape
is
not
None
for
buffername
in
[
"mamba_cache"
,
"mamba_gc_cache_buffer"
]:
buffer
=
(
torch
.
empty
(
size
=
(
mamba_layers
,
max_batch_size
)
+
self
.
mamba_cache
=
(
torch
.
empty
(
size
=
(
mamba_layers
,
max_batch_size
)
+
conv_state_shape
,
dtype
=
dtype
,
device
=
"cuda"
),
...
...
@@ -832,7 +860,6 @@ class JambaForCausalLM(nn.Module, HasInnerState):
temporal_state_shape
,
dtype
=
dtype
,
device
=
"cuda"
))
setattr
(
self
,
buffername
,
buffer
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
...
...
vllm/worker/model_runner.py
View file @
07ab1607
...
...
@@ -1711,9 +1711,6 @@ class CUDAGraphRunner:
non_blocking
=
True
)
# Run the graph.
self
.
graph
.
replay
()
if
"seqlen_agnostic_capture_inputs"
in
self
.
input_buffers
:
self
.
model
.
copy_outputs_after_cuda_graphs
(
self
.
input_buffers
,
**
kwargs
)
# Return the output tensor.
if
get_pp_group
().
is_last_rank
:
return
self
.
output_buffers
[
"hidden_states"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment