Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
07ab1607
Unverified
Commit
07ab1607
authored
Aug 09, 2024
by
Mor Zusman
Committed by
GitHub
Aug 09, 2024
Browse files
[Model][Jamba] Mamba cache single buffer (#6739)
Co-authored-by:
Mor Zusman
<
morz@ai21.com
>
parent
b4e9528f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
148 additions
and
124 deletions
+148
-124
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+148
-121
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+0
-3
No files found.
vllm/model_executor/models/jamba.py
View file @
07ab1607
...
@@ -609,12 +609,8 @@ class JambaForCausalLM(nn.Module, HasInnerState):
...
@@ -609,12 +609,8 @@ class JambaForCausalLM(nn.Module, HasInnerState):
# compatibility
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
)
)
# Current step used indices
self
.
current_indices
:
List
[
int
]
=
[]
# Used to track and store by the Mamba cache between steps.
# Used to track and store by the Mamba cache between steps.
self
.
mamba_cache
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
=
tuple
()
self
.
mamba_cache
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
=
tuple
()
# Used as an input_buffer for the CUDA graph runs.
self
.
mamba_gc_cache_buffer
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
=
tuple
()
# Maps between the request id and a dict that maps between the seq_id
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the self.mamba_cache
# and its index inside the self.mamba_cache
self
.
mamba_cache_indices_mapping
:
Dict
[
str
,
Dict
[
int
,
int
]]
=
{}
self
.
mamba_cache_indices_mapping
:
Dict
[
str
,
Dict
[
int
,
int
]]
=
{}
...
@@ -644,95 +640,148 @@ class JambaForCausalLM(nn.Module, HasInnerState):
...
@@ -644,95 +640,148 @@ class JambaForCausalLM(nn.Module, HasInnerState):
batch_size
=
input_ids
.
shape
[
0
]
batch_size
=
input_ids
.
shape
[
0
]
if
attn_metadata
.
prefill_metadata
:
if
attn_metadata
.
prefill_metadata
:
batch_size
=
len
(
request_ids_to_seq_ids
)
batch_size
=
len
(
request_ids_to_seq_ids
)
(
mamba_cache
=
self
.
_prepare_current_run_mamba_cache
(
current_seqlen_agnostic_cache
,
request_ids_to_seq_ids
,
batch_size
,
finished_requests_ids
)
indices
,
)
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
batch_size
,
finished_requests_ids
)
else
:
else
:
# CUDA graph capturing runs
# CUDA graph capturing runs
current_seqlen_agnostic_cache
,
indices
=
(
mamba_cache
=
kwargs
[
"seqlen_agnostic_capture_inputs"
]
kwargs
[
"seqlen_agnostic_capture_inputs"
],
[],
)
self
.
current_indices
=
indices
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
attn_metadata
,
mamba_cache
[
0
],
current_seqlen_agnostic_cache
[
0
],
mamba_cache
[
1
])
current_seqlen_agnostic_cache
[
1
])
if
"seqlen_agnostic_capture_inputs"
not
in
kwargs
:
self
.
_copy_mamba_cache_by_indices
(
self
.
current_indices
,
current_seqlen_agnostic_cache
)
return
hidden_states
return
hidden_states
def
_
copy
_mamba_cache
_by_indices
(
def
_
swap
_mamba_cache
(
self
,
from_index
:
int
,
to_index
:
int
):
self
,
indices
:
List
[
int
],
assert
len
(
self
.
mamba_cache
)
>
0
current_seqlen_agnostic_cache
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
])
:
for
cache_t
in
self
.
mamba_cache
:
for
i
,
offset
in
enumerate
(
indices
):
cache_t
[:,
[
to_index
,
from_index
]]
=
\
self
.
_copy_mamba_cache
(
offset
,
i
,
current_seqlen_agnostic_cache
)
cache_t
[:,
[
from_index
,
to_index
]]
def
_copy_mamba_cache
(
self
,
index_to
:
int
,
index_from
:
int
,
def
_copy_mamba_cache
(
self
,
from_index
:
int
,
to_index
:
int
):
from_buffer
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]):
assert
len
(
self
.
mamba_cache
)
>
0
assert
len
(
self
.
mamba_cache
)
>
0
for
(
cache_t
,
from_buffer_t
)
in
zip
(
self
.
mamba_cache
,
from_buffer
)
:
for
cache_t
in
self
.
mamba_cache
:
cache_t
[:,
index
_to
].
copy_
(
from_buffer_t
[:,
index_from
],
cache_t
[:,
to_
index
].
copy_
(
cache_t
[:,
from_index
],
non_blocking
=
True
)
non_blocking
=
True
)
def
_assign_seq_id_to_mamba_cache
(
self
,
cur_rid
:
str
,
def
_move_out_if_already_occupied
(
self
,
index
:
int
,
seqs_id
:
List
[
int
])
->
List
[
int
]:
all_occupied_indices
:
List
[
int
]):
indices_for_current_run
=
[]
if
index
in
all_occupied_indices
:
for
seq_id
in
seqs_id
:
first_free_index
=
self
.
_first_free_index_in_mamba_cache
()
if
cur_rid
not
in
self
.
mamba_cache_indices_mapping
:
# In case occupied, move the occupied to a new empty block
self
.
mamba_cache_indices_mapping
[
cur_rid
]
=
{}
self
.
_move_cache_index_and_mappings
(
from_index
=
index
,
first_free_index
=
self
.
_first_free_index_in_mamba_cache
()
to_index
=
first_free_index
)
self
.
mamba_cache_indices_mapping
[
cur_rid
][
seq_id
]
=
first_free_index
def
_assign_seq_id_to_mamba_cache_in_specific_dest
(
self
,
cur_rid
:
str
,
index_for_current_run
=
first_free_index
seq_id
:
int
,
## case of decoding n>1, copy prefill cache to decoding indices
destination_index
:
int
):
elif
seq_id
not
in
(
seq_ids2indices
:
=
"""
self
.
mamba_cache_indices_mapping
[
cur_rid
]):
Assign (req_id,seq_id) pair to a `destination_index` index, if
first_free_index
=
self
.
_first_free_index_in_mamba_cache
()
already occupied, move the occupying index to a free index.
index_exist
=
list
(
seq_ids2indices
.
values
())[
0
]
"""
self
.
_copy_mamba_cache
(
index_from
=
index_exist
,
all_occupied_indices
=
self
.
_get_all_occupied_indices
()
index_to
=
first_free_index
,
if
cur_rid
not
in
self
.
mamba_cache_indices_mapping
:
from_buffer
=
self
.
mamba_cache
)
self
.
_move_out_if_already_occupied
(
self
.
mamba_cache_indices_mapping
[
cur_rid
][
index
=
destination_index
,
seq_id
]
=
first_free_index
all_occupied_indices
=
all_occupied_indices
)
index_for_current_run
=
first_free_index
self
.
mamba_cache_indices_mapping
[
cur_rid
]
=
{
else
:
seq_id
:
destination_index
index_for_current_run
=
self
.
mamba_cache_indices_mapping
[
}
cur_rid
][
seq_id
]
elif
seq_id
not
in
(
seq_ids2indices
:
=
self
.
mamba_cache_indices_mapping
[
cur_rid
]):
indices_for_current_run
.
append
(
index_for_current_run
)
# parallel sampling , where n > 1, assume prefill have
return
indices_for_current_run
# already happened now we only need to copy the already
# existing cache into the siblings seq_ids caches
self
.
_move_out_if_already_occupied
(
index
=
destination_index
,
all_occupied_indices
=
all_occupied_indices
)
index_exists
=
list
(
seq_ids2indices
.
values
())[
0
]
# case of decoding n>1, copy prefill cache to decoding indices
self
.
_copy_mamba_cache
(
from_index
=
index_exists
,
to_index
=
destination_index
)
self
.
mamba_cache_indices_mapping
[
cur_rid
][
seq_id
]
=
destination_index
else
:
# already exists
cache_index_already_exists
=
self
.
mamba_cache_indices_mapping
[
cur_rid
][
seq_id
]
if
cache_index_already_exists
!=
destination_index
:
# In case the seq id already exists but not in
# the right destination, swap it with what's occupying it
self
.
_swap_pair_indices_and_mappings
(
from_index
=
cache_index_already_exists
,
to_index
=
destination_index
)
def
_prepare_current_run_mamba_cache
(
def
_prepare_current_run_mamba_cache
(
self
,
request_ids_to_seq_ids
:
Dict
[
str
,
list
[
int
]],
batch_size
:
int
,
self
,
request_ids_to_seq_ids
:
Dict
[
str
,
list
[
int
]],
finished_requests_ids
:
List
[
str
]
batch_size
:
int
,
finished_requests_ids
:
List
[
str
]):
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
List
[
int
]]:
running_indices
=
[]
indices_for_current_run
=
[]
request_ids_to_seq_ids_flatten
=
[
for
request_id
,
seqs_id
in
request_ids_to_seq_ids
.
items
():
(
req_id
,
seq_id
)
for
req_id
,
seq_ids
in
request_ids_to_seq_ids
.
items
()
for
seq_id
in
seq_ids
]
for
dest_index
,
(
request_id
,
seq_id
)
in
enumerate
(
request_ids_to_seq_ids_flatten
):
if
request_id
in
finished_requests_ids
:
if
request_id
in
finished_requests_ids
:
# Do not allocate cache for requests that run
# Do not allocate cache
index
for requests that run
# and finish right after
# and finish right after
continue
continue
indices_for_current_run
+=
self
.
_assign_seq_id_to_mamba_cache
(
self
.
_assign_seq_id_to_mamba_cache_in_specific_dest
(
request_id
,
seqs_id
)
request_id
,
seq_id
,
dest_index
)
## Pad the batch in case of running batch that was not captured via CG
running_indices
.
append
(
dest_index
)
padded_indices
=
indices_for_current_run
.
copy
()
pad_index
=
self
.
_first_free_index_in_mamba_cache
()
for
_
in
range
(
batch_size
-
len
(
indices_for_current_run
)):
self
.
_clean_up_first_bs_blocks
(
batch_size
,
running_indices
)
padded_indices
.
append
(
pad_index
)
conv_state
=
self
.
mamba_cache
[
0
][:,
:
batch_size
]
temporal_state
=
self
.
mamba_cache
[
1
][:,
:
batch_size
]
conv_state
=
self
.
mamba_cache
[
0
][:,
padded_indices
]
return
(
conv_state
,
temporal_state
)
temporal_state
=
self
.
mamba_cache
[
1
][:,
padded_indices
]
return
(
conv_state
,
temporal_state
),
indices_for_current_run
def
_get_all_occupied_indices
(
self
):
return
[
cache_idx
for
seq_ids2indices
in
self
.
mamba_cache_indices_mapping
.
values
()
for
cache_idx
in
seq_ids2indices
.
values
()
]
def
_clean_up_first_bs_blocks
(
self
,
batch_size
:
int
,
indices_for_current_run
:
List
[
int
]):
# move out all of the occupied but currently not running blocks
# outside of the first n blocks
destination_indices
=
set
([
range
(
batch_size
)])
max_possible_batch_size
=
self
.
mamba_cache
[
0
].
shape
[
1
]
for
destination_index
in
destination_indices
:
if
destination_index
in
self
.
_get_all_occupied_indices
()
and
\
destination_index
not
in
indices_for_current_run
:
# move not running indices outside of the batch
all_other_indices
=
list
(
range
(
batch_size
,
max_possible_batch_size
))
first_avail_index
=
self
.
_first_free_index_in_mamba_cache
(
all_other_indices
)
self
.
_swap_indices
(
from_index
=
destination_index
,
to_index
=
first_avail_index
)
def
_move_cache_index_and_mappings
(
self
,
from_index
:
int
,
to_index
:
int
):
self
.
_copy_mamba_cache
(
from_index
=
from_index
,
to_index
=
to_index
)
self
.
_update_mapping_index
(
from_index
=
from_index
,
to_index
=
to_index
)
def
_swap_pair_indices_and_mappings
(
self
,
from_index
:
int
,
to_index
:
int
):
self
.
_swap_mamba_cache
(
from_index
=
from_index
,
to_index
=
to_index
)
self
.
_swap_mapping_index
(
from_index
=
from_index
,
to_index
=
to_index
)
def
_swap_mapping_index
(
self
,
from_index
:
int
,
to_index
:
int
):
for
seq_ids2index
in
self
.
mamba_cache_indices_mapping
.
values
():
for
seq_id
,
index
in
seq_ids2index
.
items
():
if
from_index
==
index
:
seq_ids2index
.
update
({
seq_id
:
to_index
})
elif
to_index
==
index
:
seq_ids2index
.
update
({
seq_id
:
from_index
})
def
_update_mapping_index
(
self
,
from_index
:
int
,
to_index
:
int
):
for
seq_ids2index
in
self
.
mamba_cache_indices_mapping
.
values
():
for
seq_id
,
index
in
seq_ids2index
.
items
():
if
from_index
==
index
:
seq_ids2index
.
update
({
seq_id
:
to_index
})
return
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
"""
"""
...
@@ -747,28 +796,9 @@ class JambaForCausalLM(nn.Module, HasInnerState):
...
@@ -747,28 +796,9 @@ class JambaForCausalLM(nn.Module, HasInnerState):
self
.
_release_mamba_cache
(
finished_requests_ids
)
self
.
_release_mamba_cache
(
finished_requests_ids
)
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
cg_batch_size
=
input_buffers
[
'input_ids'
].
shape
[
0
]
cg_batch_size
=
input_buffers
[
'input_ids'
].
shape
[
0
]
(
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
current_mamba_cache
,
cg_batch_size
,
indices
,
finished_requests_ids
)
)
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
cg_batch_size
,
finished_requests_ids
)
self
.
current_indices
=
indices
for
input_buffer
,
current_cache_buffer
in
zip
(
input_buffers
[
"seqlen_agnostic_capture_inputs"
],
current_mamba_cache
):
input_buffer
.
copy_
(
current_cache_buffer
,
non_blocking
=
True
)
def
copy_outputs_after_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
"""
Copy the relevant Mamba cache from the CUDA graph input_buffers
back to the JambaForCausalLM.mamba_cache after CUDA
graph replay run is done.
"""
self
.
_copy_mamba_cache_by_indices
(
self
.
current_indices
,
input_buffers
[
"seqlen_agnostic_capture_inputs"
])
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
):
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
):
"""
"""
...
@@ -776,26 +806,25 @@ class JambaForCausalLM(nn.Module, HasInnerState):
...
@@ -776,26 +806,25 @@ class JambaForCausalLM(nn.Module, HasInnerState):
The buffer is used to maintain the Mamba Cache during the CUDA graph
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
replay runs.
"""
"""
return
tuple
(
buffer
[:,
:
batch_size
]
return
tuple
(
buffer
[:,
:
batch_size
]
for
buffer
in
self
.
mamba_cache
)
for
buffer
in
self
.
mamba_gc_cache_buffer
)
def
_release_mamba_cache
(
self
,
finished_seq_groups_req_ids
:
List
[
str
]):
def
_release_mamba_cache
(
self
,
finished_seq_groups_req_ids
:
List
[
str
]):
for
req_id
in
finished_seq_groups_req_ids
:
for
req_id
in
finished_seq_groups_req_ids
:
if
req_id
in
self
.
mamba_cache_indices_mapping
:
if
req_id
in
self
.
mamba_cache_indices_mapping
:
self
.
mamba_cache_indices_mapping
.
pop
(
req_id
)
self
.
mamba_cache_indices_mapping
.
pop
(
req_id
)
def
_first_free_index_in_mamba_cache
(
self
)
->
int
:
def
_first_free_index_in_mamba_cache
(
if
self
.
mamba_cache
:
self
,
indices_range
:
Optional
[
List
[
int
]]
=
None
)
->
int
:
assert
self
.
mamba_cache
is
not
None
if
indices_range
is
None
:
max_possible_batch_size
=
self
.
mamba_cache
[
0
].
shape
[
1
]
max_possible_batch_size
=
self
.
mamba_cache
[
0
].
shape
[
1
]
occupied
=
[
indices_range
=
list
(
range
(
max_possible_batch_size
))
id
for
seq_ids
in
self
.
mamba_cache_indices_mapping
.
values
()
all_occupied_indices
=
self
.
_get_all_occupied_indices
()
for
id
in
seq_ids
.
values
()
for
i
in
indices_range
:
]
if
i
not
in
all_occupied_indices
:
first_free_index
=
[
return
i
i
not
in
occupied
for
i
in
range
(
max_possible_batch_size
)
raise
Exception
(
"Couldn't find a free spot in the mamba cache! This"
].
index
(
True
)
"should never happen"
)
return
first_free_index
return
0
def
_get_mamba_cache_shape
(
def
_get_mamba_cache_shape
(
self
self
...
@@ -819,20 +848,18 @@ class JambaForCausalLM(nn.Module, HasInnerState):
...
@@ -819,20 +848,18 @@ class JambaForCausalLM(nn.Module, HasInnerState):
[
layer_type
==
"mamba"
for
layer_type
in
layers_type
])
[
layer_type
==
"mamba"
for
layer_type
in
layers_type
])
max_batch_size
=
(
_get_graph_batch_size
(
max_batch_size
=
(
_get_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
if
self
.
scheduler_config
else
self
.
scheduler_config
.
max_num_seqs
)
if
self
.
scheduler_config
else
max
(
_BATCH_SIZES_TO_CAPTURE
)
)
+
10
max
(
_BATCH_SIZES_TO_CAPTURE
)
+
2
)
conv_state_shape
,
temporal_state_shape
=
self
.
_get_mamba_cache_shape
()
conv_state_shape
,
temporal_state_shape
=
self
.
_get_mamba_cache_shape
()
assert
conv_state_shape
is
not
None
and
temporal_state_shape
is
not
None
assert
conv_state_shape
is
not
None
and
temporal_state_shape
is
not
None
for
buffername
in
[
"mamba_cache"
,
"mamba_gc_cache_buffer"
]:
self
.
mamba_cache
=
(
torch
.
empty
(
size
=
(
mamba_layers
,
max_batch_size
)
+
buffer
=
(
torch
.
empty
(
size
=
(
mamba_layers
,
max_batch_size
)
+
conv_state_shape
,
conv_state_shape
,
dtype
=
dtype
,
dtype
=
dtype
,
device
=
"cuda"
),
device
=
"cuda"
),
torch
.
empty
(
size
=
(
mamba_layers
,
max_batch_size
)
+
torch
.
empty
(
size
=
(
mamba_layers
,
max_batch_size
)
+
temporal_state_shape
,
temporal_state_shape
,
dtype
=
dtype
,
dtype
=
dtype
,
device
=
"cuda"
))
device
=
"cuda"
))
setattr
(
self
,
buffername
,
buffer
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
...
...
vllm/worker/model_runner.py
View file @
07ab1607
...
@@ -1711,9 +1711,6 @@ class CUDAGraphRunner:
...
@@ -1711,9 +1711,6 @@ class CUDAGraphRunner:
non_blocking
=
True
)
non_blocking
=
True
)
# Run the graph.
# Run the graph.
self
.
graph
.
replay
()
self
.
graph
.
replay
()
if
"seqlen_agnostic_capture_inputs"
in
self
.
input_buffers
:
self
.
model
.
copy_outputs_after_cuda_graphs
(
self
.
input_buffers
,
**
kwargs
)
# Return the output tensor.
# Return the output tensor.
if
get_pp_group
().
is_last_rank
:
if
get_pp_group
().
is_last_rank
:
return
self
.
output_buffers
[
"hidden_states"
]
return
self
.
output_buffers
[
"hidden_states"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment