Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9ad32dac
Unverified
Commit
9ad32dac
authored
Jul 16, 2024
by
Mor Zusman
Committed by
GitHub
Jul 16, 2024
Browse files
[BugFix][Model] Jamba - Handle aborted requests, Add tests and fix cleanup bug (#6425)
Co-authored-by:
Mor Zusman
<
morz@ai21.com
>
parent
d6f3b3d5
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
176 additions
and
24 deletions
+176
-24
tests/models/test_jamba.py
tests/models/test_jamba.py
+83
-0
vllm/core/scheduler.py
vllm/core/scheduler.py
+1
-0
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+20
-12
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+47
-1
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+25
-11
No files found.
tests/models/test_jamba.py
View file @
9ad32dac
import
pytest
import
pytest
from
tests.models.utils
import
check_outputs_equal
from
vllm.worker.model_runner
import
_get_graph_batch_size
from
vllm.worker.model_runner
import
_get_graph_batch_size
MODELS
=
[
"ai21labs/Jamba-tiny-random"
]
MODELS
=
[
"ai21labs/Jamba-tiny-random"
]
...
@@ -34,6 +35,34 @@ def test_models(
...
@@ -34,6 +35,34 @@ def test_models(
f
"Test
{
i
}
:
\n
HF:
{
hf_output_ids
}
\n
vLLM:
{
vllm_output_ids
}
"
)
f
"Test
{
i
}
:
\n
HF:
{
hf_output_ids
}
\n
vLLM:
{
vllm_output_ids
}
"
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
20
])
def
test_batching
(
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
)
->
None
:
# To pass the small model tests, we need full precision.
for_loop_outputs
=
[]
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
for
prompt
in
example_prompts
:
for_loop_outputs
.
append
(
vllm_model
.
generate_greedy
([
prompt
],
max_tokens
)[
0
])
batched_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
check_outputs_equal
(
outputs_0_lst
=
for_loop_outputs
,
outputs_1_lst
=
batched_outputs
,
name_0
=
"for_loop_vllm"
,
name_1
=
"batched_vllm"
,
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
20
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
20
])
...
@@ -60,6 +89,60 @@ def test_mamba_cache_cg_padding(
...
@@ -60,6 +89,60 @@ def test_mamba_cache_cg_padding(
"Could be related to mamba cache not padded correctly"
)
"Could be related to mamba cache not padded correctly"
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
20
])
def
test_models_preemption_recompute
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
)
->
None
:
# Tests that outputs are identical with and w/o preemtions (recompute)
assert
dtype
==
"float"
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
vllm_model
.
model
.
llm_engine
.
scheduler
[
0
].
ENABLE_ARTIFICIAL_PREEMPT
=
True
preempt_vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
vllm_model
.
model
.
llm_engine
.
scheduler
[
0
].
ENABLE_ARTIFICIAL_PREEMPT
=
False
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
check_outputs_equal
(
outputs_0_lst
=
preempt_vllm_outputs
,
outputs_1_lst
=
vllm_outputs
,
name_0
=
"vllm_preepmtions"
,
name_1
=
"vllm"
,
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
def
test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks
(
vllm_runner
,
model
:
str
,
dtype
:
str
,
example_prompts
,
)
->
None
:
# This test is for verifying that the Jamba inner state management doesn't
# collapse in case where the number of incoming requests and
# finished_requests_ids is larger than the maximum mamba block capacity.
# This could generally happen due to the fact that Jamba does support
# statelessness mechanism where it can cleanup new incoming requests in
# a single step.
try
:
with
vllm_runner
(
model
,
dtype
=
dtype
,
max_num_seqs
=
10
)
as
vllm_model
:
vllm_model
.
generate_greedy
([
example_prompts
[
0
]]
*
100
,
10
)
except
ValueError
:
pytest
.
fail
(
"Jamba inner state wasn't cleaned up properly between"
"steps finished requests registered unnecessarily "
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
def
test_state_cleanup
(
def
test_state_cleanup
(
...
...
vllm/core/scheduler.py
View file @
9ad32dac
...
@@ -374,6 +374,7 @@ class Scheduler:
...
@@ -374,6 +374,7 @@ class Scheduler:
for
aborted_group
in
aborted_groups
:
for
aborted_group
in
aborted_groups
:
# Remove the sequence group from the state queue.
# Remove the sequence group from the state queue.
state_queue
.
remove
(
aborted_group
)
state_queue
.
remove
(
aborted_group
)
self
.
_finished_requests_ids
.
append
(
aborted_group
.
request_id
)
for
seq
in
aborted_group
.
get_seqs
():
for
seq
in
aborted_group
.
get_seqs
():
if
seq
.
is_finished
():
if
seq
.
is_finished
():
continue
continue
...
...
vllm/model_executor/model_loader/loader.py
View file @
9ad32dac
...
@@ -32,7 +32,8 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -32,7 +32,8 @@ from vllm.model_executor.model_loader.weight_utils import (
filter_duplicate_safetensors_files
,
filter_files_not_needed_for_inference
,
filter_duplicate_safetensors_files
,
filter_files_not_needed_for_inference
,
get_quant_config
,
initialize_dummy_weights
,
np_cache_weights_iterator
,
get_quant_config
,
initialize_dummy_weights
,
np_cache_weights_iterator
,
pt_weights_iterator
,
safetensors_weights_iterator
)
pt_weights_iterator
,
safetensors_weights_iterator
)
from
vllm.model_executor.models.interfaces
import
(
supports_lora
,
from
vllm.model_executor.models.interfaces
import
(
has_inner_state
,
supports_lora
,
supports_vision
)
supports_vision
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -66,10 +67,10 @@ def _get_quantization_config(
...
@@ -66,10 +67,10 @@ def _get_quantization_config(
def
_get_model_initialization_kwargs
(
def
_get_model_initialization_kwargs
(
model_class
:
Type
[
nn
.
Module
],
model_class
:
Type
[
nn
.
Module
],
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
)
->
Dict
[
str
,
Any
]:
scheduler_config
:
Optional
[
SchedulerConfig
]
=
None
)
->
Dict
[
str
,
Any
]:
"""Get extra kwargs for model initialization."""
"""Get extra kwargs for model initialization."""
extra_kwargs
:
Dict
[
str
,
Any
]
=
{}
extra_kwargs
:
Dict
[
str
,
Any
]
=
{}
...
@@ -90,13 +91,19 @@ def _get_model_initialization_kwargs(
...
@@ -90,13 +91,19 @@ def _get_model_initialization_kwargs(
extra_kwargs
[
"multimodal_config"
]
=
multimodal_config
extra_kwargs
[
"multimodal_config"
]
=
multimodal_config
if
has_inner_state
(
model_class
)
and
scheduler_config
:
extra_kwargs
[
"scheduler_config"
]
=
scheduler_config
return
extra_kwargs
return
extra_kwargs
def
_initialize_model
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
,
def
_initialize_model
(
lora_config
:
Optional
[
LoRAConfig
],
model_config
:
ModelConfig
,
multimodal_config
:
Optional
[
MultiModalConfig
],
load_config
:
LoadConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
cache_config
:
CacheConfig
,
scheduler_config
:
Optional
[
SchedulerConfig
]
=
None
)
->
nn
.
Module
:
"""Initialize a model with the given configurations."""
"""Initialize a model with the given configurations."""
model_class
=
get_model_architecture
(
model_config
)[
0
]
model_class
=
get_model_architecture
(
model_config
)[
0
]
quant_config
=
_get_quantization_config
(
model_config
,
load_config
)
quant_config
=
_get_quantization_config
(
model_config
,
load_config
)
...
@@ -105,7 +112,8 @@ def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
...
@@ -105,7 +112,8 @@ def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
**
_get_model_initialization_kwargs
(
**
_get_model_initialization_kwargs
(
model_class
,
lora_config
,
multimodal_config
))
model_class
,
lora_config
,
multimodal_config
,
scheduler_config
))
class
BaseModelLoader
(
ABC
):
class
BaseModelLoader
(
ABC
):
...
@@ -266,7 +274,7 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -266,7 +274,7 @@ class DefaultModelLoader(BaseModelLoader):
with
torch
.
device
(
device_config
.
device
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
multimodal_config
,
lora_config
,
multimodal_config
,
cache_config
)
cache_config
,
scheduler_config
)
model
.
load_weights
(
model
.
load_weights
(
self
.
_get_weights_iterator
(
model_config
.
model
,
self
.
_get_weights_iterator
(
model_config
.
model
,
model_config
.
revision
,
model_config
.
revision
,
...
@@ -302,7 +310,7 @@ class DummyModelLoader(BaseModelLoader):
...
@@ -302,7 +310,7 @@ class DummyModelLoader(BaseModelLoader):
with
torch
.
device
(
device_config
.
device
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
multimodal_config
,
lora_config
,
multimodal_config
,
cache_config
)
cache_config
,
scheduler_config
)
# NOTE(woosuk): For accurate performance evaluation, we assign
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
# random values to the weights.
initialize_dummy_weights
(
model
)
initialize_dummy_weights
(
model
)
...
...
vllm/model_executor/models/interfaces.py
View file @
9ad32dac
...
@@ -3,7 +3,7 @@ from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
...
@@ -3,7 +3,7 @@ from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
from
typing_extensions
import
TypeGuard
from
typing_extensions
import
TypeGuard
from
vllm.config
import
LoRAConfig
,
MultiModalConfig
from
vllm.config
import
LoRAConfig
,
MultiModalConfig
,
SchedulerConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -142,3 +142,49 @@ def _supports_lora(
...
@@ -142,3 +142,49 @@ def _supports_lora(
return
isinstance
(
model
,
_SupportsLoRAType
)
return
isinstance
(
model
,
_SupportsLoRAType
)
return
isinstance
(
model
,
SupportsLoRA
)
return
isinstance
(
model
,
SupportsLoRA
)
@
runtime_checkable
class
HasInnerState
(
Protocol
):
"""The interface required for all models that has inner state."""
has_inner_state
:
ClassVar
[
Literal
[
True
]]
=
True
"""
A flag that indicates this model has inner state.
Models that has inner state usually need access to the scheduler_config
for max_num_seqs ,etc... (Currently only used by Jamba)
"""
def
__init__
(
self
,
*
,
scheduler_config
:
Optional
[
SchedulerConfig
]
=
None
)
->
None
:
...
@
runtime_checkable
class
_HasInnerStateType
(
Protocol
):
has_inner_state
:
ClassVar
[
Literal
[
True
]]
def
__init__
(
self
,
*
,
scheduler_config
:
Optional
[
SchedulerConfig
]
=
None
)
->
None
:
...
@
overload
def
has_inner_state
(
model
:
object
)
->
TypeGuard
[
HasInnerState
]:
...
@
overload
def
has_inner_state
(
model
:
Type
[
object
])
->
TypeGuard
[
Type
[
HasInnerState
]]:
...
def
has_inner_state
(
model
:
Union
[
Type
[
object
],
object
]
)
->
Union
[
TypeGuard
[
Type
[
HasInnerState
]],
TypeGuard
[
HasInnerState
]]:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
_HasInnerStateType
)
return
isinstance
(
model
,
HasInnerState
)
vllm/model_executor/models/jamba.py
View file @
9ad32dac
...
@@ -13,7 +13,7 @@ from transformers import JambaConfig
...
@@ -13,7 +13,7 @@ from transformers import JambaConfig
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
)
...
@@ -32,10 +32,12 @@ from vllm.model_executor.layers.sampler import Sampler
...
@@ -32,10 +32,12 @@ from vllm.model_executor.layers.sampler import Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
HasInnerState
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.worker.model_runner
import
_BATCH_SIZES_TO_CAPTURE
from
vllm.worker.model_runner
import
(
_BATCH_SIZES_TO_CAPTURE
,
_get_graph_batch_size
)
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
@@ -612,7 +614,7 @@ class JambaModel(nn.Module):
...
@@ -612,7 +614,7 @@ class JambaModel(nn.Module):
return
hidden_states
return
hidden_states
class
JambaForCausalLM
(
nn
.
Module
):
class
JambaForCausalLM
(
nn
.
Module
,
HasInnerState
):
packed_modules_mapping
=
{
packed_modules_mapping
=
{
"qkv_proj"
:
[
"qkv_proj"
:
[
"q_proj"
,
"q_proj"
,
...
@@ -640,9 +642,11 @@ class JambaForCausalLM(nn.Module):
...
@@ -640,9 +642,11 @@ class JambaForCausalLM(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
scheduler_config
:
Optional
[
SchedulerConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
scheduler_config
=
scheduler_config
self
.
model
=
JambaModel
(
config
,
self
.
model
=
JambaModel
(
config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
...
@@ -689,6 +693,8 @@ class JambaForCausalLM(nn.Module):
...
@@ -689,6 +693,8 @@ class JambaForCausalLM(nn.Module):
for
key
in
[
"request_ids_to_seq_ids"
,
"finished_requests_ids"
])
for
key
in
[
"request_ids_to_seq_ids"
,
"finished_requests_ids"
])
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
finished_requests_ids
=
kwargs
[
"finished_requests_ids"
]
self
.
_release_mamba_cache
(
finished_requests_ids
)
batch_size
=
input_ids
.
shape
[
0
]
batch_size
=
input_ids
.
shape
[
0
]
if
attn_metadata
.
prefill_metadata
:
if
attn_metadata
.
prefill_metadata
:
batch_size
=
len
(
request_ids_to_seq_ids
)
batch_size
=
len
(
request_ids_to_seq_ids
)
...
@@ -696,9 +702,8 @@ class JambaForCausalLM(nn.Module):
...
@@ -696,9 +702,8 @@ class JambaForCausalLM(nn.Module):
current_seqlen_agnostic_cache
,
current_seqlen_agnostic_cache
,
indices
,
indices
,
)
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
)
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
batch_size
)
batch_size
,
finished_requests_ids
=
kwargs
[
"finished_requests_ids"
]
finished_requests_ids
)
self
.
_release_mamba_cache
(
finished_requests_ids
)
else
:
else
:
# CUDA graph capturing runs
# CUDA graph capturing runs
current_seqlen_agnostic_cache
,
indices
=
(
current_seqlen_agnostic_cache
,
indices
=
(
...
@@ -760,10 +765,15 @@ class JambaForCausalLM(nn.Module):
...
@@ -760,10 +765,15 @@ class JambaForCausalLM(nn.Module):
return
indices_for_current_run
return
indices_for_current_run
def
_prepare_current_run_mamba_cache
(
def
_prepare_current_run_mamba_cache
(
self
,
request_ids_to_seq_ids
:
Dict
[
str
,
list
[
int
]],
batch_size
:
int
self
,
request_ids_to_seq_ids
:
Dict
[
str
,
list
[
int
]],
batch_size
:
int
,
finished_requests_ids
:
List
[
str
]
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
List
[
int
]]:
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
List
[
int
]]:
indices_for_current_run
=
[]
indices_for_current_run
=
[]
for
request_id
,
seqs_id
in
request_ids_to_seq_ids
.
items
():
for
request_id
,
seqs_id
in
request_ids_to_seq_ids
.
items
():
if
request_id
in
finished_requests_ids
:
# Do not allocate cache for requests that run
# and finish right after
continue
indices_for_current_run
+=
self
.
_assign_seq_id_to_mamba_cache
(
indices_for_current_run
+=
self
.
_assign_seq_id_to_mamba_cache
(
request_id
,
seqs_id
)
request_id
,
seqs_id
)
## Pad the batch in case of running batch that was not captured via CG
## Pad the batch in case of running batch that was not captured via CG
...
@@ -787,16 +797,17 @@ class JambaForCausalLM(nn.Module):
...
@@ -787,16 +797,17 @@ class JambaForCausalLM(nn.Module):
assert
all
(
assert
all
(
key
in
kwargs
key
in
kwargs
for
key
in
[
"request_ids_to_seq_ids"
,
"finished_requests_ids"
])
for
key
in
[
"request_ids_to_seq_ids"
,
"finished_requests_ids"
])
finished_requests_ids
=
kwargs
[
"finished_requests_ids"
]
self
.
_release_mamba_cache
(
finished_requests_ids
)
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
cg_batch_size
=
input_buffers
[
'input_ids'
].
shape
[
0
]
cg_batch_size
=
input_buffers
[
'input_ids'
].
shape
[
0
]
(
(
current_mamba_cache
,
current_mamba_cache
,
indices
,
indices
,
)
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
)
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
cg_batch_size
)
cg_batch_size
,
finished_requests_ids
)
self
.
current_indices
=
indices
self
.
current_indices
=
indices
finished_requests_ids
=
kwargs
[
"finished_requests_ids"
]
self
.
_release_mamba_cache
(
finished_requests_ids
)
for
input_buffer
,
current_cache_buffer
in
zip
(
for
input_buffer
,
current_cache_buffer
in
zip
(
input_buffers
[
"seqlen_agnostic_capture_inputs"
],
input_buffers
[
"seqlen_agnostic_capture_inputs"
],
...
@@ -860,9 +871,12 @@ class JambaForCausalLM(nn.Module):
...
@@ -860,9 +871,12 @@ class JambaForCausalLM(nn.Module):
layers_type
=
self
.
config
.
layers_block_type
layers_type
=
self
.
config
.
layers_block_type
mamba_layers
=
sum
(
mamba_layers
=
sum
(
[
layer_type
==
"mamba"
for
layer_type
in
layers_type
])
[
layer_type
==
"mamba"
for
layer_type
in
layers_type
])
max_batch_size
=
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
+
10
max_batch_size
=
(
_get_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
if
self
.
scheduler_config
else
max
(
_BATCH_SIZES_TO_CAPTURE
))
+
10
conv_state_shape
,
temporal_state_shape
=
self
.
_get_mamba_cache_shape
()
conv_state_shape
,
temporal_state_shape
=
self
.
_get_mamba_cache_shape
()
assert
conv_state_shape
is
not
None
and
temporal_state_shape
is
not
None
assert
conv_state_shape
is
not
None
and
temporal_state_shape
is
not
None
for
buffername
in
[
"mamba_cache"
,
"mamba_gc_cache_buffer"
]:
for
buffername
in
[
"mamba_cache"
,
"mamba_gc_cache_buffer"
]:
buffer
=
(
torch
.
empty
(
size
=
(
mamba_layers
,
max_batch_size
)
+
buffer
=
(
torch
.
empty
(
size
=
(
mamba_layers
,
max_batch_size
)
+
conv_state_shape
,
conv_state_shape
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment