Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fb60ae9b
Unverified
Commit
fb60ae9b
authored
Oct 17, 2024
by
Mor Zusman
Committed by
GitHub
Oct 16, 2024
Browse files
[Kernel][Model] Improve continuous batching for Jamba and Mamba (#9189)
parent
415f76a9
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
511 additions
and
439 deletions
+511
-439
csrc/mamba/causal_conv1d/causal_conv1d.cu
csrc/mamba/causal_conv1d/causal_conv1d.cu
+25
-12
csrc/mamba/causal_conv1d/causal_conv1d.h
csrc/mamba/causal_conv1d/causal_conv1d.h
+1
-0
csrc/mamba/mamba_ssm/selective_scan.h
csrc/mamba/mamba_ssm/selective_scan.h
+1
-0
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
+14
-10
csrc/ops.h
csrc/ops.h
+17
-15
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+6
-3
tests/kernels/test_causal_conv1d.py
tests/kernels/test_causal_conv1d.py
+97
-94
tests/kernels/test_mamba_ssm.py
tests/kernels/test_mamba_ssm.py
+89
-35
tests/models/decoder_only/language/test_jamba.py
tests/models/decoder_only/language/test_jamba.py
+25
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+40
-33
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
+33
-20
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
+43
-27
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+31
-40
vllm/model_executor/models/mamba.py
vllm/model_executor/models/mamba.py
+28
-25
vllm/model_executor/models/mamba_cache.py
vllm/model_executor/models/mamba_cache.py
+61
-125
No files found.
csrc/mamba/causal_conv1d/causal_conv1d.cu
View file @
fb60ae9b
...
...
@@ -55,6 +55,7 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
const
at
::
Tensor
out
,
const
c10
::
optional
<
at
::
Tensor
>&
bias
,
bool
silu_activation
,
int64_t
pad_slot_id
,
const
c10
::
optional
<
at
::
Tensor
>&
query_start_loc
=
std
::
nullopt
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_indices
=
std
::
nullopt
,
const
c10
::
optional
<
at
::
Tensor
>&
has_initial_state
=
std
::
nullopt
)
{
...
...
@@ -66,6 +67,7 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
params
.
dim
=
dim
;
params
.
seqlen
=
seqlen
;
params
.
width
=
width
;
params
.
pad_slot_id
=
pad_slot_id
;
params
.
silu_activation
=
silu_activation
;
...
...
@@ -90,14 +92,16 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
}
at
::
Tensor
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
void
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>
&
bias_
,
const
c10
::
optional
<
at
::
Tensor
>
&
conv_states
,
const
c10
::
optional
<
at
::
Tensor
>
&
query_start_loc
,
const
c10
::
optional
<
at
::
Tensor
>
&
cache_indices
,
const
c10
::
optional
<
at
::
Tensor
>
&
has_initial_state
,
bool
silu_activation
)
{
bool
silu_activation
,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t
pad_slot_id
)
{
auto
input_type
=
x
.
scalar_type
();
auto
weight_type
=
weight
.
scalar_type
();
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
...
...
@@ -153,12 +157,13 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
CHECK_SHAPE
(
cache_indices_
,
batch_size
);
}
at
::
Tensor
out
=
torch
::
empty_like
(
x
)
;
at
::
Tensor
out
=
x
;
ConvParamsBase
params
;
set_conv_params_fwd
(
params
,
batch_size
,
dim
,
seqlen
,
width
,
x
,
weight
,
out
,
bias_
,
silu_activation
,
pad_slot_id
,
query_start_loc
,
cache_indices
,
has_initial_state
...
...
@@ -183,18 +188,19 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16
(
x
.
scalar_type
(),
"causal_conv1d_fwd"
,
[
&
]
{
causal_conv1d_fwd_cuda
<
input_t
,
weight_t
>
(
params
,
stream
);
});
return
out
;
}
at
::
Tensor
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
void
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>
&
bias_
,
bool
silu_activation
,
const
c10
::
optional
<
at
::
Tensor
>
&
cache_seqlens_
,
const
c10
::
optional
<
at
::
Tensor
>
&
conv_state_indices_
)
{
const
c10
::
optional
<
at
::
Tensor
>
&
conv_state_indices_
,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t
pad_slot_id
)
{
auto
input_type
=
x
.
scalar_type
();
auto
weight_type
=
weight
.
scalar_type
();
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
...
...
@@ -227,12 +233,13 @@ causal_conv1d_update(const at::Tensor &x,
CHECK_SHAPE
(
bias
,
dim
);
}
at
::
Tensor
out
=
torch
::
empty_like
(
x
)
;
at
::
Tensor
out
=
x
;
ConvParamsBase
params
;
set_conv_params_fwd
(
params
,
batch_size
,
dim
,
seqlen
,
width
,
x
,
weight
,
out
,
bias_
,
silu_activation
);
silu_activation
,
pad_slot_id
);
params
.
conv_state_ptr
=
conv_state
.
data_ptr
();
params
.
conv_state_len
=
conv_state_len
;
// All stride are in elements, not bytes.
...
...
@@ -274,7 +281,6 @@ causal_conv1d_update(const at::Tensor &x,
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16
(
x
.
scalar_type
(),
"causal_conv1d_update"
,
[
&
]
{
causal_conv1d_update_cuda
<
input_t
,
weight_t
>
(
params
,
stream
);
});
return
out
;
}
template
<
int
kNThreads_
,
int
kWidth_
,
bool
kIsVecLoad_
,
typename
input_t_
,
typename
weight_t_
>
...
...
@@ -340,7 +346,10 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
int
*
cache_indices
=
params
.
cache_indices_ptr
==
nullptr
?
nullptr
:
reinterpret_cast
<
int
*>
(
params
.
cache_indices_ptr
);
int
cache_index
=
cache_indices
==
nullptr
?
batch_id
:
cache_indices
[
batch_id
];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if
(
cache_index
==
params
.
pad_slot_id
){
return
;
}
input_t
*
conv_states
=
params
.
conv_states_ptr
==
nullptr
?
nullptr
:
reinterpret_cast
<
input_t
*>
(
params
.
conv_states_ptr
)
+
cache_index
*
params
.
conv_states_batch_stride
+
channel_id
*
params
.
conv_states_c_stride
;
...
...
@@ -528,6 +537,10 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
const
int
conv_state_batch_coord
=
params
.
conv_state_indices_ptr
==
nullptr
?
batch_id
:
params
.
conv_state_indices_ptr
[
batch_id
];
// conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early
if
(
conv_state_batch_coord
==
params
.
pad_slot_id
){
return
;
}
input_t
*
conv_state
=
reinterpret_cast
<
input_t
*>
(
params
.
conv_state_ptr
)
+
conv_state_batch_coord
*
params
.
conv_state_batch_stride
+
channel_id
*
params
.
conv_state_c_stride
;
...
...
csrc/mamba/causal_conv1d/causal_conv1d.h
View file @
fb60ae9b
...
...
@@ -13,6 +13,7 @@ struct ConvParamsBase {
using
index_t
=
uint32_t
;
int
batch
,
dim
,
seqlen
,
width
;
int64_t
pad_slot_id
;
bool
silu_activation
;
index_t
x_batch_stride
;
...
...
csrc/mamba/mamba_ssm/selective_scan.h
View file @
fb60ae9b
...
...
@@ -21,6 +21,7 @@ struct SSMParamsBase {
int
dim_ngroups_ratio
;
bool
is_variable_B
;
bool
is_variable_C
;
int64_t
pad_slot_id
;
bool
delta_softplus
;
...
...
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
View file @
fb60ae9b
...
...
@@ -115,6 +115,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
const
int
*
cache_indices
=
params
.
cache_indices_ptr
==
nullptr
?
nullptr
:
reinterpret_cast
<
int
*>
(
params
.
cache_indices_ptr
);
const
int
cache_index
=
cache_indices
==
nullptr
?
batch_id
:
cache_indices
[
batch_id
];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if
(
cache_index
==
params
.
pad_slot_id
){
return
;
}
input_t
*
u
=
reinterpret_cast
<
input_t
*>
(
params
.
u_ptr
)
+
sequence_start_index
*
params
.
u_batch_stride
+
dim_id
*
kNRows
*
params
.
u_d_stride
;
input_t
*
delta
=
reinterpret_cast
<
input_t
*>
(
params
.
delta_ptr
)
+
sequence_start_index
*
params
.
delta_batch_stride
...
...
@@ -387,7 +391,6 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
const
size_t
seqlen
,
const
size_t
dstate
,
const
size_t
n_groups
,
const
size_t
n_chunks
,
const
bool
is_variable_B
,
const
bool
is_variable_C
,
// device pointers
...
...
@@ -407,7 +410,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
const
c10
::
optional
<
at
::
Tensor
>&
query_start_loc
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_indices
,
const
c10
::
optional
<
at
::
Tensor
>&
has_initial_state
,
bool
varlen
)
{
bool
varlen
,
int64_t
pad_slot_id
)
{
// Reset the parameters
memset
(
&
params
,
0
,
sizeof
(
params
));
...
...
@@ -417,8 +421,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
params
.
seqlen
=
seqlen
;
params
.
dstate
=
dstate
;
params
.
n_groups
=
n_groups
;
params
.
n_chunks
=
n_chunks
;
params
.
dim_ngroups_ratio
=
dim
/
n_groups
;
params
.
pad_slot_id
=
pad_slot_id
;
params
.
delta_softplus
=
delta_softplus
;
...
...
@@ -507,7 +511,10 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const
c10
::
optional
<
torch
::
Tensor
>
&
query_start_loc
,
const
c10
::
optional
<
torch
::
Tensor
>
&
cache_indices
,
const
c10
::
optional
<
torch
::
Tensor
>
&
has_initial_state
,
const
torch
::
Tensor
&
ssm_states
)
{
const
torch
::
Tensor
&
ssm_states
,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t
pad_slot_id
)
{
auto
input_type
=
u
.
scalar_type
();
auto
weight_type
=
A
.
scalar_type
();
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
...
...
@@ -618,18 +625,14 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
out_z
=
z
;
const
int
n_chunks
=
(
seqlen
+
2048
-
1
)
/
2048
;
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
// at::Tensor out = torch::empty_like(u);
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
at
::
Tensor
out
=
delta
;
TORCH_CHECK
(
ssm_states
.
scalar_type
()
==
input_type
);
TORCH_CHECK
(
ssm_states
.
is_cuda
());
TORCH_CHECK
(
ssm_states
.
stride
(
-
1
)
==
1
);
CHECK_SHAPE
(
ssm_states
,
batch_size
,
dim
,
dstate
);
SSMParamsBase
params
;
set_ssm_params_fwd
(
params
,
batch_size
,
dim
,
seqlen
,
dstate
,
n_groups
,
n_chunks
,
is_variable_B
,
is_variable_C
,
set_ssm_params_fwd
(
params
,
batch_size
,
dim
,
seqlen
,
dstate
,
n_groups
,
is_variable_B
,
is_variable_C
,
u
,
delta
,
A
,
B
,
C
,
out
,
z
,
out_z
,
D_
,
delta_bias_
,
...
...
@@ -639,7 +642,8 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
query_start_loc
,
cache_indices
,
has_initial_state
,
varlen
varlen
,
pad_slot_id
);
...
...
csrc/ops.h
View file @
fb60ae9b
...
...
@@ -157,21 +157,23 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const
c10
::
optional
<
torch
::
Tensor
>&
query_start_loc
,
const
c10
::
optional
<
torch
::
Tensor
>&
cache_indices
,
const
c10
::
optional
<
torch
::
Tensor
>&
has_initial_state
,
const
torch
::
Tensor
&
ssm_states
);
at
::
Tensor
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
bool
silu_activation
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_seqlens_
,
const
c10
::
optional
<
at
::
Tensor
>&
conv_state_indices_
);
at
::
Tensor
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
const
c10
::
optional
<
at
::
Tensor
>&
conv_states
,
const
c10
::
optional
<
at
::
Tensor
>&
query_start_loc
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_indices
,
const
c10
::
optional
<
at
::
Tensor
>&
has_initial_state
,
bool
silu_activation
);
const
torch
::
Tensor
&
ssm_states
,
int64_t
pad_slot_id
);
void
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
bool
silu_activation
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_seqlens_
,
const
c10
::
optional
<
at
::
Tensor
>&
conv_state_indices_
,
int64_t
pad_slot_id
);
void
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
const
c10
::
optional
<
at
::
Tensor
>&
conv_states
,
const
c10
::
optional
<
at
::
Tensor
>&
query_start_loc
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_indices
,
const
c10
::
optional
<
at
::
Tensor
>&
has_initial_state
,
bool
silu_activation
,
int64_t
pad_slot_id
);
#ifndef USE_ROCM
using
fptr_t
=
int64_t
;
...
...
csrc/torch_bindings.cpp
View file @
fb60ae9b
...
...
@@ -278,7 +278,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"Tensor! ssm_states) -> ()"
);
"Tensor! ssm_states,"
"int pad_slot_id) -> ()"
);
ops
.
impl
(
"selective_scan_fwd"
,
torch
::
kCUDA
,
&
selective_scan_fwd
);
ops
.
def
(
...
...
@@ -288,7 +289,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? bias_,"
"bool silu_activation,"
"Tensor? cache_seqlens_,"
"Tensor? conv_state_indices) -> Tensor"
);
"Tensor? conv_state_indices,"
"int pad_slot_id) -> ()"
);
ops
.
impl
(
"causal_conv1d_update"
,
torch
::
kCUDA
,
&
causal_conv1d_update
);
ops
.
def
(
...
...
@@ -298,7 +300,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"bool silu_activation) -> Tensor"
);
"bool silu_activation,"
"int pad_slot_id) -> ()"
);
ops
.
impl
(
"causal_conv1d_fwd"
,
torch
::
kCUDA
,
&
causal_conv1d_fwd
);
#endif
...
...
tests/kernels/test_causal_conv1d.py
View file @
fb60ae9b
...
...
@@ -6,6 +6,7 @@ import torch.nn.functional as F
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
# noqa: F401
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
causal_conv1d_update
)
from
vllm.utils
import
seed_everything
...
...
@@ -114,16 +115,15 @@ def causal_conv1d_update_ref(x,
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
,
torch
.
float
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
True
])
def
causal_conv1d_opcheck_fn
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seq_len
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
has_initial_state
:
Optional
[
torch
.
Tensor
]
=
None
,
conv_states
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
"silu"
,
):
def
causal_conv1d_opcheck_fn
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seq_len
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
has_initial_state
:
Optional
[
torch
.
Tensor
]
=
None
,
conv_states
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
"silu"
,
pad_slot_id
:
int
=
PAD_SLOT_ID
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
...
...
@@ -141,16 +141,9 @@ def causal_conv1d_opcheck_fn(
x
=
x
.
contiguous
()
bias
=
bias
.
contiguous
()
if
bias
is
not
None
else
None
opcheck
(
torch
.
ops
.
_C
.
causal_conv1d_fwd
,
(
x
,
weight
,
bias
,
conv_states
,
cu_seq_len
,
cache_indices
,
has_initial_state
,
activation
in
[
"silu"
,
"swish"
],
))
opcheck
(
torch
.
ops
.
_C
.
causal_conv1d_fwd
,
(
x
,
weight
,
bias
,
conv_states
,
cu_seq_len
,
cache_indices
,
has_initial_state
,
activation
in
[
"silu"
,
"swish"
],
pad_slot_id
))
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
,
torch
.
float
])
...
...
@@ -233,17 +226,11 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
seed_everything
(
0
)
batch
=
2
x
=
torch
.
randn
(
batch
,
dim
,
seqlen
,
device
=
device
,
dtype
=
itype
)
x_ref
=
x
.
clone
()
conv_state
=
torch
.
randn
(
batch
,
dim
,
width
-
1
,
device
=
device
,
dtype
=
itype
)
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
,
requires_grad
=
True
)
if
has_bias
:
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
,
requires_grad
=
True
)
else
:
bias
=
None
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
conv_state_ref
=
conv_state
.
detach
().
clone
()
activation
=
None
if
not
silu_activation
else
"silu"
out
=
causal_conv1d_update
(
x
,
...
...
@@ -251,7 +238,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
weight
,
bias
,
activation
=
activation
)
out_ref
=
causal_conv1d_update_ref
(
x
,
out_ref
=
causal_conv1d_update_ref
(
x
_ref
,
conv_state_ref
,
weight
,
bias
,
...
...
@@ -260,15 +247,9 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
assert
torch
.
equal
(
conv_state
,
conv_state_ref
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
opcheck
(
torch
.
ops
.
_C
.
causal_conv1d_update
,
(
x
,
conv_state
,
weight
,
bias
,
activation
in
[
"silu"
,
"swish"
],
None
,
None
,
))
opcheck
(
torch
.
ops
.
_C
.
causal_conv1d_update
,
(
x
,
conv_state
,
weight
,
bias
,
activation
in
[
"silu"
,
"swish"
],
None
,
None
,
PAD_SLOT_ID
))
@
pytest
.
mark
.
parametrize
(
"itype"
,
...
...
@@ -278,37 +259,48 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
1
,
4
,
5
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
2
,
3
,
4
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
2048
+
16
,
4096
])
def
test_causal_conv1d_update_with_batch_gather
(
dim
,
width
,
seqlen
,
has_bias
,
# tests correctness in case subset of the sequences are padded
@
pytest
.
mark
.
parametrize
(
"with_padding"
,
[
True
,
False
])
def
test_causal_conv1d_update_with_batch_gather
(
with_padding
,
dim
,
width
,
seqlen
,
has_bias
,
silu_activation
,
itype
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
# set
)
seed
# set seed
seed_everything
(
0
)
batch
=
64
x
=
torch
.
randn
(
batch
,
dim
,
1
,
device
=
device
,
dtype
=
itype
)
batch_size
=
3
padding
=
5
if
with_padding
else
0
padded_batch_size
=
batch_size
+
padding
total_entries
=
10
*
batch_size
total_entries
=
10
*
batch
x
=
torch
.
randn
(
padded_batch_size
,
dim
,
1
,
device
=
device
,
dtype
=
itype
)
x_ref
=
x
.
clone
()
conv_state_indices
=
torch
.
randperm
(
total_entries
)[:
batch_size
].
to
(
dtype
=
torch
.
int32
,
device
=
device
)
unused_states_bool
=
torch
.
ones
(
total_entries
,
dtype
=
torch
.
bool
,
device
=
device
)
unused_states_bool
[
conv_state_indices
]
=
False
padded_state_indices
=
torch
.
concat
([
conv_state_indices
,
torch
.
as_tensor
(
[
PAD_SLOT_ID
]
*
padding
,
dtype
=
torch
.
int32
,
device
=
device
)
],
dim
=
0
)
conv_state
=
torch
.
randn
(
total_entries
,
dim
,
width
-
1
,
device
=
device
,
dtype
=
itype
)
conv_state_indices
=
torch
.
randperm
(
total_entries
)[:
batch
].
to
(
dtype
=
torch
.
int32
,
device
=
device
)
conv_state_for_padding_test
=
conv_state
.
clone
()
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
,
requires_grad
=
True
)
if
has_bias
:
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
,
requires_grad
=
True
)
else
:
bias
=
None
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
conv_state_ref
=
conv_state
[
conv_state_indices
,
:].
detach
().
clone
()
activation
=
None
if
not
silu_activation
else
"silu"
out
=
causal_conv1d_update
(
x
,
...
...
@@ -316,45 +308,50 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias,
weight
,
bias
,
activation
=
activation
,
conv_state_indices
=
conv_state_indices
)
out_ref
=
causal_conv1d_update_ref
(
x
,
conv_state_indices
=
padded_state_indices
,
pad_slot_id
=
PAD_SLOT_ID
)
out_ref
=
causal_conv1d_update_ref
(
x_ref
[:
batch_size
],
conv_state_ref
,
weight
,
bias
,
activation
=
activation
)
assert
torch
.
equal
(
conv_state
[
conv_state_indices
,
:],
conv_state_ref
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
[:
batch_size
],
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
equal
(
conv_state
[
unused_states_bool
],
conv_state_for_padding_test
[
unused_states_bool
])
opcheck
(
torch
.
ops
.
_C
.
causal_conv1d_update
,
(
x
,
conv_state
,
weight
,
bias
,
activation
in
[
"silu"
,
"swish"
],
None
,
conv_state_indices
,
))
opcheck
(
torch
.
ops
.
_C
.
causal_conv1d_update
,
(
x
,
conv_state
,
weight
,
bias
,
activation
in
[
"silu"
,
"swish"
],
None
,
padded_state_indices
,
PAD_SLOT_ID
))
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
8
,
16
,
32
,
64
,
128
,
256
,
512
,
784
,
1024
,
2048
,
4096
])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
8
,
16
,
32
,
64
,
128
,
256
,
512
,
784
,
1024
,
2048
,
2049
,
4096
])
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
64
,
4096
])
def
test_causal_conv1d_varlen
(
dim
,
seqlen
,
width
,
has_bias
,
silu_activation
,
itype
):
# tests correctness in case subset of the sequences are padded
@
pytest
.
mark
.
parametrize
(
'with_padding'
,
[
True
,
False
])
def
test_causal_conv1d_varlen
(
with_padding
,
dim
,
seqlen
,
width
,
has_bias
,
silu_activation
,
itype
):
device
=
"cuda"
torch
.
cuda
.
empty_cache
()
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
# set seed
seed_everything
(
0
)
batch
=
1
seqlens
=
[]
nsplits
=
3
batch_size
=
4
if
seqlen
<
10
:
batch_size
=
1
padding
=
3
if
with_padding
else
0
padded_batch_size
=
batch_size
+
padding
nsplits
=
padded_batch_size
-
1
eos_pos
=
torch
.
randperm
(
seqlen
-
1
)[:
nsplits
].
sort
().
values
seqlens
.
append
(
torch
.
diff
(
...
...
@@ -364,10 +361,11 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
assert
sum
(
seqlens
[
-
1
])
==
seqlen
assert
all
(
s
>
0
for
s
in
seqlens
[
-
1
])
total_entries
=
batch_size
*
10
cumsum
=
torch
.
cumsum
(
torch
.
tensor
(
seqlens
[
0
]),
dim
=
0
).
to
(
torch
.
int32
)
cumsum
=
torch
.
concat
([
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
),
cumsum
],
dim
=
0
)
x
=
torch
.
randn
(
batch
,
4096
+
dim
+
64
,
seqlen
,
device
=
device
,
x
=
torch
.
randn
(
1
,
4096
+
dim
+
64
,
seqlen
,
device
=
device
,
dtype
=
itype
)[:,
4096
:
4096
+
dim
,
:]
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
...
...
@@ -375,7 +373,7 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
weight_ref
=
weight
.
clone
()
bias_ref
=
bias
.
clone
()
if
bias
is
not
None
else
None
activation
=
None
if
not
silu_activation
else
"silu"
final_states
=
torch
.
randn
(
nsplits
+
1
,
final_states
=
torch
.
randn
(
total_entries
,
dim
,
width
-
1
,
device
=
x
.
device
,
...
...
@@ -385,18 +383,27 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
2
,
(
cumsum
.
shape
[
0
]
-
1
,
),
dtype
=
torch
.
bool
,
device
=
x
.
device
)
cach
e_indices
=
torch
.
randperm
(
cumsum
.
shape
[
0
]
-
1
,
stat
e_indices
=
torch
.
randperm
(
total_entries
,
dtype
=
torch
.
int32
,
device
=
x
.
device
)
device
=
x
.
device
)[:
batch_size
]
padded_state_indices
=
torch
.
concat
([
state_indices
,
torch
.
as_tensor
(
[
PAD_SLOT_ID
]
*
padding
,
dtype
=
torch
.
int32
,
device
=
device
),
],
dim
=-
1
)
out
=
causal_conv1d_fn
(
x
.
squeeze
(
0
),
weight
,
bias
,
cumsum
.
cuda
(),
cach
e_indices
,
has_initial_states
,
final_states
,
activation
)
padded_stat
e_indices
,
has_initial_states
,
final_states
,
activation
,
PAD_SLOT_ID
)
out_ref
=
[]
out_ref_b
=
[]
splits
=
[
torch
.
split
(
var
,
seqlens
[
0
],
dim
=-
1
)
for
var
in
(
x_ref
)]
for
i
in
range
(
len
(
seqlens
[
0
])):
x_s
=
[
v
[
i
].
unsqueeze
(
0
)
for
v
in
splits
][
0
]
if
padded_state_indices
[
i
]
==
PAD_SLOT_ID
:
continue
out_ref_b
.
append
(
causal_conv1d_ref
(
x_s
,
...
...
@@ -404,21 +411,17 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
bias_ref
,
activation
=
activation
,
return_final_states
=
True
,
final_states_out
=
final_states_ref
[
cache_indices
[
i
]].
unsqueeze
(
0
),
initial_states
=
final_states_ref
[
cach
e_indices
[
i
]].
unsqueeze
(
0
)
if
has_initial_states
[
i
]
else
None
))
final_states_out
=
final_states_ref
[
padded_state_indices
[
i
]].
unsqueeze
(
0
),
initial_states
=
final_states_ref
[
padded_stat
e_indices
[
i
]].
unsqueeze
(
0
)
if
has_initial_states
[
i
]
else
None
))
out_ref
.
append
(
torch
.
cat
([
t
[
0
]
for
t
in
out_ref_b
],
dim
=
2
))
out_ref
=
torch
.
cat
(
out_ref
,
dim
=
0
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
"Output state max diff"
f
":
{
(
final_states
-
final_states_ref
).
abs
().
max
()
}
"
)
print
(
"Output state mean diff"
f
":
{
(
final_states
-
final_states_ref
).
abs
().
mean
()
}
"
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
out_ref_tensor
=
torch
.
cat
(
out_ref
,
dim
=
0
)
unpadded_out
=
out
[:,
:
out_ref_tensor
.
shape
[
-
1
]]
assert
torch
.
allclose
(
unpadded_out
,
out_ref_tensor
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
final_states
,
final_states_ref
,
rtol
=
rtol
,
atol
=
atol
)
causal_conv1d_opcheck_fn
(
x
.
squeeze
(
0
),
weight
,
bias
,
cumsum
.
cuda
(),
cach
e_indices
,
has_initial_states
,
final_states
,
activation
)
padded_stat
e_indices
,
has_initial_states
,
final_states
,
activation
)
tests/kernels/test_mamba_ssm.py
View file @
fb60ae9b
...
...
@@ -5,6 +5,7 @@ from einops import rearrange, repeat
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
# noqa: F401
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.model_executor.layers.mamba.ops.mamba_ssm
import
(
selective_scan_fn
,
selective_state_update
)
from
vllm.utils
import
seed_everything
...
...
@@ -174,7 +175,8 @@ def selective_scan_opcheck_fn(u,
cu_seq_len
=
None
,
cache_indices
=
None
,
has_initial_state
=
None
,
ssm_states
=
None
):
ssm_states
=
None
,
pad_slot_id
=
PAD_SLOT_ID
):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate).
"""
...
...
@@ -203,7 +205,7 @@ def selective_scan_opcheck_fn(u,
# a bogus error.
opcheck
(
torch
.
ops
.
_C
.
selective_scan_fwd
,
(
u
,
delta
,
A
,
B
,
C
,
D
,
z
,
delta_bias
,
delta_softplus
,
cu_seq_len
,
cache_indices
,
has_initial_state
,
ssm_states
),
cache_indices
,
has_initial_state
,
ssm_states
,
pad_slot_id
),
test_utils
=
[
"test_schema"
,
"test_faketensor"
])
...
...
@@ -404,9 +406,12 @@ def test_selective_state_update(dim, dstate, has_z, itype):
@
pytest
.
mark
.
parametrize
(
"varBC_groups"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"is_variable_C"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"is_variable_B"
,
[
True
])
def
test_selective_scan_varlen
(
is_variable_B
,
is_variable_C
,
varBC_groups
,
has_D
,
has_z
,
has_delta_bias
,
delta_softplus
,
return_last_state
,
seqlen
,
itype
,
wtype
):
# tests correctness in case subset of the sequences are padded
@
pytest
.
mark
.
parametrize
(
"with_padding"
,
[
False
,
True
])
def
test_selective_scan_varlen
(
with_padding
,
is_variable_B
,
is_variable_C
,
varBC_groups
,
has_D
,
has_z
,
has_delta_bias
,
delta_softplus
,
return_last_state
,
seqlen
,
itype
,
wtype
):
if
varBC_groups
>
1
and
(
not
is_variable_B
or
not
is_variable_C
):
pytest
.
skip
()
# This config is not applicable
device
=
'cuda'
...
...
@@ -420,18 +425,27 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
# set seed
torch
.
random
.
manual_seed
(
0
)
seqlens
=
[]
nsplits
=
3
batch_size
=
4
if
seqlen
<
10
:
nsplits
=
0
batch_size
=
1
padding
=
3
if
with_padding
else
0
padded_batch_size
=
batch_size
+
padding
if
with_padding
and
seqlen
<
padded_batch_size
:
pytest
.
skip
()
nsplits
=
padded_batch_size
-
1
eos_pos
=
torch
.
randperm
(
seqlen
-
1
)[:
nsplits
].
sort
().
values
seqlens
.
append
(
torch
.
diff
(
torch
.
cat
(
[
torch
.
tensor
([
-
1
]),
eos_pos
,
torch
.
tensor
([
seqlen
-
1
])])).
tolist
())
assert
sum
(
seqlens
[
-
1
])
==
seqlen
assert
all
(
s
>
0
for
s
in
seqlens
[
-
1
])
total_entries
=
batch_size
*
10
cumsum
=
torch
.
cumsum
(
torch
.
tensor
(
seqlens
[
0
]),
dim
=
0
).
to
(
torch
.
int32
)
cumsum
=
torch
.
concat
([
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
),
cumsum
],
dim
=
0
).
cuda
()
...
...
@@ -462,22 +476,33 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
delta_ref
=
delta
.
clone
()
out
=
None
out_ref
=
None
prev_state_shape
=
(
cumsum
.
shape
[
0
]
-
1
,
u
.
shape
[
0
],
int
(
A
.
shape
[
1
]))
prev_state_shape
=
(
total_entries
,
u
.
shape
[
0
],
int
(
A
.
shape
[
1
]))
prev_state
=
torch
.
randn
(
prev_state_shape
,
device
=
u
.
device
,
dtype
=
itype
,
requires_grad
=
False
)
prev_state_ref
=
prev_state
.
clone
()
cach
e_indices
=
torch
.
randperm
(
cumsum
.
shape
[
0
]
-
1
,
stat
e_indices
=
torch
.
randperm
(
total_entries
,
dtype
=
torch
.
int32
,
device
=
u
.
device
)
device
=
u
.
device
)[:
batch_size
]
unused_states_bool
=
torch
.
ones
(
total_entries
,
dtype
=
torch
.
bool
,
device
=
device
)
unused_states_bool
[
state_indices
]
=
False
padded_state_indices
=
torch
.
concat
([
state_indices
,
torch
.
as_tensor
(
[
PAD_SLOT_ID
]
*
padding
,
dtype
=
torch
.
int32
,
device
=
device
),
],
dim
=-
1
)
has_initial_state
=
torch
.
randint
(
0
,
2
,
(
cumsum
.
shape
[
0
]
-
1
,
),
dtype
=
torch
.
bool
,
device
=
u
.
device
)
out
=
selective_scan_fn
(
u
,
prev_state
,
delta
,
A
,
B
,
C
,
D
,
z
,
delta_bias
,
delta_softplus
,
cumsum
,
cach
e_indices
,
delta_softplus
,
cumsum
,
padded_stat
e_indices
,
has_initial_state
)
outs_ref
=
[]
splits
=
[
...
...
@@ -486,6 +511,8 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
]
for
i
in
range
(
len
(
seqlens
[
0
])):
u_s
,
delta_s
,
B_s
,
C_s
,
z_s
=
[
v
[
i
].
unsqueeze
(
0
)
for
v
in
splits
]
if
padded_state_indices
[
i
]
==
PAD_SLOT_ID
:
continue
out_ref_s
,
_
=
selective_scan_ref
(
u_s
,
delta_s
,
...
...
@@ -497,21 +524,22 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
delta_bias
=
delta_bias
,
delta_softplus
=
delta_softplus
,
return_last_state
=
return_last_state
,
prev_state
=
prev_state_ref
[
cach
e_indices
[
i
]].
unsqueeze
(
0
)
prev_state
=
prev_state_ref
[
padded_stat
e_indices
[
i
]].
unsqueeze
(
0
)
if
has_initial_state
[
i
]
else
None
,
final_state_out
=
prev_state_ref
[
cache_indices
[
i
]].
unsqueeze
(
0
))
final_state_out
=
prev_state_ref
[
padded_state_indices
[
i
]].
unsqueeze
(
0
))
outs_ref
.
append
(
out_ref_s
)
out_ref
=
torch
.
cat
(
outs_ref
,
dim
=-
1
)
if
len
(
outs_ref
)
>
1
else
outs_ref
[
0
]
out_ref
=
torch
.
cat
(
outs_ref
,
dim
=-
1
)[
0
]
print
(
"Output diff max"
,
(
out
-
out_ref
[
0
]).
max
())
print
(
"Output diff mean"
,
(
out
-
out_ref
[
0
]).
mean
())
unpadded_out
=
out
[:,
:
out_ref
[
0
].
shape
[
-
1
]]
print
(
"Output diff max"
,
(
unpadded_out
-
out_ref
).
max
())
print
(
"Output diff mean"
,
(
unpadded_out
-
out_ref
).
mean
())
print
(
"Output state diff max"
,
(
prev_state
-
prev_state_ref
).
max
())
print
(
"Output state diff mean"
,
(
prev_state
-
prev_state_ref
).
mean
())
assert
torch
.
allclose
(
prev_state
,
prev_state_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
[
0
],
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
unpadded_out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
selective_scan_opcheck_fn
(
u
,
delta
,
A
,
B
,
C
,
D
,
z
,
delta_bias
,
delta_softplus
,
cumsum
,
cach
e_indices
,
delta_softplus
,
cumsum
,
padded_stat
e_indices
,
has_initial_state
,
prev_state
)
...
...
@@ -520,7 +548,10 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
@
pytest
.
mark
.
parametrize
(
"has_z"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"dstate"
,
[
16
,
32
,
64
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
2048
+
16
,
4096
])
def
test_selective_state_update_with_batch_indices
(
dim
,
dstate
,
has_z
,
itype
):
# tests correctness in case subset of the sequences are padded
@
pytest
.
mark
.
parametrize
(
"with_padding"
,
[
True
,
False
])
def
test_selective_state_update_with_batch_indices
(
with_padding
,
dim
,
dstate
,
has_z
,
itype
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
5e-3
,
1e-2
)
if
itype
==
torch
.
bfloat16
:
...
...
@@ -530,21 +561,32 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
3
padding
=
5
if
with_padding
else
0
padded_batch_size
=
batch_size
+
padding
total_entries
=
10
*
batch_size
state
=
torch
.
randn
(
total_entries
,
dim
,
dstate
,
dtype
=
itype
,
device
=
device
)
state_indices
=
torch
.
randperm
(
total_entries
)[:
batch_size
].
to
(
dtype
=
torch
.
int32
,
device
=
device
)
x
=
torch
.
randn
(
batch_size
,
dim
,
device
=
device
,
dtype
=
itype
)
dt
=
torch
.
randn
(
batch_size
,
dim
,
device
=
device
,
dtype
=
itype
)
unused_states_bool
=
torch
.
ones
(
total_entries
,
dtype
=
torch
.
bool
,
device
=
device
)
unused_states_bool
[
state_indices
]
=
False
padded_state_indices
=
torch
.
concat
([
state_indices
,
torch
.
as_tensor
(
[
PAD_SLOT_ID
]
*
padding
,
dtype
=
torch
.
int32
,
device
=
device
)
],
dim
=
0
)
x
=
torch
.
randn
(
padded_batch_size
,
dim
,
device
=
device
,
dtype
=
itype
)
dt
=
torch
.
randn
(
padded_batch_size
,
dim
,
device
=
device
,
dtype
=
itype
)
dt_bias
=
torch
.
rand
(
dim
,
device
=
device
)
-
4.0
A
=
-
torch
.
rand
(
dim
,
dstate
,
device
=
device
)
-
1.0
B
=
torch
.
randn
(
batch_size
,
dstate
,
device
=
device
)
C
=
torch
.
randn
(
batch_size
,
dstate
,
device
=
device
)
B
=
torch
.
randn
(
padded_
batch_size
,
dstate
,
device
=
device
)
C
=
torch
.
randn
(
padded_
batch_size
,
dstate
,
device
=
device
)
D
=
torch
.
randn
(
dim
,
device
=
device
)
z
=
torch
.
randn_like
(
x
)
if
has_z
else
None
state_ref
=
state
[
state_indices
,
:].
detach
().
clone
()
state_ref
=
state
[
state_indices
,
:].
clone
()
state_before
=
state
.
clone
()
out
=
selective_state_update
(
state
,
x
,
dt
,
...
...
@@ -555,15 +597,16 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
z
=
z
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
state_batch_indices
=
state_indices
)
state_batch_indices
=
padded_state_indices
,
pad_slot_id
=
PAD_SLOT_ID
)
out_ref
=
selective_state_update_ref
(
state_ref
,
x
,
dt
,
x
[:
batch_size
]
,
dt
[:
batch_size
]
,
A
,
B
,
C
,
B
[:
batch_size
]
,
C
[:
batch_size
]
,
D
=
D
,
z
=
z
,
z
=
z
[:
batch_size
]
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
)
...
...
@@ -572,11 +615,21 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
print
(
"Output state diff max"
,
(
state
[
state_indices
,
:]
-
state_ref
).
max
())
print
(
"Output state diff mean"
,
(
state
[
state_indices
,
:]
-
state_ref
).
mean
())
# test padded entries stay the same
if
with_padding
:
assert
torch
.
equal
(
state_before
[
unused_states_bool
],
state
[
unused_states_bool
])
assert
torch
.
equal
(
x
[
batch_size
+
1
:],
x
[
batch_size
+
1
:])
assert
torch
.
equal
(
dt
[
batch_size
+
1
:],
dt
[
batch_size
+
1
:])
assert
torch
.
equal
(
B
[
batch_size
+
1
:],
B
[
batch_size
+
1
:])
assert
torch
.
equal
(
C
[
batch_size
+
1
:],
C
[
batch_size
+
1
:])
# test "real" entries
assert
torch
.
allclose
(
state
[
state_indices
,
:],
state_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
[:
batch_size
]
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
...
...
@@ -645,7 +698,8 @@ def test_selective_state_update_with_heads_with_batch_indices(
z
=
z
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
state_batch_indices
=
state_indices
)
state_batch_indices
=
state_indices
,
pad_slot_id
=
PAD_SLOT_ID
)
out_ref
=
selective_state_update_ref
(
state_ref
,
x
,
dt
,
...
...
tests/models/decoder_only/language/test_jamba.py
View file @
fb60ae9b
import
pytest
from
tests.utils
import
multi_gpu_test
from
vllm.sampling_params
import
SamplingParams
from
vllm.worker.model_runner
import
_get_graph_batch_size
...
...
@@ -270,6 +271,30 @@ def test_state_cleanup(
"could be related to finished_requests_ids"
)
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
def
test_jamba_distributed_produces_identical_generation
(
vllm_runner
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
example_prompts
)
->
None
:
with
vllm_runner
(
model
,
dtype
=
dtype
,
tensor_parallel_size
=
2
)
as
vllm_model
:
vllm_outputs_tp_2
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
with
vllm_runner
(
model
,
dtype
=
dtype
,
tensor_parallel_size
=
1
)
as
vllm_model
:
vllm_outputs_tp_1
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
check_outputs_equal
(
outputs_0_lst
=
vllm_outputs_tp_1
,
outputs_1_lst
=
vllm_outputs_tp_2
,
name_0
=
"vllm_tp_1"
,
name_1
=
"vllm_tp_2"
,
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
def
test_model_print
(
...
...
vllm/_custom_ops.py
View file @
fb60ae9b
...
...
@@ -464,16 +464,18 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
cu_seq_len
:
Optional
[
torch
.
Tensor
],
cache_indices
:
Optional
[
torch
.
Tensor
],
has_initial_state
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
x
)
silu_activation
:
bool
,
pad_slot_id
:
int
)
:
return
None
@
register_fake
(
"_C::causal_conv1d_update"
)
def
causal_conv1d_update_fake
(
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
,
cache_seqlens
:
Optional
[
torch
.
Tensor
],
conv_state_indices
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
torch
.
empty_like
(
x
)
def
causal_conv1d_update_fake
(
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
,
cache_seqlens
:
Optional
[
torch
.
Tensor
],
conv_state_indices
:
Optional
[
torch
.
Tensor
],
pad_slot_id
:
int
)
->
None
:
return
None
@
register_fake
(
"_C::selective_scan_fwd"
)
def
selective_scan_fwd_fake
(
u
:
torch
.
Tensor
,
delta
:
torch
.
Tensor
,
...
...
@@ -485,7 +487,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
cu_seq_len
:
Optional
[
torch
.
Tensor
],
cache_indices
:
Optional
[
torch
.
Tensor
],
has_initial_state
:
Optional
[
torch
.
Tensor
],
ssm_states
:
Optional
[
torch
.
Tensor
])
->
None
:
ssm_states
:
Optional
[
torch
.
Tensor
],
pad_slot_id
:
int
)
->
None
:
return
None
...
...
@@ -800,33 +803,37 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
query_start_loc
:
Optional
[
torch
.
Tensor
],
cache_indices
:
Optional
[
torch
.
Tensor
],
has_initial_state
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
causal_conv1d_fwd
(
x
,
weight
,
bias_
,
conv_states
,
query_start_loc
,
cache_indices
,
has_initial_state
,
silu_activation
)
def
causal_conv1d_update
(
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
,
cache_seqlens
:
Optional
[
torch
.
Tensor
],
conv_state_indices
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias_
,
silu_activation
,
cache_seqlens
,
conv_state_indices
)
def
selective_scan_fwd
(
u
:
torch
.
Tensor
,
delta
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
D_
:
Optional
[
torch
.
Tensor
],
z_
:
Optional
[
torch
.
Tensor
],
delta_bias_
:
Optional
[
torch
.
Tensor
],
delta_softplus
:
bool
,
query_start_loc
:
Optional
[
torch
.
Tensor
],
cache_indices
:
Optional
[
torch
.
Tensor
],
has_initial_state
:
Optional
[
torch
.
Tensor
],
ssm_states
:
torch
.
Tensor
):
silu_activation
:
bool
,
pad_slot_id
:
int
):
torch
.
ops
.
_C
.
causal_conv1d_fwd
(
x
,
weight
,
bias_
,
conv_states
,
query_start_loc
,
cache_indices
,
has_initial_state
,
silu_activation
,
pad_slot_id
)
def
causal_conv1d_update
(
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
,
cache_seqlens
:
Optional
[
torch
.
Tensor
],
conv_state_indices
:
Optional
[
torch
.
Tensor
],
pad_slot_id
:
int
):
torch
.
ops
.
_C
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias_
,
silu_activation
,
cache_seqlens
,
conv_state_indices
,
pad_slot_id
)
def
selective_scan_fwd
(
u
:
torch
.
Tensor
,
delta
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
D_
:
Optional
[
torch
.
Tensor
],
z_
:
Optional
[
torch
.
Tensor
],
delta_bias_
:
Optional
[
torch
.
Tensor
],
delta_softplus
:
bool
,
query_start_loc
:
Optional
[
torch
.
Tensor
],
cache_indices
:
Optional
[
torch
.
Tensor
],
has_initial_state
:
Optional
[
torch
.
Tensor
],
ssm_states
:
torch
.
Tensor
,
pad_slot_id
:
int
):
torch
.
ops
.
_C
.
selective_scan_fwd
(
u
,
delta
,
A
,
B
,
C
,
D_
,
z_
,
delta_bias_
,
delta_softplus
,
query_start_loc
,
cache_indices
,
has_initial_state
,
ssm_states
)
ssm_states
,
pad_slot_id
)
# moe
...
...
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
View file @
fb60ae9b
...
...
@@ -6,18 +6,18 @@ from typing import Optional
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
def
causal_conv1d_fn
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
has_initial_state
:
Optional
[
torch
.
Tensor
]
=
None
,
conv_states
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
"silu"
,
):
def
causal_conv1d_fn
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
has_initial_state
:
Optional
[
torch
.
Tensor
]
=
None
,
conv_states
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
"silu"
,
pad_slot_id
:
int
=
PAD_SLOT_ID
):
"""
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
sequences are concatenated from left to right for varlen
...
...
@@ -37,6 +37,13 @@ def causal_conv1d_fn(
conv_states: (...,dim,width - 1) itype
updated inplace if provided
activation: either None or "silu" or "swish"
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim, seqlen)
"""
...
...
@@ -46,10 +53,10 @@ def causal_conv1d_fn(
x
=
x
.
contiguous
()
bias
=
bias
.
contiguous
()
if
bias
is
not
None
else
None
out
=
ops
.
causal_conv1d_fwd
(
x
,
weight
,
bias
,
conv_states
,
query_start_loc
,
cache_indices
,
has_initial_state
,
activation
in
[
"silu"
,
"swish"
])
return
out
ops
.
causal_conv1d_fwd
(
x
,
weight
,
bias
,
conv_states
,
query_start_loc
,
cache_indices
,
has_initial_state
,
activation
in
[
"silu"
,
"swish"
]
,
pad_slot_id
)
return
x
def
causal_conv1d_update
(
x
:
torch
.
Tensor
,
...
...
@@ -58,7 +65,8 @@ def causal_conv1d_update(x: torch.Tensor,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
None
,
cache_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
conv_state_indices
:
Optional
[
torch
.
Tensor
]
=
None
):
conv_state_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
pad_slot_id
:
int
=
PAD_SLOT_ID
):
"""
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
...
...
@@ -73,7 +81,12 @@ def causal_conv1d_update(x: torch.Tensor,
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen)
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
...
...
@@ -82,8 +95,8 @@ def causal_conv1d_update(x: torch.Tensor,
unsqueeze
=
x
.
dim
()
==
2
if
unsqueeze
:
x
=
x
.
unsqueeze
(
-
1
)
out
=
ops
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias
,
activation_val
,
cache_seqlens
,
conv_state_indices
)
ops
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias
,
activation_val
,
cache_seqlens
,
conv_state_indices
,
pad_slot_id
)
if
unsqueeze
:
out
=
out
.
squeeze
(
-
1
)
return
out
x
=
x
.
squeeze
(
-
1
)
return
x
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
View file @
fb60ae9b
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py
from
typing
import
Tuple
import
torch
import
triton
import
triton.language
as
tl
from
packaging
import
version
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
TRITON3
=
version
.
parse
(
triton
.
__version__
)
>=
version
.
parse
(
"3.0.0"
)
...
...
@@ -50,6 +49,7 @@ def _selective_scan_update_kernel(
z_ptr
,
out_ptr
,
state_batch_indices_ptr
,
pad_slot_id
,
# Matrix dimensions
batch
,
nheads
,
...
...
@@ -143,10 +143,11 @@ def _selective_scan_update_kernel(
if
HAS_Z
:
z_ptrs
=
z_ptr
+
offs_m
*
stride_z_dim
out_ptrs
=
out_ptr
+
offs_m
*
stride_out_dim
mask
=
(
offs_m
[:,
None
]
<
dim
)
&
(
offs_n
[
None
,
:]
<
dstate
)
if
HAS_STATE_BATCH_INDICES
:
mask
&=
(
state_batch_idx
!=
pad_slot_id
)
state
=
tl
.
load
(
state_ptrs
,
mask
=
mask
,
other
=
0.0
)
state
=
tl
.
load
(
state_ptrs
,
mask
=
(
offs_m
[:,
None
]
<
dim
)
&
(
offs_n
[
None
,
:]
<
dstate
),
other
=
0.0
)
x
=
tl
.
load
(
x_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
if
not
TIE_HDIM
:
dt
=
tl
.
load
(
dt_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
...
...
@@ -177,9 +178,11 @@ def _selective_scan_update_kernel(
dB
=
B
[
None
,
:]
*
dt
[:,
None
]
if
not
TIE_HDIM
else
B
*
dt
state
=
state
*
dA
+
dB
*
x
[:,
None
]
tl
.
store
(
state_ptrs
,
state
,
mask
=
(
offs_m
[:,
None
]
<
dim
)
&
(
offs_n
[
None
,
:]
<
dstate
))
mask
=
(
offs_m
[:,
None
]
<
dim
)
&
(
offs_n
[
None
,
:]
<
dstate
)
if
HAS_STATE_BATCH_INDICES
:
mask
&=
(
state_batch_idx
!=
pad_slot_id
)
tl
.
store
(
state_ptrs
,
state
,
mask
=
mask
)
out
=
tl
.
sum
(
state
*
C
[
None
,
:],
axis
=
1
)
if
HAS_D
:
out
+=
x
*
D
...
...
@@ -198,7 +201,8 @@ def selective_state_update(state,
z
=
None
,
dt_bias
=
None
,
dt_softplus
=
False
,
state_batch_indices
=
None
):
state_batch_indices
=
None
,
pad_slot_id
=
PAD_SLOT_ID
):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
...
...
@@ -210,6 +214,12 @@ def selective_state_update(state,
D: (dim,) or (nheads, dim)
z: (batch, dim) or (batch, nheads, dim)
dt_bias: (dim,) or (nheads, dim)
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
Return:
out: (batch, dim) or (batch, nheads, dim)
"""
...
...
@@ -276,6 +286,7 @@ def selective_state_update(state,
z
,
out
,
state_batch_indices
,
pad_slot_id
,
batch
,
nheads
,
dim
,
...
...
@@ -319,22 +330,25 @@ def selective_state_update(state,
return
out
def
selective_scan_fn
(
u
,
ssm_states
,
delta
,
A
,
B
,
C
,
D
=
None
,
z
=
None
,
delta_
bias
=
Non
e
,
delta_softplus
=
Fals
e
,
query_start_loc
=
None
,
cache_indices
=
None
,
has_initial_state
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
def
selective_scan_fn
(
u
,
ssm_states
,
delta
,
A
,
B
,
C
,
D
=
None
,
z
=
None
,
delta_bias
=
None
,
delta_
softplus
=
Fals
e
,
query_start_loc
=
Non
e
,
cache_indices
=
None
,
has_initial_state
=
None
,
pad_slot_id
=
PAD_SLOT_ID
)
->
torch
.
Tensor
:
"""
u: (dim, total_length) for varlen or (batch, dim, seqlen)
applies changes in place.
ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate)
applies changes in place.
delta: (dim, total_length) for varlen or (batch, dim, seqlen)
A: (dim, dstate)
B: (ngroups, dstate, total_length) for varlen or
...
...
@@ -357,12 +371,14 @@ def selective_scan_fn(
indicate if the ssm_state at the corresponding index should be
used as initial state. Not providing argument assumes
there's no initial state
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padding entries
that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at indices 0 and 3
returns
output: (dim, total_length) for varlen or (batch, dim, seqlen)
supports inplace replacement
last_state has shape (batch, dim, dstate).
supports inplace replacement if ssm_state was provided
"""
if
u
.
stride
(
-
1
)
!=
1
:
u
=
u
.
contiguous
()
...
...
@@ -387,7 +403,7 @@ def selective_scan_fn(
ops
.
selective_scan_fwd
(
u
,
delta
,
A
,
B
,
C
,
D
,
z
,
delta_bias
,
delta_softplus
,
query_start_loc
,
cache_indices
,
has_initial_state
,
ssm_states
)
ssm_states
,
pad_slot_id
)
if
z
is
None
:
return
delta
# output written inplace to delta
...
...
vllm/model_executor/models/jamba.py
View file @
fb60ae9b
# coding=utf-8
"""Inference-only Jamba model."""
from
dataclasses
import
dataclass
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
...
...
@@ -29,7 +28,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
composed_weight_loader
,
default_weight_loader
,
sharded_weight_loader
)
from
vllm.model_executor.models.mamba_cache
import
MambaCacheManager
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -41,13 +41,6 @@ from .interfaces import HasInnerState, SupportsLoRA
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
@
dataclass
class
MambaCacheParams
:
is_prompt
:
bool
=
False
conv_state
:
torch
.
Tensor
=
torch
.
Tensor
()
ssm_state
:
torch
.
Tensor
=
torch
.
Tensor
()
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
class
JambaMambaMixer
(
nn
.
Module
):
"""
...
...
@@ -60,10 +53,9 @@ class JambaMambaMixer(nn.Module):
**selective** state spaces)
"""
def
__init__
(
self
,
config
:
JambaConfig
,
layer_idx
):
def
__init__
(
self
,
config
:
JambaConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
layer_idx
=
layer_idx
self
.
hidden_size
=
config
.
hidden_size
self
.
ssm_state_size
=
config
.
mamba_d_state
self
.
conv_kernel_size
=
config
.
mamba_d_conv
...
...
@@ -129,8 +121,8 @@ class JambaMambaMixer(nn.Module):
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
conv_state
:
torch
.
Tensor
,
ssm_state
:
torch
.
Tensor
):
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
):
# 1. Gated MLP's linear projection
projected_states
=
self
.
in_proj
(
hidden_states
)[
0
].
transpose
(
-
2
,
-
1
)
...
...
@@ -153,17 +145,18 @@ class JambaMambaMixer(nn.Module):
conv_weights
,
self
.
conv1d
.
bias
,
activation
=
self
.
activation
,
conv_states
=
conv_state
,
conv_states
=
mamba_cache_params
.
conv_state
,
has_initial_state
=
attn_metadata
.
context_lens_tensor
>
0
,
cache_indices
=
mamba_cache_params
.
state_indices_tensor
,
query_start_loc
=
attn_metadata
.
query_start_loc
)
else
:
hidden_states
=
causal_conv1d_update
(
hidden_states
.
transpose
(
0
,
1
),
conv_state
,
mamba_cache_params
.
conv_state
,
conv_weights
,
self
.
conv1d
.
bias
,
self
.
activation
,
)
conv_state_indices
=
mamba_cache_params
.
state_indices_tensor
)
hidden_states
=
hidden_states
.
transpose
(
0
,
1
)
# 3. State Space Model sequence transformation
...
...
@@ -188,7 +181,7 @@ class JambaMambaMixer(nn.Module):
and
attn_metadata
.
context_lens_tensor
is
not
None
:
scan_outputs
=
selective_scan_fn
(
hidden_states
,
ssm_state
,
mamba_cache_params
.
ssm_state
,
discrete_time_step
,
self
.
A
,
B
.
transpose
(
-
2
,
-
1
),
...
...
@@ -197,11 +190,12 @@ class JambaMambaMixer(nn.Module):
gate
,
time_proj_bias
,
delta_softplus
=
True
,
cache_indices
=
mamba_cache_params
.
state_indices_tensor
,
has_initial_state
=
attn_metadata
.
context_lens_tensor
>
0
,
query_start_loc
=
attn_metadata
.
query_start_loc
)
else
:
scan_outputs
=
selective_state_update
(
ssm_state
,
mamba_cache_params
.
ssm_state
,
hidden_states
.
transpose
(
0
,
1
),
discrete_time_step
.
transpose
(
0
,
1
),
self
.
A
,
...
...
@@ -211,7 +205,7 @@ class JambaMambaMixer(nn.Module):
gate
.
transpose
(
0
,
1
),
time_proj_bias
,
dt_softplus
=
True
,
)
state_batch_indices
=
mamba_cache_params
.
state_indices_tensor
)
scan_outputs
=
scan_outputs
.
transpose
(
0
,
1
)
# 4. Final linear projection
...
...
@@ -292,7 +286,7 @@ class JambaMambaDecoderLayer(nn.Module):
super
().
__init__
()
self
.
layer_idx
=
layer_idx
self
.
config
=
config
self
.
mamba
=
JambaMambaMixer
(
config
,
layer_idx
)
self
.
mamba
=
JambaMambaMixer
(
config
)
num_experts
=
config
.
layers_num_experts
[
layer_idx
]
ffn_layer_class
=
JambaMoE
if
num_experts
>
1
else
JambaMLP
...
...
@@ -307,8 +301,7 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
conv_state
:
torch
.
Tensor
,
ssm_state
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
**
kwargs
,
):
if
residual
is
None
:
...
...
@@ -318,8 +311,8 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mamba
(
hidden_states
,
attn_metadata
,
conv_state
,
ssm_state
)
hidden_states
=
self
.
mamba
(
hidden_states
,
attn_metadata
,
mamba_cache_params
)
# Fully Connected
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
hidden_states
,
residual
)
...
...
@@ -476,17 +469,14 @@ class JambaModel(nn.Module):
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
conv_state
:
torch
.
Tensor
,
ssm_state
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
kv_cache
=
None
current_ssm_state
=
None
current_conv_state
=
None
layer_mamba_cache_params
=
None
if
isinstance
(
layer
,
JambaAttentionDecoderLayer
):
kv_cache
=
kv_caches
[(
i
-
self
.
config
.
attn_layer_offset
)
//
self
.
config
.
attn_layer_period
]
...
...
@@ -494,8 +484,8 @@ class JambaModel(nn.Module):
current_state_layer
=
i
-
(
1
+
(
i
-
self
.
config
.
attn_layer_offset
)
//
self
.
config
.
attn_layer_period
)
current_ssm_state
=
ssm_state
[
current_st
at
e
_layer
]
current_conv_state
=
conv_state
[
current_state_layer
]
layer_mamba_cache_params
=
mamba_cache_params
.
at_layer
_idx
(
current_state_layer
)
hidden_states
,
residual
=
layer
(
positions
=
positions
,
...
...
@@ -503,9 +493,7 @@ class JambaModel(nn.Module):
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
residual
=
residual
,
conv_state
=
current_conv_state
,
ssm_state
=
current_ssm_state
,
)
mamba_cache_params
=
layer_mamba_cache_params
)
hidden_states
,
_
=
self
.
final_layernorm
(
hidden_states
,
residual
)
return
hidden_states
...
...
@@ -588,13 +576,16 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
self
.
mamba_cache
=
MambaCacheManager
(
self
.
lm_head
.
weight
.
dtype
,
num_mamba_layers
,
max_batch_size
,
*
self
.
_get_mamba_cache_shape
())
mamba_cache_tensors
=
self
.
mamba_cache
.
current_run_tensors
(
input_ids
,
attn_metadata
,
**
kwargs
)
(
mamba_cache_tensors
,
state_indices_tensor
,
)
=
self
.
mamba_cache
.
current_run_tensors
(
input_ids
,
attn_metadata
,
**
kwargs
)
mamba_cache_params
=
MambaCacheParams
(
mamba_cache_tensors
[
0
],
mamba_cache_tensors
[
1
],
state_indices_tensor
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
mamba_cache_tensors
[
0
],
mamba_cache_tensors
[
1
])
attn_metadata
,
mamba_cache_params
)
return
hidden_states
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
...
...
vllm/model_executor/models/mamba.py
View file @
fb60ae9b
...
...
@@ -27,7 +27,8 @@ from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader
,
default_weight_loader
,
sharded_weight_loader
)
from
vllm.model_executor.models.interfaces
import
(
HasInnerState
,
IsAttentionFree
)
from
vllm.model_executor.models.mamba_cache
import
MambaCacheManager
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -110,8 +111,8 @@ class MambaMixer(nn.Module):
self
.
activation
=
config
.
hidden_act
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
conv_state
:
torch
.
Tensor
,
ssm_state
:
torch
.
Tensor
):
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
):
# 1. Gated MLP's linear projection
projected_states
=
self
.
in_proj
(
hidden_states
)[
0
].
transpose
(
-
2
,
-
1
)
...
...
@@ -134,17 +135,18 @@ class MambaMixer(nn.Module):
conv_weights
,
self
.
conv1d
.
bias
,
activation
=
self
.
activation
,
conv_states
=
conv_state
,
conv_states
=
mamba_cache_params
.
conv_state
,
has_initial_state
=
attn_metadata
.
context_lens_tensor
>
0
,
cache_indices
=
mamba_cache_params
.
state_indices_tensor
,
query_start_loc
=
attn_metadata
.
query_start_loc
)
else
:
hidden_states
=
causal_conv1d_update
(
hidden_states
.
transpose
(
0
,
1
),
conv_state
,
mamba_cache_params
.
conv_state
,
conv_weights
,
self
.
conv1d
.
bias
,
self
.
activation
,
)
conv_state_indices
=
mamba_cache_params
.
state_indices_tensor
)
hidden_states
=
hidden_states
.
transpose
(
0
,
1
)
# 3. State Space Model sequence transformation
...
...
@@ -168,7 +170,7 @@ class MambaMixer(nn.Module):
and
attn_metadata
.
context_lens_tensor
is
not
None
:
scan_outputs
=
selective_scan_fn
(
hidden_states
,
ssm_state
,
mamba_cache_params
.
ssm_state
,
discrete_time_step
,
self
.
A
,
B
.
transpose
(
-
2
,
-
1
),
...
...
@@ -177,11 +179,12 @@ class MambaMixer(nn.Module):
gate
,
time_proj_bias
,
delta_softplus
=
True
,
cache_indices
=
mamba_cache_params
.
state_indices_tensor
,
has_initial_state
=
attn_metadata
.
context_lens_tensor
>
0
,
query_start_loc
=
attn_metadata
.
query_start_loc
)
else
:
scan_outputs
=
selective_state_update
(
ssm_state
,
mamba_cache_params
.
ssm_state
,
hidden_states
.
transpose
(
0
,
1
),
discrete_time_step
.
transpose
(
0
,
1
),
self
.
A
,
...
...
@@ -191,7 +194,7 @@ class MambaMixer(nn.Module):
gate
.
transpose
(
0
,
1
),
time_proj_bias
,
dt_softplus
=
True
,
)
state_batch_indices
=
mamba_cache_params
.
state_indices_tensor
)
scan_outputs
=
scan_outputs
.
transpose
(
0
,
1
)
# 4. Final linear projection
...
...
@@ -221,8 +224,7 @@ class MambaDecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
conv_state
:
torch
.
Tensor
,
ssm_state
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
**
kwargs
,
):
if
residual
is
None
:
...
...
@@ -231,8 +233,8 @@ class MambaDecoderLayer(nn.Module):
else
:
hidden_states
,
residual
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mixer
(
hidden_states
,
attn_metadata
,
conv_state
,
ssm_state
)
hidden_states
=
self
.
mixer
(
hidden_states
,
attn_metadata
,
mamba_cache_params
)
return
hidden_states
,
residual
...
...
@@ -275,25 +277,20 @@ class MambaModel(nn.Module):
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
conv_state
:
torch
.
Tensor
,
ssm_state
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embeddings
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
current_ssm_state
=
ssm_state
[
i
]
current_conv_state
=
conv_state
[
i
]
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
attn_metadata
=
attn_metadata
,
residual
=
residual
,
conv_state
=
current_conv_state
,
ssm_state
=
current_ssm_state
,
)
mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
i
))
hidden_states
,
_
=
self
.
norm_f
(
hidden_states
,
residual
)
return
hidden_states
...
...
@@ -347,12 +344,18 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
self
.
lm_head
.
weight
.
dtype
,
self
.
config
.
num_hidden_layers
,
max_batch_size
,
*
self
.
_get_mamba_cache_shape
())
mamba_cache_tensors
=
self
.
mamba_cache
.
current_run_tensors
(
input_ids
,
attn_metadata
,
**
kwargs
)
(
mamba_cache_tensors
,
state_indices_tensor
,
)
=
self
.
mamba_cache
.
current_run_tensors
(
input_ids
,
attn_metadata
,
**
kwargs
)
mamba_cache_params
=
MambaCacheParams
(
mamba_cache_tensors
[
0
],
mamba_cache_tensors
[
1
],
state_indices_tensor
)
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
attn_metadata
,
mamba_cache_tensors
[
0
],
mamba_cache_tensors
[
1
])
mamba_cache_params
)
return
hidden_states
...
...
vllm/model_executor/models/mamba_cache.py
View file @
fb60ae9b
from
typing
import
Dict
,
List
,
Optional
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
import
torch
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
@
dataclass
class
MambaCacheParams
:
conv_state
:
torch
.
Tensor
=
torch
.
Tensor
()
ssm_state
:
torch
.
Tensor
=
torch
.
Tensor
()
state_indices_tensor
:
torch
.
Tensor
=
torch
.
Tensor
()
def
at_layer_idx
(
self
,
layer_idx
):
return
MambaCacheParams
(
self
.
conv_state
[
layer_idx
],
self
.
ssm_state
[
layer_idx
],
self
.
state_indices_tensor
)
class
MambaCacheManager
:
...
...
@@ -24,6 +38,7 @@ class MambaCacheManager:
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the self.mamba_cache
self
.
mamba_cache_indices_mapping
:
Dict
[
str
,
Dict
[
int
,
int
]]
=
{}
self
.
free_cache_indices
=
list
(
range
(
max_batch_size
))
def
current_run_tensors
(
self
,
input_ids
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
**
kwargs
):
...
...
@@ -36,30 +51,43 @@ class MambaCacheManager:
finished_requests_ids
=
kwargs
[
"finished_requests_ids"
]
self
.
_release_finished_requests
(
finished_requests_ids
)
mamba_cache_tensor
s
=
self
.
_prepare_current_run_mamba_cache
(
state_indice
s
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
finished_requests_ids
)
state_indices_tensor
=
torch
.
as_tensor
(
state_indices
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
mamba_cache_tensors
=
self
.
mamba_cache
else
:
# CUDA graph capturing runs
mamba_cache_tensors
=
kwargs
[
"seqlen_agnostic_capture_inputs"
]
(
mamba_cache_tensors
,
state_indices_tensor
)
=
kwargs
[
"seqlen_agnostic_capture_inputs"
]
return
mamba_cache_tensors
return
(
mamba_cache_tensors
,
state_indices_tensor
)
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
"""
Copy the relevant Mamba cache into the CUDA graph input buffer
that was provided during the capture runs
(JambaForCausalLM.mamba_gc_cache_buffer).
Copy the relevant state_indices into the CUDA graph input buffer
"""
assert
all
(
key
in
kwargs
for
key
in
[
"request_ids_to_seq_ids"
,
"finished_requests_ids"
])
finished_requests_ids
=
kwargs
[
"finished_requests_ids"
]
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
assert
"seqlen_agnostic_capture_inputs"
in
input_buffers
_
,
input_state_indices_buffer
=
input_buffers
[
"seqlen_agnostic_capture_inputs"
]
self
.
_release_finished_requests
(
finished_requests_ids
)
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
finished_requests_ids
)
state_indices
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
finished_requests_ids
)
cuda_graph_pad_len
=
input_state_indices_buffer
.
shape
[
0
]
-
len
(
state_indices
)
state_indices
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_len
)
input_state_indices_buffer
.
copy_
(
torch
.
as_tensor
(
state_indices
,
dtype
=
torch
.
int32
,
device
=
"cuda"
))
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
):
"""
...
...
@@ -67,13 +95,10 @@ class MambaCacheManager:
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
"""
return
tuple
(
buffer
[:,
:
batch_size
]
for
buffer
in
self
.
mamba_cache
)
def
_swap_mamba_cache
(
self
,
from_index
:
int
,
to_index
:
int
):
assert
len
(
self
.
mamba_cache
)
>
0
for
cache_t
in
self
.
mamba_cache
:
cache_t
[:,
[
to_index
,
from_index
]]
=
\
cache_t
[:,
[
from_index
,
to_index
]]
state_indices_tensor
=
torch
.
as_tensor
([
PAD_SLOT_ID
]
*
batch_size
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
return
(
self
.
mamba_cache
,
state_indices_tensor
)
def
_copy_mamba_cache
(
self
,
from_index
:
int
,
to_index
:
int
):
assert
len
(
self
.
mamba_cache
)
>
0
...
...
@@ -81,142 +106,53 @@ class MambaCacheManager:
cache_t
[:,
to_index
].
copy_
(
cache_t
[:,
from_index
],
non_blocking
=
True
)
def
_move_out_if_already_occupied
(
self
,
index
:
int
,
all_occupied_indices
:
List
[
int
]):
if
index
in
all_occupied_indices
:
first_free_index
=
self
.
_first_free_index_in_mamba_cache
()
# In case occupied, move the occupied to a new empty block
self
.
_move_cache_index_and_mappings
(
from_index
=
index
,
to_index
=
first_free_index
)
def
_assign_seq_id_to_mamba_cache_in_specific_dest
(
self
,
cur_rid
:
str
,
seq_id
:
int
,
destination_index
:
int
):
def
_assign_seq_id_to_cache_index
(
self
,
cur_rid
:
str
,
seq_id
:
int
,
finished_requests_ids
)
->
int
:
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
all_occupied_indices
=
self
.
_get_all_occupied_indices
()
if
cur_rid
not
in
self
.
mamba_cache_indices_mapping
:
self
.
_move_out_if_already_occupied
(
index
=
destination_index
,
all_occupied_indices
=
all_occupied
_indices
)
if
cur_rid
in
finished_requests_ids
:
# set as pad, do not allocate destination index
return
PAD_SLOT_ID
elif
cur_rid
not
in
self
.
mamba_cache_indices_mapping
:
destination_index
=
self
.
free_cache
_indices
.
pop
(
)
self
.
mamba_cache_indices_mapping
[
cur_rid
]
=
{
seq_id
:
destination_index
}
return
destination_index
elif
seq_id
not
in
(
seq_ids2indices
:
=
self
.
mamba_cache_indices_mapping
[
cur_rid
]):
# parallel sampling , where n > 1, assume prefill have
# already happened
now we only need to copy the already
# already happened
, so we copy the
# existing cache into the siblings seq_ids caches
self
.
_move_out_if_already_occupied
(
index
=
destination_index
,
all_occupied_indices
=
all_occupied_indices
)
index_exists
=
list
(
seq_ids2indices
.
values
())[
0
]
index_exists
=
next
(
iter
(
seq_ids2indices
.
values
()))
# case of decoding n>1, copy prefill cache to decoding indices
destination_index
=
self
.
free_cache_indices
.
pop
()
self
.
_copy_mamba_cache
(
from_index
=
index_exists
,
to_index
=
destination_index
)
self
.
mamba_cache_indices_mapping
[
cur_rid
][
seq_id
]
=
destination_index
return
destination_index
else
:
# already exists
cache_index_already_exists
=
self
.
mamba_cache_indices_mapping
[
cur_rid
][
seq_id
]
if
cache_index_already_exists
!=
destination_index
:
# In case the seq id already exists but not in
# the right destination, swap it with what's occupying it
self
.
_swap_pair_indices_and_mappings
(
from_index
=
cache_index_already_exists
,
to_index
=
destination_index
)
return
self
.
mamba_cache_indices_mapping
[
cur_rid
][
seq_id
]
def
_prepare_current_run_mamba_cache
(
self
,
request_ids_to_seq_ids
:
Dict
[
str
,
list
[
int
]],
finished_requests_ids
:
List
[
str
]):
r
unning_indices
=
[
]
request
_id
s
_to_
seq_ids_flatten
=
[
(
req_id
,
seq
_id
)
finished_requests_ids
:
List
[
str
])
->
List
[
int
]
:
r
eturn
[
self
.
_assign_seq
_id_to_
cache_index
(
req_id
,
seq_id
,
finished_requests
_id
s
)
for
req_id
,
seq_ids
in
request_ids_to_seq_ids
.
items
()
for
seq_id
in
seq_ids
]
batch_size
=
len
(
request_ids_to_seq_ids_flatten
)
for
dest_index
,
(
request_id
,
seq_id
)
in
enumerate
(
request_ids_to_seq_ids_flatten
):
if
request_id
in
finished_requests_ids
:
# Do not allocate cache index for requests that run
# and finish right after
continue
self
.
_assign_seq_id_to_mamba_cache_in_specific_dest
(
request_id
,
seq_id
,
dest_index
)
running_indices
.
append
(
dest_index
)
self
.
_clean_up_first_bs_blocks
(
batch_size
,
running_indices
)
conv_state
=
self
.
mamba_cache
[
0
][:,
:
batch_size
]
temporal_state
=
self
.
mamba_cache
[
1
][:,
:
batch_size
]
return
(
conv_state
,
temporal_state
)
def
_get_all_occupied_indices
(
self
):
return
[
cache_idx
for
seq_ids2indices
in
self
.
mamba_cache_indices_mapping
.
values
()
for
cache_idx
in
seq_ids2indices
.
values
()
]
def
_clean_up_first_bs_blocks
(
self
,
batch_size
:
int
,
indices_for_current_run
:
List
[
int
]):
# move out all of the occupied but currently not running blocks
# outside of the first n blocks
destination_indices
=
range
(
batch_size
)
max_possible_batch_size
=
self
.
mamba_cache
[
0
].
shape
[
1
]
for
destination_index
in
destination_indices
:
if
destination_index
in
self
.
_get_all_occupied_indices
()
and
\
destination_index
not
in
indices_for_current_run
:
# move not running indices outside of the batch
all_other_indices
=
list
(
range
(
batch_size
,
max_possible_batch_size
))
first_avail_index
=
self
.
_first_free_index_in_mamba_cache
(
all_other_indices
)
self
.
_swap_indices
(
from_index
=
destination_index
,
to_index
=
first_avail_index
)
def
_move_cache_index_and_mappings
(
self
,
from_index
:
int
,
to_index
:
int
):
self
.
_copy_mamba_cache
(
from_index
=
from_index
,
to_index
=
to_index
)
self
.
_update_mapping_index
(
from_index
=
from_index
,
to_index
=
to_index
)
def
_swap_pair_indices_and_mappings
(
self
,
from_index
:
int
,
to_index
:
int
):
self
.
_swap_mamba_cache
(
from_index
=
from_index
,
to_index
=
to_index
)
self
.
_swap_mapping_index
(
from_index
=
from_index
,
to_index
=
to_index
)
def
_swap_mapping_index
(
self
,
from_index
:
int
,
to_index
:
int
):
for
seq_ids2index
in
self
.
mamba_cache_indices_mapping
.
values
():
for
seq_id
,
index
in
seq_ids2index
.
items
():
if
from_index
==
index
:
seq_ids2index
.
update
({
seq_id
:
to_index
})
elif
to_index
==
index
:
seq_ids2index
.
update
({
seq_id
:
from_index
})
def
_update_mapping_index
(
self
,
from_index
:
int
,
to_index
:
int
):
for
seq_ids2index
in
self
.
mamba_cache_indices_mapping
.
values
():
for
seq_id
,
index
in
seq_ids2index
.
items
():
if
from_index
==
index
:
seq_ids2index
.
update
({
seq_id
:
to_index
})
return
def
_release_finished_requests
(
self
,
finished_seq_groups_req_ids
:
List
[
str
]):
for
req_id
in
finished_seq_groups_req_ids
:
if
req_id
in
self
.
mamba_cache_indices_mapping
:
for
seq_id
in
self
.
mamba_cache_indices_mapping
[
req_id
]:
self
.
free_cache_indices
.
append
(
self
.
mamba_cache_indices_mapping
[
req_id
][
seq_id
])
self
.
mamba_cache_indices_mapping
.
pop
(
req_id
)
def
_first_free_index_in_mamba_cache
(
self
,
indices_range
:
Optional
[
List
[
int
]]
=
None
)
->
int
:
assert
self
.
mamba_cache
is
not
None
if
indices_range
is
None
:
max_possible_batch_size
=
self
.
mamba_cache
[
0
].
shape
[
1
]
indices_range
=
list
(
range
(
max_possible_batch_size
))
all_occupied_indices
=
self
.
_get_all_occupied_indices
()
for
i
in
indices_range
:
if
i
not
in
all_occupied_indices
:
return
i
raise
Exception
(
"Couldn't find a free spot in the mamba cache! This"
"should never happen"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment