Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fb60ae9b
Unverified
Commit
fb60ae9b
authored
Oct 17, 2024
by
Mor Zusman
Committed by
GitHub
Oct 16, 2024
Browse files
[Kernel][Model] Improve continuous batching for Jamba and Mamba (#9189)
parent
415f76a9
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
511 additions
and
439 deletions
+511
-439
csrc/mamba/causal_conv1d/causal_conv1d.cu
csrc/mamba/causal_conv1d/causal_conv1d.cu
+25
-12
csrc/mamba/causal_conv1d/causal_conv1d.h
csrc/mamba/causal_conv1d/causal_conv1d.h
+1
-0
csrc/mamba/mamba_ssm/selective_scan.h
csrc/mamba/mamba_ssm/selective_scan.h
+1
-0
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
+14
-10
csrc/ops.h
csrc/ops.h
+17
-15
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+6
-3
tests/kernels/test_causal_conv1d.py
tests/kernels/test_causal_conv1d.py
+97
-94
tests/kernels/test_mamba_ssm.py
tests/kernels/test_mamba_ssm.py
+89
-35
tests/models/decoder_only/language/test_jamba.py
tests/models/decoder_only/language/test_jamba.py
+25
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+40
-33
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
+33
-20
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
+43
-27
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+31
-40
vllm/model_executor/models/mamba.py
vllm/model_executor/models/mamba.py
+28
-25
vllm/model_executor/models/mamba_cache.py
vllm/model_executor/models/mamba_cache.py
+61
-125
No files found.
csrc/mamba/causal_conv1d/causal_conv1d.cu
View file @
fb60ae9b
...
@@ -55,6 +55,7 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
...
@@ -55,6 +55,7 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
const
at
::
Tensor
out
,
const
at
::
Tensor
out
,
const
c10
::
optional
<
at
::
Tensor
>&
bias
,
const
c10
::
optional
<
at
::
Tensor
>&
bias
,
bool
silu_activation
,
bool
silu_activation
,
int64_t
pad_slot_id
,
const
c10
::
optional
<
at
::
Tensor
>&
query_start_loc
=
std
::
nullopt
,
const
c10
::
optional
<
at
::
Tensor
>&
query_start_loc
=
std
::
nullopt
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_indices
=
std
::
nullopt
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_indices
=
std
::
nullopt
,
const
c10
::
optional
<
at
::
Tensor
>&
has_initial_state
=
std
::
nullopt
)
{
const
c10
::
optional
<
at
::
Tensor
>&
has_initial_state
=
std
::
nullopt
)
{
...
@@ -66,6 +67,7 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
...
@@ -66,6 +67,7 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
params
.
dim
=
dim
;
params
.
dim
=
dim
;
params
.
seqlen
=
seqlen
;
params
.
seqlen
=
seqlen
;
params
.
width
=
width
;
params
.
width
=
width
;
params
.
pad_slot_id
=
pad_slot_id
;
params
.
silu_activation
=
silu_activation
;
params
.
silu_activation
=
silu_activation
;
...
@@ -90,14 +92,16 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
...
@@ -90,14 +92,16 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
}
}
at
::
Tensor
void
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>
&
bias_
,
const
c10
::
optional
<
at
::
Tensor
>
&
bias_
,
const
c10
::
optional
<
at
::
Tensor
>
&
conv_states
,
const
c10
::
optional
<
at
::
Tensor
>
&
conv_states
,
const
c10
::
optional
<
at
::
Tensor
>
&
query_start_loc
,
const
c10
::
optional
<
at
::
Tensor
>
&
query_start_loc
,
const
c10
::
optional
<
at
::
Tensor
>
&
cache_indices
,
const
c10
::
optional
<
at
::
Tensor
>
&
cache_indices
,
const
c10
::
optional
<
at
::
Tensor
>
&
has_initial_state
,
const
c10
::
optional
<
at
::
Tensor
>
&
has_initial_state
,
bool
silu_activation
)
{
bool
silu_activation
,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t
pad_slot_id
)
{
auto
input_type
=
x
.
scalar_type
();
auto
input_type
=
x
.
scalar_type
();
auto
weight_type
=
weight
.
scalar_type
();
auto
weight_type
=
weight
.
scalar_type
();
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
...
@@ -153,12 +157,13 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
...
@@ -153,12 +157,13 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
CHECK_SHAPE
(
cache_indices_
,
batch_size
);
CHECK_SHAPE
(
cache_indices_
,
batch_size
);
}
}
at
::
Tensor
out
=
torch
::
empty_like
(
x
)
;
at
::
Tensor
out
=
x
;
ConvParamsBase
params
;
ConvParamsBase
params
;
set_conv_params_fwd
(
params
,
batch_size
,
dim
,
seqlen
,
width
,
x
,
weight
,
out
,
set_conv_params_fwd
(
params
,
batch_size
,
dim
,
seqlen
,
width
,
x
,
weight
,
out
,
bias_
,
bias_
,
silu_activation
,
silu_activation
,
pad_slot_id
,
query_start_loc
,
query_start_loc
,
cache_indices
,
cache_indices
,
has_initial_state
has_initial_state
...
@@ -183,18 +188,19 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
...
@@ -183,18 +188,19 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16
(
x
.
scalar_type
(),
"causal_conv1d_fwd"
,
[
&
]
{
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16
(
x
.
scalar_type
(),
"causal_conv1d_fwd"
,
[
&
]
{
causal_conv1d_fwd_cuda
<
input_t
,
weight_t
>
(
params
,
stream
);
causal_conv1d_fwd_cuda
<
input_t
,
weight_t
>
(
params
,
stream
);
});
});
return
out
;
}
}
at
::
Tensor
void
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>
&
bias_
,
const
c10
::
optional
<
at
::
Tensor
>
&
bias_
,
bool
silu_activation
,
bool
silu_activation
,
const
c10
::
optional
<
at
::
Tensor
>
&
cache_seqlens_
,
const
c10
::
optional
<
at
::
Tensor
>
&
cache_seqlens_
,
const
c10
::
optional
<
at
::
Tensor
>
&
conv_state_indices_
)
{
const
c10
::
optional
<
at
::
Tensor
>
&
conv_state_indices_
,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t
pad_slot_id
)
{
auto
input_type
=
x
.
scalar_type
();
auto
input_type
=
x
.
scalar_type
();
auto
weight_type
=
weight
.
scalar_type
();
auto
weight_type
=
weight
.
scalar_type
();
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
...
@@ -227,12 +233,13 @@ causal_conv1d_update(const at::Tensor &x,
...
@@ -227,12 +233,13 @@ causal_conv1d_update(const at::Tensor &x,
CHECK_SHAPE
(
bias
,
dim
);
CHECK_SHAPE
(
bias
,
dim
);
}
}
at
::
Tensor
out
=
torch
::
empty_like
(
x
)
;
at
::
Tensor
out
=
x
;
ConvParamsBase
params
;
ConvParamsBase
params
;
set_conv_params_fwd
(
params
,
batch_size
,
dim
,
seqlen
,
width
,
x
,
weight
,
out
,
set_conv_params_fwd
(
params
,
batch_size
,
dim
,
seqlen
,
width
,
x
,
weight
,
out
,
bias_
,
bias_
,
silu_activation
);
silu_activation
,
pad_slot_id
);
params
.
conv_state_ptr
=
conv_state
.
data_ptr
();
params
.
conv_state_ptr
=
conv_state
.
data_ptr
();
params
.
conv_state_len
=
conv_state_len
;
params
.
conv_state_len
=
conv_state_len
;
// All stride are in elements, not bytes.
// All stride are in elements, not bytes.
...
@@ -274,7 +281,6 @@ causal_conv1d_update(const at::Tensor &x,
...
@@ -274,7 +281,6 @@ causal_conv1d_update(const at::Tensor &x,
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16
(
x
.
scalar_type
(),
"causal_conv1d_update"
,
[
&
]
{
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16
(
x
.
scalar_type
(),
"causal_conv1d_update"
,
[
&
]
{
causal_conv1d_update_cuda
<
input_t
,
weight_t
>
(
params
,
stream
);
causal_conv1d_update_cuda
<
input_t
,
weight_t
>
(
params
,
stream
);
});
});
return
out
;
}
}
template
<
int
kNThreads_
,
int
kWidth_
,
bool
kIsVecLoad_
,
typename
input_t_
,
typename
weight_t_
>
template
<
int
kNThreads_
,
int
kWidth_
,
bool
kIsVecLoad_
,
typename
input_t_
,
typename
weight_t_
>
...
@@ -340,7 +346,10 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
...
@@ -340,7 +346,10 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
int
*
cache_indices
=
params
.
cache_indices_ptr
==
nullptr
?
nullptr
int
*
cache_indices
=
params
.
cache_indices_ptr
==
nullptr
?
nullptr
:
reinterpret_cast
<
int
*>
(
params
.
cache_indices_ptr
);
:
reinterpret_cast
<
int
*>
(
params
.
cache_indices_ptr
);
int
cache_index
=
cache_indices
==
nullptr
?
batch_id
:
cache_indices
[
batch_id
];
int
cache_index
=
cache_indices
==
nullptr
?
batch_id
:
cache_indices
[
batch_id
];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if
(
cache_index
==
params
.
pad_slot_id
){
return
;
}
input_t
*
conv_states
=
params
.
conv_states_ptr
==
nullptr
?
nullptr
input_t
*
conv_states
=
params
.
conv_states_ptr
==
nullptr
?
nullptr
:
reinterpret_cast
<
input_t
*>
(
params
.
conv_states_ptr
)
+
cache_index
*
params
.
conv_states_batch_stride
+
channel_id
*
params
.
conv_states_c_stride
;
:
reinterpret_cast
<
input_t
*>
(
params
.
conv_states_ptr
)
+
cache_index
*
params
.
conv_states_batch_stride
+
channel_id
*
params
.
conv_states_c_stride
;
...
@@ -528,6 +537,10 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
...
@@ -528,6 +537,10 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
const
int
conv_state_batch_coord
=
params
.
conv_state_indices_ptr
==
nullptr
const
int
conv_state_batch_coord
=
params
.
conv_state_indices_ptr
==
nullptr
?
batch_id
?
batch_id
:
params
.
conv_state_indices_ptr
[
batch_id
];
:
params
.
conv_state_indices_ptr
[
batch_id
];
// conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early
if
(
conv_state_batch_coord
==
params
.
pad_slot_id
){
return
;
}
input_t
*
conv_state
=
reinterpret_cast
<
input_t
*>
(
params
.
conv_state_ptr
)
input_t
*
conv_state
=
reinterpret_cast
<
input_t
*>
(
params
.
conv_state_ptr
)
+
conv_state_batch_coord
*
params
.
conv_state_batch_stride
+
conv_state_batch_coord
*
params
.
conv_state_batch_stride
+
channel_id
*
params
.
conv_state_c_stride
;
+
channel_id
*
params
.
conv_state_c_stride
;
...
...
csrc/mamba/causal_conv1d/causal_conv1d.h
View file @
fb60ae9b
...
@@ -13,6 +13,7 @@ struct ConvParamsBase {
...
@@ -13,6 +13,7 @@ struct ConvParamsBase {
using
index_t
=
uint32_t
;
using
index_t
=
uint32_t
;
int
batch
,
dim
,
seqlen
,
width
;
int
batch
,
dim
,
seqlen
,
width
;
int64_t
pad_slot_id
;
bool
silu_activation
;
bool
silu_activation
;
index_t
x_batch_stride
;
index_t
x_batch_stride
;
...
...
csrc/mamba/mamba_ssm/selective_scan.h
View file @
fb60ae9b
...
@@ -21,6 +21,7 @@ struct SSMParamsBase {
...
@@ -21,6 +21,7 @@ struct SSMParamsBase {
int
dim_ngroups_ratio
;
int
dim_ngroups_ratio
;
bool
is_variable_B
;
bool
is_variable_B
;
bool
is_variable_C
;
bool
is_variable_C
;
int64_t
pad_slot_id
;
bool
delta_softplus
;
bool
delta_softplus
;
...
...
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
View file @
fb60ae9b
...
@@ -115,6 +115,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
...
@@ -115,6 +115,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
const
int
*
cache_indices
=
params
.
cache_indices_ptr
==
nullptr
?
nullptr
const
int
*
cache_indices
=
params
.
cache_indices_ptr
==
nullptr
?
nullptr
:
reinterpret_cast
<
int
*>
(
params
.
cache_indices_ptr
);
:
reinterpret_cast
<
int
*>
(
params
.
cache_indices_ptr
);
const
int
cache_index
=
cache_indices
==
nullptr
?
batch_id
:
cache_indices
[
batch_id
];
const
int
cache_index
=
cache_indices
==
nullptr
?
batch_id
:
cache_indices
[
batch_id
];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if
(
cache_index
==
params
.
pad_slot_id
){
return
;
}
input_t
*
u
=
reinterpret_cast
<
input_t
*>
(
params
.
u_ptr
)
+
sequence_start_index
*
params
.
u_batch_stride
input_t
*
u
=
reinterpret_cast
<
input_t
*>
(
params
.
u_ptr
)
+
sequence_start_index
*
params
.
u_batch_stride
+
dim_id
*
kNRows
*
params
.
u_d_stride
;
+
dim_id
*
kNRows
*
params
.
u_d_stride
;
input_t
*
delta
=
reinterpret_cast
<
input_t
*>
(
params
.
delta_ptr
)
+
sequence_start_index
*
params
.
delta_batch_stride
input_t
*
delta
=
reinterpret_cast
<
input_t
*>
(
params
.
delta_ptr
)
+
sequence_start_index
*
params
.
delta_batch_stride
...
@@ -387,7 +391,6 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
...
@@ -387,7 +391,6 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
const
size_t
seqlen
,
const
size_t
seqlen
,
const
size_t
dstate
,
const
size_t
dstate
,
const
size_t
n_groups
,
const
size_t
n_groups
,
const
size_t
n_chunks
,
const
bool
is_variable_B
,
const
bool
is_variable_B
,
const
bool
is_variable_C
,
const
bool
is_variable_C
,
// device pointers
// device pointers
...
@@ -407,7 +410,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
...
@@ -407,7 +410,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
const
c10
::
optional
<
at
::
Tensor
>&
query_start_loc
,
const
c10
::
optional
<
at
::
Tensor
>&
query_start_loc
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_indices
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_indices
,
const
c10
::
optional
<
at
::
Tensor
>&
has_initial_state
,
const
c10
::
optional
<
at
::
Tensor
>&
has_initial_state
,
bool
varlen
)
{
bool
varlen
,
int64_t
pad_slot_id
)
{
// Reset the parameters
// Reset the parameters
memset
(
&
params
,
0
,
sizeof
(
params
));
memset
(
&
params
,
0
,
sizeof
(
params
));
...
@@ -417,8 +421,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
...
@@ -417,8 +421,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
params
.
seqlen
=
seqlen
;
params
.
seqlen
=
seqlen
;
params
.
dstate
=
dstate
;
params
.
dstate
=
dstate
;
params
.
n_groups
=
n_groups
;
params
.
n_groups
=
n_groups
;
params
.
n_chunks
=
n_chunks
;
params
.
dim_ngroups_ratio
=
dim
/
n_groups
;
params
.
dim_ngroups_ratio
=
dim
/
n_groups
;
params
.
pad_slot_id
=
pad_slot_id
;
params
.
delta_softplus
=
delta_softplus
;
params
.
delta_softplus
=
delta_softplus
;
...
@@ -507,7 +511,10 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
...
@@ -507,7 +511,10 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const
c10
::
optional
<
torch
::
Tensor
>
&
query_start_loc
,
const
c10
::
optional
<
torch
::
Tensor
>
&
query_start_loc
,
const
c10
::
optional
<
torch
::
Tensor
>
&
cache_indices
,
const
c10
::
optional
<
torch
::
Tensor
>
&
cache_indices
,
const
c10
::
optional
<
torch
::
Tensor
>
&
has_initial_state
,
const
c10
::
optional
<
torch
::
Tensor
>
&
has_initial_state
,
const
torch
::
Tensor
&
ssm_states
)
{
const
torch
::
Tensor
&
ssm_states
,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t
pad_slot_id
)
{
auto
input_type
=
u
.
scalar_type
();
auto
input_type
=
u
.
scalar_type
();
auto
weight_type
=
A
.
scalar_type
();
auto
weight_type
=
A
.
scalar_type
();
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
...
@@ -618,18 +625,14 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
...
@@ -618,18 +625,14 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
out_z
=
z
;
out_z
=
z
;
const
int
n_chunks
=
(
seqlen
+
2048
-
1
)
/
2048
;
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
// at::Tensor out = torch::empty_like(u);
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
at
::
Tensor
out
=
delta
;
at
::
Tensor
out
=
delta
;
TORCH_CHECK
(
ssm_states
.
scalar_type
()
==
input_type
);
TORCH_CHECK
(
ssm_states
.
scalar_type
()
==
input_type
);
TORCH_CHECK
(
ssm_states
.
is_cuda
());
TORCH_CHECK
(
ssm_states
.
is_cuda
());
TORCH_CHECK
(
ssm_states
.
stride
(
-
1
)
==
1
);
TORCH_CHECK
(
ssm_states
.
stride
(
-
1
)
==
1
);
CHECK_SHAPE
(
ssm_states
,
batch_size
,
dim
,
dstate
);
SSMParamsBase
params
;
SSMParamsBase
params
;
set_ssm_params_fwd
(
params
,
batch_size
,
dim
,
seqlen
,
dstate
,
n_groups
,
n_chunks
,
is_variable_B
,
is_variable_C
,
set_ssm_params_fwd
(
params
,
batch_size
,
dim
,
seqlen
,
dstate
,
n_groups
,
is_variable_B
,
is_variable_C
,
u
,
delta
,
A
,
B
,
C
,
out
,
z
,
out_z
,
u
,
delta
,
A
,
B
,
C
,
out
,
z
,
out_z
,
D_
,
D_
,
delta_bias_
,
delta_bias_
,
...
@@ -639,7 +642,8 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
...
@@ -639,7 +642,8 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
query_start_loc
,
query_start_loc
,
cache_indices
,
cache_indices
,
has_initial_state
,
has_initial_state
,
varlen
varlen
,
pad_slot_id
);
);
...
...
csrc/ops.h
View file @
fb60ae9b
...
@@ -157,21 +157,23 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
...
@@ -157,21 +157,23 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const
c10
::
optional
<
torch
::
Tensor
>&
query_start_loc
,
const
c10
::
optional
<
torch
::
Tensor
>&
query_start_loc
,
const
c10
::
optional
<
torch
::
Tensor
>&
cache_indices
,
const
c10
::
optional
<
torch
::
Tensor
>&
cache_indices
,
const
c10
::
optional
<
torch
::
Tensor
>&
has_initial_state
,
const
c10
::
optional
<
torch
::
Tensor
>&
has_initial_state
,
const
torch
::
Tensor
&
ssm_states
);
const
torch
::
Tensor
&
ssm_states
,
int64_t
pad_slot_id
);
at
::
Tensor
causal_conv1d_update
(
void
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
bool
silu_activation
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
bool
silu_activation
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_seqlens_
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_seqlens_
,
const
c10
::
optional
<
at
::
Tensor
>&
conv_state_indices_
);
const
c10
::
optional
<
at
::
Tensor
>&
conv_state_indices_
,
int64_t
pad_slot_id
);
at
::
Tensor
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
void
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
const
c10
::
optional
<
at
::
Tensor
>&
conv_states
,
const
c10
::
optional
<
at
::
Tensor
>&
conv_states
,
const
c10
::
optional
<
at
::
Tensor
>&
query_start_loc
,
const
c10
::
optional
<
at
::
Tensor
>&
query_start_loc
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_indices
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_indices
,
const
c10
::
optional
<
at
::
Tensor
>&
has_initial_state
,
const
c10
::
optional
<
at
::
Tensor
>&
has_initial_state
,
bool
silu_activation
);
bool
silu_activation
,
int64_t
pad_slot_id
);
#ifndef USE_ROCM
#ifndef USE_ROCM
using
fptr_t
=
int64_t
;
using
fptr_t
=
int64_t
;
...
...
csrc/torch_bindings.cpp
View file @
fb60ae9b
...
@@ -278,7 +278,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -278,7 +278,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? query_start_loc,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"Tensor? has_initial_state,"
"Tensor! ssm_states) -> ()"
);
"Tensor! ssm_states,"
"int pad_slot_id) -> ()"
);
ops
.
impl
(
"selective_scan_fwd"
,
torch
::
kCUDA
,
&
selective_scan_fwd
);
ops
.
impl
(
"selective_scan_fwd"
,
torch
::
kCUDA
,
&
selective_scan_fwd
);
ops
.
def
(
ops
.
def
(
...
@@ -288,7 +289,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -288,7 +289,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? bias_,"
"Tensor? bias_,"
"bool silu_activation,"
"bool silu_activation,"
"Tensor? cache_seqlens_,"
"Tensor? cache_seqlens_,"
"Tensor? conv_state_indices) -> Tensor"
);
"Tensor? conv_state_indices,"
"int pad_slot_id) -> ()"
);
ops
.
impl
(
"causal_conv1d_update"
,
torch
::
kCUDA
,
&
causal_conv1d_update
);
ops
.
impl
(
"causal_conv1d_update"
,
torch
::
kCUDA
,
&
causal_conv1d_update
);
ops
.
def
(
ops
.
def
(
...
@@ -298,7 +300,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -298,7 +300,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? query_start_loc,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"Tensor? has_initial_state,"
"bool silu_activation) -> Tensor"
);
"bool silu_activation,"
"int pad_slot_id) -> ()"
);
ops
.
impl
(
"causal_conv1d_fwd"
,
torch
::
kCUDA
,
&
causal_conv1d_fwd
);
ops
.
impl
(
"causal_conv1d_fwd"
,
torch
::
kCUDA
,
&
causal_conv1d_fwd
);
#endif
#endif
...
...
tests/kernels/test_causal_conv1d.py
View file @
fb60ae9b
...
@@ -6,6 +6,7 @@ import torch.nn.functional as F
...
@@ -6,6 +6,7 @@ import torch.nn.functional as F
from
tests.kernels.utils
import
opcheck
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
# noqa: F401
from
vllm
import
_custom_ops
as
ops
# noqa: F401
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
causal_conv1d_update
)
causal_conv1d_fn
,
causal_conv1d_update
)
from
vllm.utils
import
seed_everything
from
vllm.utils
import
seed_everything
...
@@ -114,8 +115,7 @@ def causal_conv1d_update_ref(x,
...
@@ -114,8 +115,7 @@ def causal_conv1d_update_ref(x,
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
,
torch
.
float
])
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
,
torch
.
float
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
True
])
def
causal_conv1d_opcheck_fn
(
def
causal_conv1d_opcheck_fn
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seq_len
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seq_len
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -123,7 +123,7 @@ def causal_conv1d_opcheck_fn(
...
@@ -123,7 +123,7 @@ def causal_conv1d_opcheck_fn(
has_initial_state
:
Optional
[
torch
.
Tensor
]
=
None
,
has_initial_state
:
Optional
[
torch
.
Tensor
]
=
None
,
conv_states
:
Optional
[
torch
.
Tensor
]
=
None
,
conv_states
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
"silu"
,
activation
:
Optional
[
str
]
=
"silu"
,
):
pad_slot_id
:
int
=
PAD_SLOT_ID
):
"""
"""
x: (batch, dim, seqlen)
x: (batch, dim, seqlen)
weight: (dim, width)
weight: (dim, width)
...
@@ -141,16 +141,9 @@ def causal_conv1d_opcheck_fn(
...
@@ -141,16 +141,9 @@ def causal_conv1d_opcheck_fn(
x
=
x
.
contiguous
()
x
=
x
.
contiguous
()
bias
=
bias
.
contiguous
()
if
bias
is
not
None
else
None
bias
=
bias
.
contiguous
()
if
bias
is
not
None
else
None
opcheck
(
torch
.
ops
.
_C
.
causal_conv1d_fwd
,
(
opcheck
(
torch
.
ops
.
_C
.
causal_conv1d_fwd
,
x
,
(
x
,
weight
,
bias
,
conv_states
,
cu_seq_len
,
cache_indices
,
weight
,
has_initial_state
,
activation
in
[
"silu"
,
"swish"
],
pad_slot_id
))
bias
,
conv_states
,
cu_seq_len
,
cache_indices
,
has_initial_state
,
activation
in
[
"silu"
,
"swish"
],
))
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
,
torch
.
float
])
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
,
torch
.
float
])
...
@@ -233,17 +226,11 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
...
@@ -233,17 +226,11 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
seed_everything
(
0
)
seed_everything
(
0
)
batch
=
2
batch
=
2
x
=
torch
.
randn
(
batch
,
dim
,
seqlen
,
device
=
device
,
dtype
=
itype
)
x
=
torch
.
randn
(
batch
,
dim
,
seqlen
,
device
=
device
,
dtype
=
itype
)
x_ref
=
x
.
clone
()
conv_state
=
torch
.
randn
(
batch
,
dim
,
width
-
1
,
device
=
device
,
dtype
=
itype
)
conv_state
=
torch
.
randn
(
batch
,
dim
,
width
-
1
,
device
=
device
,
dtype
=
itype
)
weight
=
torch
.
randn
(
dim
,
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
width
,
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
device
=
device
,
dtype
=
itype
,
requires_grad
=
True
)
if
has_bias
:
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
,
requires_grad
=
True
)
else
:
bias
=
None
conv_state_ref
=
conv_state
.
detach
().
clone
()
conv_state_ref
=
conv_state
.
detach
().
clone
()
activation
=
None
if
not
silu_activation
else
"silu"
activation
=
None
if
not
silu_activation
else
"silu"
out
=
causal_conv1d_update
(
x
,
out
=
causal_conv1d_update
(
x
,
...
@@ -251,7 +238,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
...
@@ -251,7 +238,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
weight
,
weight
,
bias
,
bias
,
activation
=
activation
)
activation
=
activation
)
out_ref
=
causal_conv1d_update_ref
(
x
,
out_ref
=
causal_conv1d_update_ref
(
x
_ref
,
conv_state_ref
,
conv_state_ref
,
weight
,
weight
,
bias
,
bias
,
...
@@ -260,15 +247,9 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
...
@@ -260,15 +247,9 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
assert
torch
.
equal
(
conv_state
,
conv_state_ref
)
assert
torch
.
equal
(
conv_state
,
conv_state_ref
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
opcheck
(
torch
.
ops
.
_C
.
causal_conv1d_update
,
(
opcheck
(
torch
.
ops
.
_C
.
causal_conv1d_update
,
x
,
(
x
,
conv_state
,
weight
,
bias
,
activation
conv_state
,
in
[
"silu"
,
"swish"
],
None
,
None
,
PAD_SLOT_ID
))
weight
,
bias
,
activation
in
[
"silu"
,
"swish"
],
None
,
None
,
))
@
pytest
.
mark
.
parametrize
(
"itype"
,
@
pytest
.
mark
.
parametrize
(
"itype"
,
...
@@ -278,37 +259,48 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
...
@@ -278,37 +259,48 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
1
,
4
,
5
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
1
,
4
,
5
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
2
,
3
,
4
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
2
,
3
,
4
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
2048
+
16
,
4096
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
2048
+
16
,
4096
])
def
test_causal_conv1d_update_with_batch_gather
(
dim
,
width
,
seqlen
,
has_bias
,
# tests correctness in case subset of the sequences are padded
@
pytest
.
mark
.
parametrize
(
"with_padding"
,
[
True
,
False
])
def
test_causal_conv1d_update_with_batch_gather
(
with_padding
,
dim
,
width
,
seqlen
,
has_bias
,
silu_activation
,
itype
):
silu_activation
,
itype
):
device
=
"cuda"
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
rtol
,
atol
=
1e-2
,
5e-2
# set
)
seed
# set seed
seed_everything
(
0
)
seed_everything
(
0
)
batch
=
64
x
=
torch
.
randn
(
batch
,
dim
,
1
,
device
=
device
,
dtype
=
itype
)
batch_size
=
3
padding
=
5
if
with_padding
else
0
padded_batch_size
=
batch_size
+
padding
total_entries
=
10
*
batch_size
x
=
torch
.
randn
(
padded_batch_size
,
dim
,
1
,
device
=
device
,
dtype
=
itype
)
x_ref
=
x
.
clone
()
total_entries
=
10
*
batch
conv_state_indices
=
torch
.
randperm
(
total_entries
)[:
batch_size
].
to
(
dtype
=
torch
.
int32
,
device
=
device
)
unused_states_bool
=
torch
.
ones
(
total_entries
,
dtype
=
torch
.
bool
,
device
=
device
)
unused_states_bool
[
conv_state_indices
]
=
False
padded_state_indices
=
torch
.
concat
([
conv_state_indices
,
torch
.
as_tensor
(
[
PAD_SLOT_ID
]
*
padding
,
dtype
=
torch
.
int32
,
device
=
device
)
],
dim
=
0
)
conv_state
=
torch
.
randn
(
total_entries
,
conv_state
=
torch
.
randn
(
total_entries
,
dim
,
dim
,
width
-
1
,
width
-
1
,
device
=
device
,
device
=
device
,
dtype
=
itype
)
dtype
=
itype
)
conv_state_indices
=
torch
.
randperm
(
total_entries
)[:
batch
].
to
(
conv_state_for_padding_test
=
conv_state
.
clone
()
dtype
=
torch
.
int32
,
device
=
device
)
weight
=
torch
.
randn
(
dim
,
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
width
,
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
device
=
device
,
dtype
=
itype
,
requires_grad
=
True
)
if
has_bias
:
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
,
requires_grad
=
True
)
else
:
bias
=
None
conv_state_ref
=
conv_state
[
conv_state_indices
,
:].
detach
().
clone
()
conv_state_ref
=
conv_state
[
conv_state_indices
,
:].
detach
().
clone
()
activation
=
None
if
not
silu_activation
else
"silu"
activation
=
None
if
not
silu_activation
else
"silu"
out
=
causal_conv1d_update
(
x
,
out
=
causal_conv1d_update
(
x
,
...
@@ -316,45 +308,50 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias,
...
@@ -316,45 +308,50 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias,
weight
,
weight
,
bias
,
bias
,
activation
=
activation
,
activation
=
activation
,
conv_state_indices
=
conv_state_indices
)
conv_state_indices
=
padded_state_indices
,
out_ref
=
causal_conv1d_update_ref
(
x
,
pad_slot_id
=
PAD_SLOT_ID
)
out_ref
=
causal_conv1d_update_ref
(
x_ref
[:
batch_size
],
conv_state_ref
,
conv_state_ref
,
weight
,
weight
,
bias
,
bias
,
activation
=
activation
)
activation
=
activation
)
assert
torch
.
equal
(
conv_state
[
conv_state_indices
,
:],
conv_state_ref
)
assert
torch
.
equal
(
conv_state
[
conv_state_indices
,
:],
conv_state_ref
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
[:
batch_size
],
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
equal
(
conv_state
[
unused_states_bool
],
conv_state_for_padding_test
[
unused_states_bool
])
opcheck
(
torch
.
ops
.
_C
.
causal_conv1d_update
,
(
opcheck
(
torch
.
ops
.
_C
.
causal_conv1d_update
,
x
,
(
x
,
conv_state
,
weight
,
bias
,
activation
conv_state
,
in
[
"silu"
,
"swish"
],
None
,
padded_state_indices
,
PAD_SLOT_ID
))
weight
,
bias
,
activation
in
[
"silu"
,
"swish"
],
None
,
conv_state_indices
,
))
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
@
pytest
.
mark
.
parametrize
(
[
8
,
16
,
32
,
64
,
128
,
256
,
512
,
784
,
1024
,
2048
,
4096
])
'seqlen'
,
[
8
,
16
,
32
,
64
,
128
,
256
,
512
,
784
,
1024
,
2048
,
2049
,
4096
])
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
64
,
4096
])
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
64
,
4096
])
def
test_causal_conv1d_varlen
(
dim
,
seqlen
,
width
,
has_bias
,
silu_activation
,
# tests correctness in case subset of the sequences are padded
itype
):
@
pytest
.
mark
.
parametrize
(
'with_padding'
,
[
True
,
False
])
def
test_causal_conv1d_varlen
(
with_padding
,
dim
,
seqlen
,
width
,
has_bias
,
silu_activation
,
itype
):
device
=
"cuda"
device
=
"cuda"
torch
.
cuda
.
empty_cache
()
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
rtol
,
atol
=
1e-2
,
5e-2
# set seed
# set seed
seed_everything
(
0
)
seed_everything
(
0
)
batch
=
1
seqlens
=
[]
seqlens
=
[]
nsplits
=
3
batch_size
=
4
if
seqlen
<
10
:
batch_size
=
1
padding
=
3
if
with_padding
else
0
padded_batch_size
=
batch_size
+
padding
nsplits
=
padded_batch_size
-
1
eos_pos
=
torch
.
randperm
(
seqlen
-
1
)[:
nsplits
].
sort
().
values
eos_pos
=
torch
.
randperm
(
seqlen
-
1
)[:
nsplits
].
sort
().
values
seqlens
.
append
(
seqlens
.
append
(
torch
.
diff
(
torch
.
diff
(
...
@@ -364,10 +361,11 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
...
@@ -364,10 +361,11 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
assert
sum
(
seqlens
[
-
1
])
==
seqlen
assert
sum
(
seqlens
[
-
1
])
==
seqlen
assert
all
(
s
>
0
for
s
in
seqlens
[
-
1
])
assert
all
(
s
>
0
for
s
in
seqlens
[
-
1
])
total_entries
=
batch_size
*
10
cumsum
=
torch
.
cumsum
(
torch
.
tensor
(
seqlens
[
0
]),
dim
=
0
).
to
(
torch
.
int32
)
cumsum
=
torch
.
cumsum
(
torch
.
tensor
(
seqlens
[
0
]),
dim
=
0
).
to
(
torch
.
int32
)
cumsum
=
torch
.
concat
([
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
),
cumsum
],
cumsum
=
torch
.
concat
([
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
),
cumsum
],
dim
=
0
)
dim
=
0
)
x
=
torch
.
randn
(
batch
,
4096
+
dim
+
64
,
seqlen
,
device
=
device
,
x
=
torch
.
randn
(
1
,
4096
+
dim
+
64
,
seqlen
,
device
=
device
,
dtype
=
itype
)[:,
4096
:
4096
+
dim
,
:]
dtype
=
itype
)[:,
4096
:
4096
+
dim
,
:]
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
...
@@ -375,7 +373,7 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
...
@@ -375,7 +373,7 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
weight_ref
=
weight
.
clone
()
weight_ref
=
weight
.
clone
()
bias_ref
=
bias
.
clone
()
if
bias
is
not
None
else
None
bias_ref
=
bias
.
clone
()
if
bias
is
not
None
else
None
activation
=
None
if
not
silu_activation
else
"silu"
activation
=
None
if
not
silu_activation
else
"silu"
final_states
=
torch
.
randn
(
nsplits
+
1
,
final_states
=
torch
.
randn
(
total_entries
,
dim
,
dim
,
width
-
1
,
width
-
1
,
device
=
x
.
device
,
device
=
x
.
device
,
...
@@ -385,18 +383,27 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
...
@@ -385,18 +383,27 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
2
,
(
cumsum
.
shape
[
0
]
-
1
,
),
2
,
(
cumsum
.
shape
[
0
]
-
1
,
),
dtype
=
torch
.
bool
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
device
=
x
.
device
)
cach
e_indices
=
torch
.
randperm
(
cumsum
.
shape
[
0
]
-
1
,
stat
e_indices
=
torch
.
randperm
(
total_entries
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
x
.
device
)
device
=
x
.
device
)[:
batch_size
]
padded_state_indices
=
torch
.
concat
([
state_indices
,
torch
.
as_tensor
(
[
PAD_SLOT_ID
]
*
padding
,
dtype
=
torch
.
int32
,
device
=
device
),
],
dim
=-
1
)
out
=
causal_conv1d_fn
(
x
.
squeeze
(
0
),
weight
,
bias
,
cumsum
.
cuda
(),
out
=
causal_conv1d_fn
(
x
.
squeeze
(
0
),
weight
,
bias
,
cumsum
.
cuda
(),
cach
e_indices
,
has_initial_states
,
final_states
,
padded_stat
e_indices
,
has_initial_states
,
activation
)
final_states
,
activation
,
PAD_SLOT_ID
)
out_ref
=
[]
out_ref
=
[]
out_ref_b
=
[]
out_ref_b
=
[]
splits
=
[
torch
.
split
(
var
,
seqlens
[
0
],
dim
=-
1
)
for
var
in
(
x_ref
)]
splits
=
[
torch
.
split
(
var
,
seqlens
[
0
],
dim
=-
1
)
for
var
in
(
x_ref
)]
for
i
in
range
(
len
(
seqlens
[
0
])):
for
i
in
range
(
len
(
seqlens
[
0
])):
x_s
=
[
v
[
i
].
unsqueeze
(
0
)
for
v
in
splits
][
0
]
x_s
=
[
v
[
i
].
unsqueeze
(
0
)
for
v
in
splits
][
0
]
if
padded_state_indices
[
i
]
==
PAD_SLOT_ID
:
continue
out_ref_b
.
append
(
out_ref_b
.
append
(
causal_conv1d_ref
(
causal_conv1d_ref
(
x_s
,
x_s
,
...
@@ -404,21 +411,17 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
...
@@ -404,21 +411,17 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
bias_ref
,
bias_ref
,
activation
=
activation
,
activation
=
activation
,
return_final_states
=
True
,
return_final_states
=
True
,
final_states_out
=
final_states_ref
[
cache_indices
[
i
]].
unsqueeze
(
final_states_out
=
final_states_ref
[
0
),
padded_state_indices
[
i
]].
unsqueeze
(
0
),
initial_states
=
final_states_ref
[
cach
e_indices
[
i
]].
unsqueeze
(
0
)
initial_states
=
final_states_ref
[
padded_stat
e_indices
[
i
]].
if
has_initial_states
[
i
]
else
None
))
unsqueeze
(
0
)
if
has_initial_states
[
i
]
else
None
))
out_ref
.
append
(
torch
.
cat
([
t
[
0
]
for
t
in
out_ref_b
],
dim
=
2
))
out_ref
.
append
(
torch
.
cat
([
t
[
0
]
for
t
in
out_ref_b
],
dim
=
2
))
out_ref
=
torch
.
cat
(
out_ref
,
dim
=
0
)
out_ref_tensor
=
torch
.
cat
(
out_ref
,
dim
=
0
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
unpadded_out
=
out
[:,
:
out_ref_tensor
.
shape
[
-
1
]]
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
assert
torch
.
allclose
(
unpadded_out
,
out_ref_tensor
,
rtol
=
rtol
,
atol
=
atol
)
print
(
"Output state max diff"
f
":
{
(
final_states
-
final_states_ref
).
abs
().
max
()
}
"
)
print
(
"Output state mean diff"
f
":
{
(
final_states
-
final_states_ref
).
abs
().
mean
()
}
"
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
final_states
,
final_states_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
final_states
,
final_states_ref
,
rtol
=
rtol
,
atol
=
atol
)
causal_conv1d_opcheck_fn
(
x
.
squeeze
(
0
),
weight
,
bias
,
cumsum
.
cuda
(),
causal_conv1d_opcheck_fn
(
x
.
squeeze
(
0
),
weight
,
bias
,
cumsum
.
cuda
(),
cach
e_indices
,
has_initial_states
,
final_states
,
padded_stat
e_indices
,
has_initial_states
,
activation
)
final_states
,
activation
)
tests/kernels/test_mamba_ssm.py
View file @
fb60ae9b
...
@@ -5,6 +5,7 @@ from einops import rearrange, repeat
...
@@ -5,6 +5,7 @@ from einops import rearrange, repeat
from
tests.kernels.utils
import
opcheck
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
# noqa: F401
from
vllm
import
_custom_ops
as
ops
# noqa: F401
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.model_executor.layers.mamba.ops.mamba_ssm
import
(
from
vllm.model_executor.layers.mamba.ops.mamba_ssm
import
(
selective_scan_fn
,
selective_state_update
)
selective_scan_fn
,
selective_state_update
)
from
vllm.utils
import
seed_everything
from
vllm.utils
import
seed_everything
...
@@ -174,7 +175,8 @@ def selective_scan_opcheck_fn(u,
...
@@ -174,7 +175,8 @@ def selective_scan_opcheck_fn(u,
cu_seq_len
=
None
,
cu_seq_len
=
None
,
cache_indices
=
None
,
cache_indices
=
None
,
has_initial_state
=
None
,
has_initial_state
=
None
,
ssm_states
=
None
):
ssm_states
=
None
,
pad_slot_id
=
PAD_SLOT_ID
):
"""if return_last_state is True, returns (out, last_state)
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate).
last_state has shape (batch, dim, dstate).
"""
"""
...
@@ -203,7 +205,7 @@ def selective_scan_opcheck_fn(u,
...
@@ -203,7 +205,7 @@ def selective_scan_opcheck_fn(u,
# a bogus error.
# a bogus error.
opcheck
(
torch
.
ops
.
_C
.
selective_scan_fwd
,
opcheck
(
torch
.
ops
.
_C
.
selective_scan_fwd
,
(
u
,
delta
,
A
,
B
,
C
,
D
,
z
,
delta_bias
,
delta_softplus
,
cu_seq_len
,
(
u
,
delta
,
A
,
B
,
C
,
D
,
z
,
delta_bias
,
delta_softplus
,
cu_seq_len
,
cache_indices
,
has_initial_state
,
ssm_states
),
cache_indices
,
has_initial_state
,
ssm_states
,
pad_slot_id
),
test_utils
=
[
"test_schema"
,
"test_faketensor"
])
test_utils
=
[
"test_schema"
,
"test_faketensor"
])
...
@@ -404,9 +406,12 @@ def test_selective_state_update(dim, dstate, has_z, itype):
...
@@ -404,9 +406,12 @@ def test_selective_state_update(dim, dstate, has_z, itype):
@
pytest
.
mark
.
parametrize
(
"varBC_groups"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"varBC_groups"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"is_variable_C"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"is_variable_C"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"is_variable_B"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"is_variable_B"
,
[
True
])
def
test_selective_scan_varlen
(
is_variable_B
,
is_variable_C
,
varBC_groups
,
# tests correctness in case subset of the sequences are padded
has_D
,
has_z
,
has_delta_bias
,
delta_softplus
,
@
pytest
.
mark
.
parametrize
(
"with_padding"
,
[
False
,
True
])
return_last_state
,
seqlen
,
itype
,
wtype
):
def
test_selective_scan_varlen
(
with_padding
,
is_variable_B
,
is_variable_C
,
varBC_groups
,
has_D
,
has_z
,
has_delta_bias
,
delta_softplus
,
return_last_state
,
seqlen
,
itype
,
wtype
):
if
varBC_groups
>
1
and
(
not
is_variable_B
or
not
is_variable_C
):
if
varBC_groups
>
1
and
(
not
is_variable_B
or
not
is_variable_C
):
pytest
.
skip
()
# This config is not applicable
pytest
.
skip
()
# This config is not applicable
device
=
'cuda'
device
=
'cuda'
...
@@ -420,18 +425,27 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
...
@@ -420,18 +425,27 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
seqlens
=
[]
seqlens
=
[]
nsplits
=
3
batch_size
=
4
if
seqlen
<
10
:
if
seqlen
<
10
:
nsplits
=
0
batch_size
=
1
padding
=
3
if
with_padding
else
0
padded_batch_size
=
batch_size
+
padding
if
with_padding
and
seqlen
<
padded_batch_size
:
pytest
.
skip
()
nsplits
=
padded_batch_size
-
1
eos_pos
=
torch
.
randperm
(
seqlen
-
1
)[:
nsplits
].
sort
().
values
eos_pos
=
torch
.
randperm
(
seqlen
-
1
)[:
nsplits
].
sort
().
values
seqlens
.
append
(
seqlens
.
append
(
torch
.
diff
(
torch
.
diff
(
torch
.
cat
(
torch
.
cat
(
[
torch
.
tensor
([
-
1
]),
eos_pos
,
[
torch
.
tensor
([
-
1
]),
eos_pos
,
torch
.
tensor
([
seqlen
-
1
])])).
tolist
())
torch
.
tensor
([
seqlen
-
1
])])).
tolist
())
assert
sum
(
seqlens
[
-
1
])
==
seqlen
assert
sum
(
seqlens
[
-
1
])
==
seqlen
assert
all
(
s
>
0
for
s
in
seqlens
[
-
1
])
assert
all
(
s
>
0
for
s
in
seqlens
[
-
1
])
total_entries
=
batch_size
*
10
cumsum
=
torch
.
cumsum
(
torch
.
tensor
(
seqlens
[
0
]),
dim
=
0
).
to
(
torch
.
int32
)
cumsum
=
torch
.
cumsum
(
torch
.
tensor
(
seqlens
[
0
]),
dim
=
0
).
to
(
torch
.
int32
)
cumsum
=
torch
.
concat
([
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
),
cumsum
],
cumsum
=
torch
.
concat
([
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
),
cumsum
],
dim
=
0
).
cuda
()
dim
=
0
).
cuda
()
...
@@ -462,22 +476,33 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
...
@@ -462,22 +476,33 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
delta_ref
=
delta
.
clone
()
delta_ref
=
delta
.
clone
()
out
=
None
out
=
None
out_ref
=
None
out_ref
=
None
prev_state_shape
=
(
cumsum
.
shape
[
0
]
-
1
,
u
.
shape
[
0
],
int
(
A
.
shape
[
1
]))
prev_state_shape
=
(
total_entries
,
u
.
shape
[
0
],
int
(
A
.
shape
[
1
]))
prev_state
=
torch
.
randn
(
prev_state_shape
,
prev_state
=
torch
.
randn
(
prev_state_shape
,
device
=
u
.
device
,
device
=
u
.
device
,
dtype
=
itype
,
dtype
=
itype
,
requires_grad
=
False
)
requires_grad
=
False
)
prev_state_ref
=
prev_state
.
clone
()
prev_state_ref
=
prev_state
.
clone
()
cach
e_indices
=
torch
.
randperm
(
cumsum
.
shape
[
0
]
-
1
,
stat
e_indices
=
torch
.
randperm
(
total_entries
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
u
.
device
)
device
=
u
.
device
)[:
batch_size
]
unused_states_bool
=
torch
.
ones
(
total_entries
,
dtype
=
torch
.
bool
,
device
=
device
)
unused_states_bool
[
state_indices
]
=
False
padded_state_indices
=
torch
.
concat
([
state_indices
,
torch
.
as_tensor
(
[
PAD_SLOT_ID
]
*
padding
,
dtype
=
torch
.
int32
,
device
=
device
),
],
dim
=-
1
)
has_initial_state
=
torch
.
randint
(
0
,
has_initial_state
=
torch
.
randint
(
0
,
2
,
(
cumsum
.
shape
[
0
]
-
1
,
),
2
,
(
cumsum
.
shape
[
0
]
-
1
,
),
dtype
=
torch
.
bool
,
dtype
=
torch
.
bool
,
device
=
u
.
device
)
device
=
u
.
device
)
out
=
selective_scan_fn
(
u
,
prev_state
,
delta
,
A
,
B
,
C
,
D
,
z
,
delta_bias
,
out
=
selective_scan_fn
(
u
,
prev_state
,
delta
,
A
,
B
,
C
,
D
,
z
,
delta_bias
,
delta_softplus
,
cumsum
,
cach
e_indices
,
delta_softplus
,
cumsum
,
padded_stat
e_indices
,
has_initial_state
)
has_initial_state
)
outs_ref
=
[]
outs_ref
=
[]
splits
=
[
splits
=
[
...
@@ -486,6 +511,8 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
...
@@ -486,6 +511,8 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
]
]
for
i
in
range
(
len
(
seqlens
[
0
])):
for
i
in
range
(
len
(
seqlens
[
0
])):
u_s
,
delta_s
,
B_s
,
C_s
,
z_s
=
[
v
[
i
].
unsqueeze
(
0
)
for
v
in
splits
]
u_s
,
delta_s
,
B_s
,
C_s
,
z_s
=
[
v
[
i
].
unsqueeze
(
0
)
for
v
in
splits
]
if
padded_state_indices
[
i
]
==
PAD_SLOT_ID
:
continue
out_ref_s
,
_
=
selective_scan_ref
(
out_ref_s
,
_
=
selective_scan_ref
(
u_s
,
u_s
,
delta_s
,
delta_s
,
...
@@ -497,21 +524,22 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
...
@@ -497,21 +524,22 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
delta_bias
=
delta_bias
,
delta_bias
=
delta_bias
,
delta_softplus
=
delta_softplus
,
delta_softplus
=
delta_softplus
,
return_last_state
=
return_last_state
,
return_last_state
=
return_last_state
,
prev_state
=
prev_state_ref
[
cach
e_indices
[
i
]].
unsqueeze
(
0
)
prev_state
=
prev_state_ref
[
padded_stat
e_indices
[
i
]].
unsqueeze
(
0
)
if
has_initial_state
[
i
]
else
None
,
if
has_initial_state
[
i
]
else
None
,
final_state_out
=
prev_state_ref
[
cache_indices
[
i
]].
unsqueeze
(
0
))
final_state_out
=
prev_state_ref
[
padded_state_indices
[
i
]].
unsqueeze
(
0
))
outs_ref
.
append
(
out_ref_s
)
outs_ref
.
append
(
out_ref_s
)
out_ref
=
torch
.
cat
(
outs_ref
,
dim
=-
1
)
if
len
(
outs_ref
)
>
1
else
outs_ref
[
0
]
out_ref
=
torch
.
cat
(
outs_ref
,
dim
=-
1
)[
0
]
print
(
"Output diff max"
,
(
out
-
out_ref
[
0
]).
max
())
unpadded_out
=
out
[:,
:
out_ref
[
0
].
shape
[
-
1
]]
print
(
"Output diff mean"
,
(
out
-
out_ref
[
0
]).
mean
())
print
(
"Output diff max"
,
(
unpadded_out
-
out_ref
).
max
())
print
(
"Output diff mean"
,
(
unpadded_out
-
out_ref
).
mean
())
print
(
"Output state diff max"
,
(
prev_state
-
prev_state_ref
).
max
())
print
(
"Output state diff max"
,
(
prev_state
-
prev_state_ref
).
max
())
print
(
"Output state diff mean"
,
(
prev_state
-
prev_state_ref
).
mean
())
print
(
"Output state diff mean"
,
(
prev_state
-
prev_state_ref
).
mean
())
assert
torch
.
allclose
(
prev_state
,
prev_state_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
prev_state
,
prev_state_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
[
0
],
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
unpadded_out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
selective_scan_opcheck_fn
(
u
,
delta
,
A
,
B
,
C
,
D
,
z
,
delta_bias
,
selective_scan_opcheck_fn
(
u
,
delta
,
A
,
B
,
C
,
D
,
z
,
delta_bias
,
delta_softplus
,
cumsum
,
cach
e_indices
,
delta_softplus
,
cumsum
,
padded_stat
e_indices
,
has_initial_state
,
prev_state
)
has_initial_state
,
prev_state
)
...
@@ -520,7 +548,10 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
...
@@ -520,7 +548,10 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
@
pytest
.
mark
.
parametrize
(
"has_z"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_z"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"dstate"
,
[
16
,
32
,
64
])
@
pytest
.
mark
.
parametrize
(
"dstate"
,
[
16
,
32
,
64
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
2048
+
16
,
4096
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
2048
+
16
,
4096
])
def
test_selective_state_update_with_batch_indices
(
dim
,
dstate
,
has_z
,
itype
):
# tests correctness in case subset of the sequences are padded
@
pytest
.
mark
.
parametrize
(
"with_padding"
,
[
True
,
False
])
def
test_selective_state_update_with_batch_indices
(
with_padding
,
dim
,
dstate
,
has_z
,
itype
):
device
=
"cuda"
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
5e-3
,
1e-2
)
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
5e-3
,
1e-2
)
if
itype
==
torch
.
bfloat16
:
if
itype
==
torch
.
bfloat16
:
...
@@ -530,21 +561,32 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
...
@@ -530,21 +561,32 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
3
batch_size
=
3
padding
=
5
if
with_padding
else
0
padded_batch_size
=
batch_size
+
padding
total_entries
=
10
*
batch_size
total_entries
=
10
*
batch_size
state
=
torch
.
randn
(
total_entries
,
dim
,
dstate
,
dtype
=
itype
,
device
=
device
)
state
=
torch
.
randn
(
total_entries
,
dim
,
dstate
,
dtype
=
itype
,
device
=
device
)
state_indices
=
torch
.
randperm
(
total_entries
)[:
batch_size
].
to
(
state_indices
=
torch
.
randperm
(
total_entries
)[:
batch_size
].
to
(
dtype
=
torch
.
int32
,
device
=
device
)
dtype
=
torch
.
int32
,
device
=
device
)
unused_states_bool
=
torch
.
ones
(
total_entries
,
x
=
torch
.
randn
(
batch_size
,
dim
,
device
=
device
,
dtype
=
itype
)
dtype
=
torch
.
bool
,
dt
=
torch
.
randn
(
batch_size
,
dim
,
device
=
device
,
dtype
=
itype
)
device
=
device
)
unused_states_bool
[
state_indices
]
=
False
padded_state_indices
=
torch
.
concat
([
state_indices
,
torch
.
as_tensor
(
[
PAD_SLOT_ID
]
*
padding
,
dtype
=
torch
.
int32
,
device
=
device
)
],
dim
=
0
)
x
=
torch
.
randn
(
padded_batch_size
,
dim
,
device
=
device
,
dtype
=
itype
)
dt
=
torch
.
randn
(
padded_batch_size
,
dim
,
device
=
device
,
dtype
=
itype
)
dt_bias
=
torch
.
rand
(
dim
,
device
=
device
)
-
4.0
dt_bias
=
torch
.
rand
(
dim
,
device
=
device
)
-
4.0
A
=
-
torch
.
rand
(
dim
,
dstate
,
device
=
device
)
-
1.0
A
=
-
torch
.
rand
(
dim
,
dstate
,
device
=
device
)
-
1.0
B
=
torch
.
randn
(
batch_size
,
dstate
,
device
=
device
)
B
=
torch
.
randn
(
padded_
batch_size
,
dstate
,
device
=
device
)
C
=
torch
.
randn
(
batch_size
,
dstate
,
device
=
device
)
C
=
torch
.
randn
(
padded_
batch_size
,
dstate
,
device
=
device
)
D
=
torch
.
randn
(
dim
,
device
=
device
)
D
=
torch
.
randn
(
dim
,
device
=
device
)
z
=
torch
.
randn_like
(
x
)
if
has_z
else
None
z
=
torch
.
randn_like
(
x
)
if
has_z
else
None
state_ref
=
state
[
state_indices
,
:].
detach
().
clone
()
state_ref
=
state
[
state_indices
,
:].
clone
()
state_before
=
state
.
clone
()
out
=
selective_state_update
(
state
,
out
=
selective_state_update
(
state
,
x
,
x
,
dt
,
dt
,
...
@@ -555,15 +597,16 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
...
@@ -555,15 +597,16 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
z
=
z
,
z
=
z
,
dt_bias
=
dt_bias
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
dt_softplus
=
True
,
state_batch_indices
=
state_indices
)
state_batch_indices
=
padded_state_indices
,
pad_slot_id
=
PAD_SLOT_ID
)
out_ref
=
selective_state_update_ref
(
state_ref
,
out_ref
=
selective_state_update_ref
(
state_ref
,
x
,
x
[:
batch_size
]
,
dt
,
dt
[:
batch_size
]
,
A
,
A
,
B
,
B
[:
batch_size
]
,
C
,
C
[:
batch_size
]
,
D
=
D
,
D
=
D
,
z
=
z
,
z
=
z
[:
batch_size
]
,
dt_bias
=
dt_bias
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
)
dt_softplus
=
True
)
...
@@ -572,11 +615,21 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
...
@@ -572,11 +615,21 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
print
(
"Output state diff max"
,
(
state
[
state_indices
,
:]
-
state_ref
).
max
())
print
(
"Output state diff max"
,
(
state
[
state_indices
,
:]
-
state_ref
).
max
())
print
(
"Output state diff mean"
,
print
(
"Output state diff mean"
,
(
state
[
state_indices
,
:]
-
state_ref
).
mean
())
(
state
[
state_indices
,
:]
-
state_ref
).
mean
())
# test padded entries stay the same
if
with_padding
:
assert
torch
.
equal
(
state_before
[
unused_states_bool
],
state
[
unused_states_bool
])
assert
torch
.
equal
(
x
[
batch_size
+
1
:],
x
[
batch_size
+
1
:])
assert
torch
.
equal
(
dt
[
batch_size
+
1
:],
dt
[
batch_size
+
1
:])
assert
torch
.
equal
(
B
[
batch_size
+
1
:],
B
[
batch_size
+
1
:])
assert
torch
.
equal
(
C
[
batch_size
+
1
:],
C
[
batch_size
+
1
:])
# test "real" entries
assert
torch
.
allclose
(
state
[
state_indices
,
:],
assert
torch
.
allclose
(
state
[
state_indices
,
:],
state_ref
,
state_ref
,
rtol
=
rtol
,
rtol
=
rtol
,
atol
=
atol
)
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
[:
batch_size
]
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
@
pytest
.
mark
.
parametrize
(
"itype"
,
...
@@ -645,7 +698,8 @@ def test_selective_state_update_with_heads_with_batch_indices(
...
@@ -645,7 +698,8 @@ def test_selective_state_update_with_heads_with_batch_indices(
z
=
z
,
z
=
z
,
dt_bias
=
dt_bias
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
dt_softplus
=
True
,
state_batch_indices
=
state_indices
)
state_batch_indices
=
state_indices
,
pad_slot_id
=
PAD_SLOT_ID
)
out_ref
=
selective_state_update_ref
(
state_ref
,
out_ref
=
selective_state_update_ref
(
state_ref
,
x
,
x
,
dt
,
dt
,
...
...
tests/models/decoder_only/language/test_jamba.py
View file @
fb60ae9b
import
pytest
import
pytest
from
tests.utils
import
multi_gpu_test
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.worker.model_runner
import
_get_graph_batch_size
from
vllm.worker.model_runner
import
_get_graph_batch_size
...
@@ -270,6 +271,30 @@ def test_state_cleanup(
...
@@ -270,6 +271,30 @@ def test_state_cleanup(
"could be related to finished_requests_ids"
)
"could be related to finished_requests_ids"
)
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
def
test_jamba_distributed_produces_identical_generation
(
vllm_runner
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
example_prompts
)
->
None
:
with
vllm_runner
(
model
,
dtype
=
dtype
,
tensor_parallel_size
=
2
)
as
vllm_model
:
vllm_outputs_tp_2
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
with
vllm_runner
(
model
,
dtype
=
dtype
,
tensor_parallel_size
=
1
)
as
vllm_model
:
vllm_outputs_tp_1
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
check_outputs_equal
(
outputs_0_lst
=
vllm_outputs_tp_1
,
outputs_1_lst
=
vllm_outputs_tp_2
,
name_0
=
"vllm_tp_1"
,
name_1
=
"vllm_tp_2"
,
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
def
test_model_print
(
def
test_model_print
(
...
...
vllm/_custom_ops.py
View file @
fb60ae9b
...
@@ -464,16 +464,18 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
...
@@ -464,16 +464,18 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
cu_seq_len
:
Optional
[
torch
.
Tensor
],
cu_seq_len
:
Optional
[
torch
.
Tensor
],
cache_indices
:
Optional
[
torch
.
Tensor
],
cache_indices
:
Optional
[
torch
.
Tensor
],
has_initial_state
:
Optional
[
torch
.
Tensor
],
has_initial_state
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
)
->
torch
.
Tensor
:
silu_activation
:
bool
,
pad_slot_id
:
int
)
:
return
torch
.
empty_like
(
x
)
return
None
@
register_fake
(
"_C::causal_conv1d_update"
)
@
register_fake
(
"_C::causal_conv1d_update"
)
def
causal_conv1d_update_fake
(
def
causal_conv1d_update_fake
(
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
,
bias_
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
,
cache_seqlens
:
Optional
[
torch
.
Tensor
],
cache_seqlens
:
Optional
[
torch
.
Tensor
],
conv_state_indices
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
conv_state_indices
:
Optional
[
torch
.
Tensor
],
return
torch
.
empty_like
(
x
)
pad_slot_id
:
int
)
->
None
:
return
None
@
register_fake
(
"_C::selective_scan_fwd"
)
@
register_fake
(
"_C::selective_scan_fwd"
)
def
selective_scan_fwd_fake
(
u
:
torch
.
Tensor
,
delta
:
torch
.
Tensor
,
def
selective_scan_fwd_fake
(
u
:
torch
.
Tensor
,
delta
:
torch
.
Tensor
,
...
@@ -485,7 +487,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
...
@@ -485,7 +487,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
cu_seq_len
:
Optional
[
torch
.
Tensor
],
cu_seq_len
:
Optional
[
torch
.
Tensor
],
cache_indices
:
Optional
[
torch
.
Tensor
],
cache_indices
:
Optional
[
torch
.
Tensor
],
has_initial_state
:
Optional
[
torch
.
Tensor
],
has_initial_state
:
Optional
[
torch
.
Tensor
],
ssm_states
:
Optional
[
torch
.
Tensor
])
->
None
:
ssm_states
:
Optional
[
torch
.
Tensor
],
pad_slot_id
:
int
)
->
None
:
return
None
return
None
...
@@ -800,33 +803,37 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
...
@@ -800,33 +803,37 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
query_start_loc
:
Optional
[
torch
.
Tensor
],
query_start_loc
:
Optional
[
torch
.
Tensor
],
cache_indices
:
Optional
[
torch
.
Tensor
],
cache_indices
:
Optional
[
torch
.
Tensor
],
has_initial_state
:
Optional
[
torch
.
Tensor
],
has_initial_state
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
)
->
torch
.
Tensor
:
silu_activation
:
bool
,
pad_slot_id
:
int
)
:
return
torch
.
ops
.
_C
.
causal_conv1d_fwd
(
x
,
weight
,
bias_
,
conv_states
,
torch
.
ops
.
_C
.
causal_conv1d_fwd
(
x
,
weight
,
bias_
,
conv_states
,
query_start_loc
,
cache_indices
,
query_start_loc
,
cache_indices
,
has_initial_state
,
silu_activation
)
has_initial_state
,
silu_activation
,
pad_slot_id
)
def
causal_conv1d_update
(
def
causal_conv1d_update
(
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
]
,
bias_
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
,
silu_activation
:
bool
,
cache_seqlens
:
Optional
[
torch
.
Tensor
],
cache_seqlens
:
Optional
[
torch
.
Tensor
],
conv_state_indices
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
conv_state_indices
:
Optional
[
torch
.
Tensor
],
return
torch
.
ops
.
_C
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias_
,
pad_slot_id
:
int
):
torch
.
ops
.
_C
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias_
,
silu_activation
,
cache_seqlens
,
silu_activation
,
cache_seqlens
,
conv_state_indices
)
conv_state_indices
,
pad_slot_id
)
def
selective_scan_fwd
(
def
selective_scan_fwd
(
u
:
torch
.
Tensor
,
delta
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
u
:
torch
.
Tensor
,
delta
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
D_
:
Optional
[
torch
.
Tensor
],
D_
:
Optional
[
torch
.
Tensor
],
z_
:
Optional
[
torch
.
Tensor
],
z_
:
Optional
[
torch
.
Tensor
],
delta_bias_
:
Optional
[
torch
.
Tensor
],
delta_bias_
:
Optional
[
torch
.
Tensor
],
delta_softplus
:
bool
,
query_start_loc
:
Optional
[
torch
.
Tensor
],
delta_softplus
:
bool
,
query_start_loc
:
Optional
[
torch
.
Tensor
],
cache_indices
:
Optional
[
torch
.
Tensor
],
cache_indices
:
Optional
[
torch
.
Tensor
],
has_initial_state
:
Optional
[
torch
.
Tensor
],
ssm_states
:
torch
.
Tensor
):
has_initial_state
:
Optional
[
torch
.
Tensor
],
ssm_states
:
torch
.
Tensor
,
pad_slot_id
:
int
):
torch
.
ops
.
_C
.
selective_scan_fwd
(
u
,
delta
,
A
,
B
,
C
,
D_
,
z_
,
delta_bias_
,
torch
.
ops
.
_C
.
selective_scan_fwd
(
u
,
delta
,
A
,
B
,
C
,
D_
,
z_
,
delta_bias_
,
delta_softplus
,
query_start_loc
,
delta_softplus
,
query_start_loc
,
cache_indices
,
has_initial_state
,
cache_indices
,
has_initial_state
,
ssm_states
)
ssm_states
,
pad_slot_id
)
# moe
# moe
...
...
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
View file @
fb60ae9b
...
@@ -6,10 +6,10 @@ from typing import Optional
...
@@ -6,10 +6,10 @@ from typing import Optional
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
def
causal_conv1d_fn
(
def
causal_conv1d_fn
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
,
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -17,7 +17,7 @@ def causal_conv1d_fn(
...
@@ -17,7 +17,7 @@ def causal_conv1d_fn(
has_initial_state
:
Optional
[
torch
.
Tensor
]
=
None
,
has_initial_state
:
Optional
[
torch
.
Tensor
]
=
None
,
conv_states
:
Optional
[
torch
.
Tensor
]
=
None
,
conv_states
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
"silu"
,
activation
:
Optional
[
str
]
=
"silu"
,
):
pad_slot_id
:
int
=
PAD_SLOT_ID
):
"""
"""
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
sequences are concatenated from left to right for varlen
sequences are concatenated from left to right for varlen
...
@@ -37,6 +37,13 @@ def causal_conv1d_fn(
...
@@ -37,6 +37,13 @@ def causal_conv1d_fn(
conv_states: (...,dim,width - 1) itype
conv_states: (...,dim,width - 1) itype
updated inplace if provided
updated inplace if provided
activation: either None or "silu" or "swish"
activation: either None or "silu" or "swish"
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim, seqlen)
out: (batch, dim, seqlen)
"""
"""
...
@@ -46,10 +53,10 @@ def causal_conv1d_fn(
...
@@ -46,10 +53,10 @@ def causal_conv1d_fn(
x
=
x
.
contiguous
()
x
=
x
.
contiguous
()
bias
=
bias
.
contiguous
()
if
bias
is
not
None
else
None
bias
=
bias
.
contiguous
()
if
bias
is
not
None
else
None
out
=
ops
.
causal_conv1d_fwd
(
x
,
weight
,
bias
,
conv_states
,
query_start_loc
,
ops
.
causal_conv1d_fwd
(
x
,
weight
,
bias
,
conv_states
,
query_start_loc
,
cache_indices
,
has_initial_state
,
activation
cache_indices
,
has_initial_state
,
activation
in
[
"silu"
,
"swish"
])
in
[
"silu"
,
"swish"
]
,
pad_slot_id
)
return
out
return
x
def
causal_conv1d_update
(
x
:
torch
.
Tensor
,
def
causal_conv1d_update
(
x
:
torch
.
Tensor
,
...
@@ -58,7 +65,8 @@ def causal_conv1d_update(x: torch.Tensor,
...
@@ -58,7 +65,8 @@ def causal_conv1d_update(x: torch.Tensor,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
None
,
activation
:
Optional
[
str
]
=
None
,
cache_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
conv_state_indices
:
Optional
[
torch
.
Tensor
]
=
None
):
conv_state_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
pad_slot_id
:
int
=
PAD_SLOT_ID
):
"""
"""
x: (batch, dim) or (batch, dim, seqlen)
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
conv_state: (batch, dim, state_len), where state_len >= width - 1
...
@@ -73,7 +81,12 @@ def causal_conv1d_update(x: torch.Tensor,
...
@@ -73,7 +81,12 @@ def causal_conv1d_update(x: torch.Tensor,
If not None, the conv_state is a larger tensor along the batch dim,
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
Useful for a continuous batching scenario.
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen)
out: (batch, dim) or (batch, dim, seqlen)
"""
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
...
@@ -82,8 +95,8 @@ def causal_conv1d_update(x: torch.Tensor,
...
@@ -82,8 +95,8 @@ def causal_conv1d_update(x: torch.Tensor,
unsqueeze
=
x
.
dim
()
==
2
unsqueeze
=
x
.
dim
()
==
2
if
unsqueeze
:
if
unsqueeze
:
x
=
x
.
unsqueeze
(
-
1
)
x
=
x
.
unsqueeze
(
-
1
)
out
=
ops
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias
,
activation_val
,
ops
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias
,
activation_val
,
cache_seqlens
,
conv_state_indices
)
cache_seqlens
,
conv_state_indices
,
pad_slot_id
)
if
unsqueeze
:
if
unsqueeze
:
out
=
out
.
squeeze
(
-
1
)
x
=
x
.
squeeze
(
-
1
)
return
out
return
x
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
View file @
fb60ae9b
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py
from
typing
import
Tuple
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
packaging
import
version
from
packaging
import
version
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
TRITON3
=
version
.
parse
(
triton
.
__version__
)
>=
version
.
parse
(
"3.0.0"
)
TRITON3
=
version
.
parse
(
triton
.
__version__
)
>=
version
.
parse
(
"3.0.0"
)
...
@@ -50,6 +49,7 @@ def _selective_scan_update_kernel(
...
@@ -50,6 +49,7 @@ def _selective_scan_update_kernel(
z_ptr
,
z_ptr
,
out_ptr
,
out_ptr
,
state_batch_indices_ptr
,
state_batch_indices_ptr
,
pad_slot_id
,
# Matrix dimensions
# Matrix dimensions
batch
,
batch
,
nheads
,
nheads
,
...
@@ -143,10 +143,11 @@ def _selective_scan_update_kernel(
...
@@ -143,10 +143,11 @@ def _selective_scan_update_kernel(
if
HAS_Z
:
if
HAS_Z
:
z_ptrs
=
z_ptr
+
offs_m
*
stride_z_dim
z_ptrs
=
z_ptr
+
offs_m
*
stride_z_dim
out_ptrs
=
out_ptr
+
offs_m
*
stride_out_dim
out_ptrs
=
out_ptr
+
offs_m
*
stride_out_dim
mask
=
(
offs_m
[:,
None
]
<
dim
)
&
(
offs_n
[
None
,
:]
<
dstate
)
if
HAS_STATE_BATCH_INDICES
:
mask
&=
(
state_batch_idx
!=
pad_slot_id
)
state
=
tl
.
load
(
state_ptrs
,
mask
=
mask
,
other
=
0.0
)
state
=
tl
.
load
(
state_ptrs
,
mask
=
(
offs_m
[:,
None
]
<
dim
)
&
(
offs_n
[
None
,
:]
<
dstate
),
other
=
0.0
)
x
=
tl
.
load
(
x_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
x
=
tl
.
load
(
x_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
if
not
TIE_HDIM
:
if
not
TIE_HDIM
:
dt
=
tl
.
load
(
dt_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
dt
=
tl
.
load
(
dt_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
...
@@ -177,9 +178,11 @@ def _selective_scan_update_kernel(
...
@@ -177,9 +178,11 @@ def _selective_scan_update_kernel(
dB
=
B
[
None
,
:]
*
dt
[:,
None
]
if
not
TIE_HDIM
else
B
*
dt
dB
=
B
[
None
,
:]
*
dt
[:,
None
]
if
not
TIE_HDIM
else
B
*
dt
state
=
state
*
dA
+
dB
*
x
[:,
None
]
state
=
state
*
dA
+
dB
*
x
[:,
None
]
tl
.
store
(
state_ptrs
,
state
,
mask
=
(
offs_m
[:,
None
]
<
dim
)
&
(
offs_n
[
None
,
:]
<
dstate
)
mask
=
(
offs_m
[:,
None
]
<
dim
)
&
(
offs_n
[
None
,
:]
<
dstate
))
if
HAS_STATE_BATCH_INDICES
:
mask
&=
(
state_batch_idx
!=
pad_slot_id
)
tl
.
store
(
state_ptrs
,
state
,
mask
=
mask
)
out
=
tl
.
sum
(
state
*
C
[
None
,
:],
axis
=
1
)
out
=
tl
.
sum
(
state
*
C
[
None
,
:],
axis
=
1
)
if
HAS_D
:
if
HAS_D
:
out
+=
x
*
D
out
+=
x
*
D
...
@@ -198,7 +201,8 @@ def selective_state_update(state,
...
@@ -198,7 +201,8 @@ def selective_state_update(state,
z
=
None
,
z
=
None
,
dt_bias
=
None
,
dt_bias
=
None
,
dt_softplus
=
False
,
dt_softplus
=
False
,
state_batch_indices
=
None
):
state_batch_indices
=
None
,
pad_slot_id
=
PAD_SLOT_ID
):
"""
"""
Argument:
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
...
@@ -210,6 +214,12 @@ def selective_state_update(state,
...
@@ -210,6 +214,12 @@ def selective_state_update(state,
D: (dim,) or (nheads, dim)
D: (dim,) or (nheads, dim)
z: (batch, dim) or (batch, nheads, dim)
z: (batch, dim) or (batch, nheads, dim)
dt_bias: (dim,) or (nheads, dim)
dt_bias: (dim,) or (nheads, dim)
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
Return:
Return:
out: (batch, dim) or (batch, nheads, dim)
out: (batch, dim) or (batch, nheads, dim)
"""
"""
...
@@ -276,6 +286,7 @@ def selective_state_update(state,
...
@@ -276,6 +286,7 @@ def selective_state_update(state,
z
,
z
,
out
,
out
,
state_batch_indices
,
state_batch_indices
,
pad_slot_id
,
batch
,
batch
,
nheads
,
nheads
,
dim
,
dim
,
...
@@ -319,8 +330,7 @@ def selective_state_update(state,
...
@@ -319,8 +330,7 @@ def selective_state_update(state,
return
out
return
out
def
selective_scan_fn
(
def
selective_scan_fn
(
u
,
u
,
ssm_states
,
ssm_states
,
delta
,
delta
,
A
,
A
,
...
@@ -332,9 +342,13 @@ def selective_scan_fn(
...
@@ -332,9 +342,13 @@ def selective_scan_fn(
delta_softplus
=
False
,
delta_softplus
=
False
,
query_start_loc
=
None
,
query_start_loc
=
None
,
cache_indices
=
None
,
cache_indices
=
None
,
has_initial_state
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
has_initial_state
=
None
,
pad_slot_id
=
PAD_SLOT_ID
)
->
torch
.
Tensor
:
"""
"""
u: (dim, total_length) for varlen or (batch, dim, seqlen)
u: (dim, total_length) for varlen or (batch, dim, seqlen)
applies changes in place.
ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate)
applies changes in place.
delta: (dim, total_length) for varlen or (batch, dim, seqlen)
delta: (dim, total_length) for varlen or (batch, dim, seqlen)
A: (dim, dstate)
A: (dim, dstate)
B: (ngroups, dstate, total_length) for varlen or
B: (ngroups, dstate, total_length) for varlen or
...
@@ -357,12 +371,14 @@ def selective_scan_fn(
...
@@ -357,12 +371,14 @@ def selective_scan_fn(
indicate if the ssm_state at the corresponding index should be
indicate if the ssm_state at the corresponding index should be
used as initial state. Not providing argument assumes
used as initial state. Not providing argument assumes
there's no initial state
there's no initial state
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padding entries
that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at indices 0 and 3
returns
returns
output: (dim, total_length) for varlen or (batch, dim, seqlen)
output: (dim, total_length) for varlen or (batch, dim, seqlen)
supports inplace replacement
supports inplace replacement
last_state has shape (batch, dim, dstate).
supports inplace replacement if ssm_state was provided
"""
"""
if
u
.
stride
(
-
1
)
!=
1
:
if
u
.
stride
(
-
1
)
!=
1
:
u
=
u
.
contiguous
()
u
=
u
.
contiguous
()
...
@@ -387,7 +403,7 @@ def selective_scan_fn(
...
@@ -387,7 +403,7 @@ def selective_scan_fn(
ops
.
selective_scan_fwd
(
u
,
delta
,
A
,
B
,
C
,
D
,
z
,
delta_bias
,
delta_softplus
,
ops
.
selective_scan_fwd
(
u
,
delta
,
A
,
B
,
C
,
D
,
z
,
delta_bias
,
delta_softplus
,
query_start_loc
,
cache_indices
,
has_initial_state
,
query_start_loc
,
cache_indices
,
has_initial_state
,
ssm_states
)
ssm_states
,
pad_slot_id
)
if
z
is
None
:
if
z
is
None
:
return
delta
# output written inplace to delta
return
delta
# output written inplace to delta
...
...
vllm/model_executor/models/jamba.py
View file @
fb60ae9b
# coding=utf-8
# coding=utf-8
"""Inference-only Jamba model."""
"""Inference-only Jamba model."""
from
dataclasses
import
dataclass
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -29,7 +28,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -29,7 +28,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
composed_weight_loader
,
default_weight_loader
,
sharded_weight_loader
)
composed_weight_loader
,
default_weight_loader
,
sharded_weight_loader
)
from
vllm.model_executor.models.mamba_cache
import
MambaCacheManager
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
@@ -41,13 +41,6 @@ from .interfaces import HasInnerState, SupportsLoRA
...
@@ -41,13 +41,6 @@ from .interfaces import HasInnerState, SupportsLoRA
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
@
dataclass
class
MambaCacheParams
:
is_prompt
:
bool
=
False
conv_state
:
torch
.
Tensor
=
torch
.
Tensor
()
ssm_state
:
torch
.
Tensor
=
torch
.
Tensor
()
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
class
JambaMambaMixer
(
nn
.
Module
):
class
JambaMambaMixer
(
nn
.
Module
):
"""
"""
...
@@ -60,10 +53,9 @@ class JambaMambaMixer(nn.Module):
...
@@ -60,10 +53,9 @@ class JambaMambaMixer(nn.Module):
**selective** state spaces)
**selective** state spaces)
"""
"""
def
__init__
(
self
,
config
:
JambaConfig
,
layer_idx
):
def
__init__
(
self
,
config
:
JambaConfig
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
layer_idx
=
layer_idx
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
self
.
ssm_state_size
=
config
.
mamba_d_state
self
.
ssm_state_size
=
config
.
mamba_d_state
self
.
conv_kernel_size
=
config
.
mamba_d_conv
self
.
conv_kernel_size
=
config
.
mamba_d_conv
...
@@ -129,8 +121,8 @@ class JambaMambaMixer(nn.Module):
...
@@ -129,8 +121,8 @@ class JambaMambaMixer(nn.Module):
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
conv_state
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
ssm_state
:
torch
.
Tensor
):
mamba_cache_params
:
MambaCacheParams
):
# 1. Gated MLP's linear projection
# 1. Gated MLP's linear projection
projected_states
=
self
.
in_proj
(
hidden_states
)[
0
].
transpose
(
-
2
,
-
1
)
projected_states
=
self
.
in_proj
(
hidden_states
)[
0
].
transpose
(
-
2
,
-
1
)
...
@@ -153,17 +145,18 @@ class JambaMambaMixer(nn.Module):
...
@@ -153,17 +145,18 @@ class JambaMambaMixer(nn.Module):
conv_weights
,
conv_weights
,
self
.
conv1d
.
bias
,
self
.
conv1d
.
bias
,
activation
=
self
.
activation
,
activation
=
self
.
activation
,
conv_states
=
conv_state
,
conv_states
=
mamba_cache_params
.
conv_state
,
has_initial_state
=
attn_metadata
.
context_lens_tensor
>
0
,
has_initial_state
=
attn_metadata
.
context_lens_tensor
>
0
,
cache_indices
=
mamba_cache_params
.
state_indices_tensor
,
query_start_loc
=
attn_metadata
.
query_start_loc
)
query_start_loc
=
attn_metadata
.
query_start_loc
)
else
:
else
:
hidden_states
=
causal_conv1d_update
(
hidden_states
=
causal_conv1d_update
(
hidden_states
.
transpose
(
0
,
1
),
hidden_states
.
transpose
(
0
,
1
),
conv_state
,
mamba_cache_params
.
conv_state
,
conv_weights
,
conv_weights
,
self
.
conv1d
.
bias
,
self
.
conv1d
.
bias
,
self
.
activation
,
self
.
activation
,
)
conv_state_indices
=
mamba_cache_params
.
state_indices_tensor
)
hidden_states
=
hidden_states
.
transpose
(
0
,
1
)
hidden_states
=
hidden_states
.
transpose
(
0
,
1
)
# 3. State Space Model sequence transformation
# 3. State Space Model sequence transformation
...
@@ -188,7 +181,7 @@ class JambaMambaMixer(nn.Module):
...
@@ -188,7 +181,7 @@ class JambaMambaMixer(nn.Module):
and
attn_metadata
.
context_lens_tensor
is
not
None
:
and
attn_metadata
.
context_lens_tensor
is
not
None
:
scan_outputs
=
selective_scan_fn
(
scan_outputs
=
selective_scan_fn
(
hidden_states
,
hidden_states
,
ssm_state
,
mamba_cache_params
.
ssm_state
,
discrete_time_step
,
discrete_time_step
,
self
.
A
,
self
.
A
,
B
.
transpose
(
-
2
,
-
1
),
B
.
transpose
(
-
2
,
-
1
),
...
@@ -197,11 +190,12 @@ class JambaMambaMixer(nn.Module):
...
@@ -197,11 +190,12 @@ class JambaMambaMixer(nn.Module):
gate
,
gate
,
time_proj_bias
,
time_proj_bias
,
delta_softplus
=
True
,
delta_softplus
=
True
,
cache_indices
=
mamba_cache_params
.
state_indices_tensor
,
has_initial_state
=
attn_metadata
.
context_lens_tensor
>
0
,
has_initial_state
=
attn_metadata
.
context_lens_tensor
>
0
,
query_start_loc
=
attn_metadata
.
query_start_loc
)
query_start_loc
=
attn_metadata
.
query_start_loc
)
else
:
else
:
scan_outputs
=
selective_state_update
(
scan_outputs
=
selective_state_update
(
ssm_state
,
mamba_cache_params
.
ssm_state
,
hidden_states
.
transpose
(
0
,
1
),
hidden_states
.
transpose
(
0
,
1
),
discrete_time_step
.
transpose
(
0
,
1
),
discrete_time_step
.
transpose
(
0
,
1
),
self
.
A
,
self
.
A
,
...
@@ -211,7 +205,7 @@ class JambaMambaMixer(nn.Module):
...
@@ -211,7 +205,7 @@ class JambaMambaMixer(nn.Module):
gate
.
transpose
(
0
,
1
),
gate
.
transpose
(
0
,
1
),
time_proj_bias
,
time_proj_bias
,
dt_softplus
=
True
,
dt_softplus
=
True
,
)
state_batch_indices
=
mamba_cache_params
.
state_indices_tensor
)
scan_outputs
=
scan_outputs
.
transpose
(
0
,
1
)
scan_outputs
=
scan_outputs
.
transpose
(
0
,
1
)
# 4. Final linear projection
# 4. Final linear projection
...
@@ -292,7 +286,7 @@ class JambaMambaDecoderLayer(nn.Module):
...
@@ -292,7 +286,7 @@ class JambaMambaDecoderLayer(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
layer_idx
=
layer_idx
self
.
layer_idx
=
layer_idx
self
.
config
=
config
self
.
config
=
config
self
.
mamba
=
JambaMambaMixer
(
config
,
layer_idx
)
self
.
mamba
=
JambaMambaMixer
(
config
)
num_experts
=
config
.
layers_num_experts
[
layer_idx
]
num_experts
=
config
.
layers_num_experts
[
layer_idx
]
ffn_layer_class
=
JambaMoE
if
num_experts
>
1
else
JambaMLP
ffn_layer_class
=
JambaMoE
if
num_experts
>
1
else
JambaMLP
...
@@ -307,8 +301,7 @@ class JambaMambaDecoderLayer(nn.Module):
...
@@ -307,8 +301,7 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
conv_state
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
ssm_state
:
torch
.
Tensor
,
**
kwargs
,
**
kwargs
,
):
):
if
residual
is
None
:
if
residual
is
None
:
...
@@ -318,8 +311,8 @@ class JambaMambaDecoderLayer(nn.Module):
...
@@ -318,8 +311,8 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
)
hidden_states
=
self
.
mamba
(
hidden_states
,
attn_metadata
,
conv_state
,
hidden_states
=
self
.
mamba
(
hidden_states
,
attn_metadata
,
ssm_state
)
mamba_cache_params
)
# Fully Connected
# Fully Connected
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
)
...
@@ -476,17 +469,14 @@ class JambaModel(nn.Module):
...
@@ -476,17 +469,14 @@ class JambaModel(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
conv_state
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
ssm_state
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
kv_cache
=
None
kv_cache
=
None
current_ssm_state
=
None
layer_mamba_cache_params
=
None
current_conv_state
=
None
if
isinstance
(
layer
,
JambaAttentionDecoderLayer
):
if
isinstance
(
layer
,
JambaAttentionDecoderLayer
):
kv_cache
=
kv_caches
[(
i
-
self
.
config
.
attn_layer_offset
)
//
kv_cache
=
kv_caches
[(
i
-
self
.
config
.
attn_layer_offset
)
//
self
.
config
.
attn_layer_period
]
self
.
config
.
attn_layer_period
]
...
@@ -494,8 +484,8 @@ class JambaModel(nn.Module):
...
@@ -494,8 +484,8 @@ class JambaModel(nn.Module):
current_state_layer
=
i
-
(
1
+
current_state_layer
=
i
-
(
1
+
(
i
-
self
.
config
.
attn_layer_offset
)
(
i
-
self
.
config
.
attn_layer_offset
)
//
self
.
config
.
attn_layer_period
)
//
self
.
config
.
attn_layer_period
)
current_ssm_state
=
ssm_state
[
current_st
at
e
_layer
]
layer_mamba_cache_params
=
mamba_cache_params
.
at_layer
_idx
(
current_conv_state
=
conv_state
[
current_state_layer
]
current_state_layer
)
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
=
positions
,
positions
=
positions
,
...
@@ -503,9 +493,7 @@ class JambaModel(nn.Module):
...
@@ -503,9 +493,7 @@ class JambaModel(nn.Module):
kv_cache
=
kv_cache
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
attn_metadata
=
attn_metadata
,
residual
=
residual
,
residual
=
residual
,
conv_state
=
current_conv_state
,
mamba_cache_params
=
layer_mamba_cache_params
)
ssm_state
=
current_ssm_state
,
)
hidden_states
,
_
=
self
.
final_layernorm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
final_layernorm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
...
@@ -588,13 +576,16 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
...
@@ -588,13 +576,16 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
self
.
mamba_cache
=
MambaCacheManager
(
self
.
mamba_cache
=
MambaCacheManager
(
self
.
lm_head
.
weight
.
dtype
,
num_mamba_layers
,
max_batch_size
,
self
.
lm_head
.
weight
.
dtype
,
num_mamba_layers
,
max_batch_size
,
*
self
.
_get_mamba_cache_shape
())
*
self
.
_get_mamba_cache_shape
())
(
mamba_cache_tensors
=
self
.
mamba_cache
.
current_run_tensors
(
mamba_cache_tensors
,
input_ids
,
attn_metadata
,
**
kwargs
)
state_indices_tensor
,
)
=
self
.
mamba_cache
.
current_run_tensors
(
input_ids
,
attn_metadata
,
**
kwargs
)
mamba_cache_params
=
MambaCacheParams
(
mamba_cache_tensors
[
0
],
mamba_cache_tensors
[
1
],
state_indices_tensor
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
mamba_cache_tensors
[
0
],
attn_metadata
,
mamba_cache_params
)
mamba_cache_tensors
[
1
])
return
hidden_states
return
hidden_states
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
...
...
vllm/model_executor/models/mamba.py
View file @
fb60ae9b
...
@@ -27,7 +27,8 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -27,7 +27,8 @@ from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader
,
default_weight_loader
,
sharded_weight_loader
)
composed_weight_loader
,
default_weight_loader
,
sharded_weight_loader
)
from
vllm.model_executor.models.interfaces
import
(
HasInnerState
,
from
vllm.model_executor.models.interfaces
import
(
HasInnerState
,
IsAttentionFree
)
IsAttentionFree
)
from
vllm.model_executor.models.mamba_cache
import
MambaCacheManager
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
@@ -110,8 +111,8 @@ class MambaMixer(nn.Module):
...
@@ -110,8 +111,8 @@ class MambaMixer(nn.Module):
self
.
activation
=
config
.
hidden_act
self
.
activation
=
config
.
hidden_act
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
conv_state
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
ssm_state
:
torch
.
Tensor
):
mamba_cache_params
:
MambaCacheParams
):
# 1. Gated MLP's linear projection
# 1. Gated MLP's linear projection
projected_states
=
self
.
in_proj
(
hidden_states
)[
0
].
transpose
(
-
2
,
-
1
)
projected_states
=
self
.
in_proj
(
hidden_states
)[
0
].
transpose
(
-
2
,
-
1
)
...
@@ -134,17 +135,18 @@ class MambaMixer(nn.Module):
...
@@ -134,17 +135,18 @@ class MambaMixer(nn.Module):
conv_weights
,
conv_weights
,
self
.
conv1d
.
bias
,
self
.
conv1d
.
bias
,
activation
=
self
.
activation
,
activation
=
self
.
activation
,
conv_states
=
conv_state
,
conv_states
=
mamba_cache_params
.
conv_state
,
has_initial_state
=
attn_metadata
.
context_lens_tensor
>
0
,
has_initial_state
=
attn_metadata
.
context_lens_tensor
>
0
,
cache_indices
=
mamba_cache_params
.
state_indices_tensor
,
query_start_loc
=
attn_metadata
.
query_start_loc
)
query_start_loc
=
attn_metadata
.
query_start_loc
)
else
:
else
:
hidden_states
=
causal_conv1d_update
(
hidden_states
=
causal_conv1d_update
(
hidden_states
.
transpose
(
0
,
1
),
hidden_states
.
transpose
(
0
,
1
),
conv_state
,
mamba_cache_params
.
conv_state
,
conv_weights
,
conv_weights
,
self
.
conv1d
.
bias
,
self
.
conv1d
.
bias
,
self
.
activation
,
self
.
activation
,
)
conv_state_indices
=
mamba_cache_params
.
state_indices_tensor
)
hidden_states
=
hidden_states
.
transpose
(
0
,
1
)
hidden_states
=
hidden_states
.
transpose
(
0
,
1
)
# 3. State Space Model sequence transformation
# 3. State Space Model sequence transformation
...
@@ -168,7 +170,7 @@ class MambaMixer(nn.Module):
...
@@ -168,7 +170,7 @@ class MambaMixer(nn.Module):
and
attn_metadata
.
context_lens_tensor
is
not
None
:
and
attn_metadata
.
context_lens_tensor
is
not
None
:
scan_outputs
=
selective_scan_fn
(
scan_outputs
=
selective_scan_fn
(
hidden_states
,
hidden_states
,
ssm_state
,
mamba_cache_params
.
ssm_state
,
discrete_time_step
,
discrete_time_step
,
self
.
A
,
self
.
A
,
B
.
transpose
(
-
2
,
-
1
),
B
.
transpose
(
-
2
,
-
1
),
...
@@ -177,11 +179,12 @@ class MambaMixer(nn.Module):
...
@@ -177,11 +179,12 @@ class MambaMixer(nn.Module):
gate
,
gate
,
time_proj_bias
,
time_proj_bias
,
delta_softplus
=
True
,
delta_softplus
=
True
,
cache_indices
=
mamba_cache_params
.
state_indices_tensor
,
has_initial_state
=
attn_metadata
.
context_lens_tensor
>
0
,
has_initial_state
=
attn_metadata
.
context_lens_tensor
>
0
,
query_start_loc
=
attn_metadata
.
query_start_loc
)
query_start_loc
=
attn_metadata
.
query_start_loc
)
else
:
else
:
scan_outputs
=
selective_state_update
(
scan_outputs
=
selective_state_update
(
ssm_state
,
mamba_cache_params
.
ssm_state
,
hidden_states
.
transpose
(
0
,
1
),
hidden_states
.
transpose
(
0
,
1
),
discrete_time_step
.
transpose
(
0
,
1
),
discrete_time_step
.
transpose
(
0
,
1
),
self
.
A
,
self
.
A
,
...
@@ -191,7 +194,7 @@ class MambaMixer(nn.Module):
...
@@ -191,7 +194,7 @@ class MambaMixer(nn.Module):
gate
.
transpose
(
0
,
1
),
gate
.
transpose
(
0
,
1
),
time_proj_bias
,
time_proj_bias
,
dt_softplus
=
True
,
dt_softplus
=
True
,
)
state_batch_indices
=
mamba_cache_params
.
state_indices_tensor
)
scan_outputs
=
scan_outputs
.
transpose
(
0
,
1
)
scan_outputs
=
scan_outputs
.
transpose
(
0
,
1
)
# 4. Final linear projection
# 4. Final linear projection
...
@@ -221,8 +224,7 @@ class MambaDecoderLayer(nn.Module):
...
@@ -221,8 +224,7 @@ class MambaDecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
conv_state
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
ssm_state
:
torch
.
Tensor
,
**
kwargs
,
**
kwargs
,
):
):
if
residual
is
None
:
if
residual
is
None
:
...
@@ -231,8 +233,8 @@ class MambaDecoderLayer(nn.Module):
...
@@ -231,8 +233,8 @@ class MambaDecoderLayer(nn.Module):
else
:
else
:
hidden_states
,
residual
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
residual
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mixer
(
hidden_states
,
attn_metadata
,
conv_state
,
hidden_states
=
self
.
mixer
(
hidden_states
,
attn_metadata
,
ssm_state
)
mamba_cache_params
)
return
hidden_states
,
residual
return
hidden_states
,
residual
...
@@ -275,25 +277,20 @@ class MambaModel(nn.Module):
...
@@ -275,25 +277,20 @@ class MambaModel(nn.Module):
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
conv_state
:
torch
.
Tensor
,
mamba_cache_params
:
MambaCacheParams
,
ssm_state
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embeddings
(
input_ids
)
hidden_states
=
self
.
embeddings
(
input_ids
)
residual
=
None
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
current_ssm_state
=
ssm_state
[
i
]
current_conv_state
=
conv_state
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
attn_metadata
=
attn_metadata
,
attn_metadata
=
attn_metadata
,
residual
=
residual
,
residual
=
residual
,
conv_state
=
current_conv_state
,
mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
i
))
ssm_state
=
current_ssm_state
,
)
hidden_states
,
_
=
self
.
norm_f
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm_f
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
...
@@ -347,12 +344,18 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
...
@@ -347,12 +344,18 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
self
.
lm_head
.
weight
.
dtype
,
self
.
config
.
num_hidden_layers
,
self
.
lm_head
.
weight
.
dtype
,
self
.
config
.
num_hidden_layers
,
max_batch_size
,
*
self
.
_get_mamba_cache_shape
())
max_batch_size
,
*
self
.
_get_mamba_cache_shape
())
mamba_cache_tensors
=
self
.
mamba_cache
.
current_run_tensors
(
(
input_ids
,
attn_metadata
,
**
kwargs
)
mamba_cache_tensors
,
state_indices_tensor
,
)
=
self
.
mamba_cache
.
current_run_tensors
(
input_ids
,
attn_metadata
,
**
kwargs
)
mamba_cache_params
=
MambaCacheParams
(
mamba_cache_tensors
[
0
],
mamba_cache_tensors
[
1
],
state_indices_tensor
)
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
attn_metadata
,
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
attn_metadata
,
mamba_cache_tensors
[
0
],
mamba_cache_params
)
mamba_cache_tensors
[
1
])
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/mamba_cache.py
View file @
fb60ae9b
from
typing
import
Dict
,
List
,
Optional
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
import
torch
import
torch
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
@
dataclass
class
MambaCacheParams
:
conv_state
:
torch
.
Tensor
=
torch
.
Tensor
()
ssm_state
:
torch
.
Tensor
=
torch
.
Tensor
()
state_indices_tensor
:
torch
.
Tensor
=
torch
.
Tensor
()
def
at_layer_idx
(
self
,
layer_idx
):
return
MambaCacheParams
(
self
.
conv_state
[
layer_idx
],
self
.
ssm_state
[
layer_idx
],
self
.
state_indices_tensor
)
class
MambaCacheManager
:
class
MambaCacheManager
:
...
@@ -24,6 +38,7 @@ class MambaCacheManager:
...
@@ -24,6 +38,7 @@ class MambaCacheManager:
# Maps between the request id and a dict that maps between the seq_id
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the self.mamba_cache
# and its index inside the self.mamba_cache
self
.
mamba_cache_indices_mapping
:
Dict
[
str
,
Dict
[
int
,
int
]]
=
{}
self
.
mamba_cache_indices_mapping
:
Dict
[
str
,
Dict
[
int
,
int
]]
=
{}
self
.
free_cache_indices
=
list
(
range
(
max_batch_size
))
def
current_run_tensors
(
self
,
input_ids
:
torch
.
Tensor
,
def
current_run_tensors
(
self
,
input_ids
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
**
kwargs
):
attn_metadata
:
AttentionMetadata
,
**
kwargs
):
...
@@ -36,30 +51,43 @@ class MambaCacheManager:
...
@@ -36,30 +51,43 @@ class MambaCacheManager:
finished_requests_ids
=
kwargs
[
"finished_requests_ids"
]
finished_requests_ids
=
kwargs
[
"finished_requests_ids"
]
self
.
_release_finished_requests
(
finished_requests_ids
)
self
.
_release_finished_requests
(
finished_requests_ids
)
mamba_cache_tensor
s
=
self
.
_prepare_current_run_mamba_cache
(
state_indice
s
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
finished_requests_ids
)
request_ids_to_seq_ids
,
finished_requests_ids
)
state_indices_tensor
=
torch
.
as_tensor
(
state_indices
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
mamba_cache_tensors
=
self
.
mamba_cache
else
:
else
:
# CUDA graph capturing runs
# CUDA graph capturing runs
mamba_cache_tensors
=
kwargs
[
"seqlen_agnostic_capture_inputs"
]
(
mamba_cache_tensors
,
state_indices_tensor
)
=
kwargs
[
"seqlen_agnostic_capture_inputs"
]
return
mamba_cache_tensors
return
(
mamba_cache_tensors
,
state_indices_tensor
)
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
"""
"""
Copy the relevant Mamba cache into the CUDA graph input buffer
Copy the relevant state_indices into the CUDA graph input buffer
that was provided during the capture runs
(JambaForCausalLM.mamba_gc_cache_buffer).
"""
"""
assert
all
(
assert
all
(
key
in
kwargs
key
in
kwargs
for
key
in
[
"request_ids_to_seq_ids"
,
"finished_requests_ids"
])
for
key
in
[
"request_ids_to_seq_ids"
,
"finished_requests_ids"
])
finished_requests_ids
=
kwargs
[
"finished_requests_ids"
]
finished_requests_ids
=
kwargs
[
"finished_requests_ids"
]
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
assert
"seqlen_agnostic_capture_inputs"
in
input_buffers
_
,
input_state_indices_buffer
=
input_buffers
[
"seqlen_agnostic_capture_inputs"
]
self
.
_release_finished_requests
(
finished_requests_ids
)
self
.
_release_finished_requests
(
finished_requests_ids
)
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
state_indices
=
self
.
_prepare_current_run_mamba_cache
(
finished_requests_ids
)
request_ids_to_seq_ids
,
finished_requests_ids
)
cuda_graph_pad_len
=
input_state_indices_buffer
.
shape
[
0
]
-
len
(
state_indices
)
state_indices
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_len
)
input_state_indices_buffer
.
copy_
(
torch
.
as_tensor
(
state_indices
,
dtype
=
torch
.
int32
,
device
=
"cuda"
))
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
):
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
):
"""
"""
...
@@ -67,13 +95,10 @@ class MambaCacheManager:
...
@@ -67,13 +95,10 @@ class MambaCacheManager:
The buffer is used to maintain the Mamba Cache during the CUDA graph
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
replay runs.
"""
"""
return
tuple
(
buffer
[:,
:
batch_size
]
for
buffer
in
self
.
mamba_cache
)
state_indices_tensor
=
torch
.
as_tensor
([
PAD_SLOT_ID
]
*
batch_size
,
dtype
=
torch
.
int32
,
def
_swap_mamba_cache
(
self
,
from_index
:
int
,
to_index
:
int
):
device
=
"cuda"
)
assert
len
(
self
.
mamba_cache
)
>
0
return
(
self
.
mamba_cache
,
state_indices_tensor
)
for
cache_t
in
self
.
mamba_cache
:
cache_t
[:,
[
to_index
,
from_index
]]
=
\
cache_t
[:,
[
from_index
,
to_index
]]
def
_copy_mamba_cache
(
self
,
from_index
:
int
,
to_index
:
int
):
def
_copy_mamba_cache
(
self
,
from_index
:
int
,
to_index
:
int
):
assert
len
(
self
.
mamba_cache
)
>
0
assert
len
(
self
.
mamba_cache
)
>
0
...
@@ -81,142 +106,53 @@ class MambaCacheManager:
...
@@ -81,142 +106,53 @@ class MambaCacheManager:
cache_t
[:,
to_index
].
copy_
(
cache_t
[:,
from_index
],
cache_t
[:,
to_index
].
copy_
(
cache_t
[:,
from_index
],
non_blocking
=
True
)
non_blocking
=
True
)
def
_move_out_if_already_occupied
(
self
,
index
:
int
,
def
_assign_seq_id_to_cache_index
(
self
,
cur_rid
:
str
,
seq_id
:
int
,
all_occupied_indices
:
List
[
int
]):
finished_requests_ids
)
->
int
:
if
index
in
all_occupied_indices
:
first_free_index
=
self
.
_first_free_index_in_mamba_cache
()
# In case occupied, move the occupied to a new empty block
self
.
_move_cache_index_and_mappings
(
from_index
=
index
,
to_index
=
first_free_index
)
def
_assign_seq_id_to_mamba_cache_in_specific_dest
(
self
,
cur_rid
:
str
,
seq_id
:
int
,
destination_index
:
int
):
"""
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
already occupied, move the occupying index to a free index.
"""
"""
all_occupied_indices
=
self
.
_get_all_occupied_indices
()
if
cur_rid
in
finished_requests_ids
:
if
cur_rid
not
in
self
.
mamba_cache_indices_mapping
:
# set as pad, do not allocate destination index
self
.
_move_out_if_already_occupied
(
return
PAD_SLOT_ID
index
=
destination_index
,
elif
cur_rid
not
in
self
.
mamba_cache_indices_mapping
:
all_occupied_indices
=
all_occupied
_indices
)
destination_index
=
self
.
free_cache
_indices
.
pop
(
)
self
.
mamba_cache_indices_mapping
[
cur_rid
]
=
{
self
.
mamba_cache_indices_mapping
[
cur_rid
]
=
{
seq_id
:
destination_index
seq_id
:
destination_index
}
}
return
destination_index
elif
seq_id
not
in
(
seq_ids2indices
:
=
elif
seq_id
not
in
(
seq_ids2indices
:
=
self
.
mamba_cache_indices_mapping
[
cur_rid
]):
self
.
mamba_cache_indices_mapping
[
cur_rid
]):
# parallel sampling , where n > 1, assume prefill have
# parallel sampling , where n > 1, assume prefill have
# already happened
now we only need to copy the already
# already happened
, so we copy the
# existing cache into the siblings seq_ids caches
# existing cache into the siblings seq_ids caches
self
.
_move_out_if_already_occupied
(
index_exists
=
next
(
iter
(
seq_ids2indices
.
values
()))
index
=
destination_index
,
all_occupied_indices
=
all_occupied_indices
)
index_exists
=
list
(
seq_ids2indices
.
values
())[
0
]
# case of decoding n>1, copy prefill cache to decoding indices
# case of decoding n>1, copy prefill cache to decoding indices
destination_index
=
self
.
free_cache_indices
.
pop
()
self
.
_copy_mamba_cache
(
from_index
=
index_exists
,
self
.
_copy_mamba_cache
(
from_index
=
index_exists
,
to_index
=
destination_index
)
to_index
=
destination_index
)
self
.
mamba_cache_indices_mapping
[
cur_rid
][
self
.
mamba_cache_indices_mapping
[
cur_rid
][
seq_id
]
=
destination_index
seq_id
]
=
destination_index
return
destination_index
else
:
else
:
# already exists
# already exists
cache_index_already_exists
=
self
.
mamba_cache_indices_mapping
[
return
self
.
mamba_cache_indices_mapping
[
cur_rid
][
seq_id
]
cur_rid
][
seq_id
]
if
cache_index_already_exists
!=
destination_index
:
# In case the seq id already exists but not in
# the right destination, swap it with what's occupying it
self
.
_swap_pair_indices_and_mappings
(
from_index
=
cache_index_already_exists
,
to_index
=
destination_index
)
def
_prepare_current_run_mamba_cache
(
def
_prepare_current_run_mamba_cache
(
self
,
request_ids_to_seq_ids
:
Dict
[
str
,
list
[
int
]],
self
,
request_ids_to_seq_ids
:
Dict
[
str
,
list
[
int
]],
finished_requests_ids
:
List
[
str
]):
finished_requests_ids
:
List
[
str
])
->
List
[
int
]
:
r
unning_indices
=
[
]
r
eturn
[
request
_id
s
_to_
seq_ids_flatten
=
[
self
.
_assign_seq
_id_to_
cache_index
(
req_id
,
seq_id
,
(
req_id
,
seq
_id
)
finished_requests
_id
s
)
for
req_id
,
seq_ids
in
request_ids_to_seq_ids
.
items
()
for
req_id
,
seq_ids
in
request_ids_to_seq_ids
.
items
()
for
seq_id
in
seq_ids
for
seq_id
in
seq_ids
]
]
batch_size
=
len
(
request_ids_to_seq_ids_flatten
)
for
dest_index
,
(
request_id
,
seq_id
)
in
enumerate
(
request_ids_to_seq_ids_flatten
):
if
request_id
in
finished_requests_ids
:
# Do not allocate cache index for requests that run
# and finish right after
continue
self
.
_assign_seq_id_to_mamba_cache_in_specific_dest
(
request_id
,
seq_id
,
dest_index
)
running_indices
.
append
(
dest_index
)
self
.
_clean_up_first_bs_blocks
(
batch_size
,
running_indices
)
conv_state
=
self
.
mamba_cache
[
0
][:,
:
batch_size
]
temporal_state
=
self
.
mamba_cache
[
1
][:,
:
batch_size
]
return
(
conv_state
,
temporal_state
)
def
_get_all_occupied_indices
(
self
):
return
[
cache_idx
for
seq_ids2indices
in
self
.
mamba_cache_indices_mapping
.
values
()
for
cache_idx
in
seq_ids2indices
.
values
()
]
def
_clean_up_first_bs_blocks
(
self
,
batch_size
:
int
,
indices_for_current_run
:
List
[
int
]):
# move out all of the occupied but currently not running blocks
# outside of the first n blocks
destination_indices
=
range
(
batch_size
)
max_possible_batch_size
=
self
.
mamba_cache
[
0
].
shape
[
1
]
for
destination_index
in
destination_indices
:
if
destination_index
in
self
.
_get_all_occupied_indices
()
and
\
destination_index
not
in
indices_for_current_run
:
# move not running indices outside of the batch
all_other_indices
=
list
(
range
(
batch_size
,
max_possible_batch_size
))
first_avail_index
=
self
.
_first_free_index_in_mamba_cache
(
all_other_indices
)
self
.
_swap_indices
(
from_index
=
destination_index
,
to_index
=
first_avail_index
)
def
_move_cache_index_and_mappings
(
self
,
from_index
:
int
,
to_index
:
int
):
self
.
_copy_mamba_cache
(
from_index
=
from_index
,
to_index
=
to_index
)
self
.
_update_mapping_index
(
from_index
=
from_index
,
to_index
=
to_index
)
def
_swap_pair_indices_and_mappings
(
self
,
from_index
:
int
,
to_index
:
int
):
self
.
_swap_mamba_cache
(
from_index
=
from_index
,
to_index
=
to_index
)
self
.
_swap_mapping_index
(
from_index
=
from_index
,
to_index
=
to_index
)
def
_swap_mapping_index
(
self
,
from_index
:
int
,
to_index
:
int
):
for
seq_ids2index
in
self
.
mamba_cache_indices_mapping
.
values
():
for
seq_id
,
index
in
seq_ids2index
.
items
():
if
from_index
==
index
:
seq_ids2index
.
update
({
seq_id
:
to_index
})
elif
to_index
==
index
:
seq_ids2index
.
update
({
seq_id
:
from_index
})
def
_update_mapping_index
(
self
,
from_index
:
int
,
to_index
:
int
):
for
seq_ids2index
in
self
.
mamba_cache_indices_mapping
.
values
():
for
seq_id
,
index
in
seq_ids2index
.
items
():
if
from_index
==
index
:
seq_ids2index
.
update
({
seq_id
:
to_index
})
return
def
_release_finished_requests
(
self
,
def
_release_finished_requests
(
self
,
finished_seq_groups_req_ids
:
List
[
str
]):
finished_seq_groups_req_ids
:
List
[
str
]):
for
req_id
in
finished_seq_groups_req_ids
:
for
req_id
in
finished_seq_groups_req_ids
:
if
req_id
in
self
.
mamba_cache_indices_mapping
:
if
req_id
in
self
.
mamba_cache_indices_mapping
:
for
seq_id
in
self
.
mamba_cache_indices_mapping
[
req_id
]:
self
.
free_cache_indices
.
append
(
self
.
mamba_cache_indices_mapping
[
req_id
][
seq_id
])
self
.
mamba_cache_indices_mapping
.
pop
(
req_id
)
self
.
mamba_cache_indices_mapping
.
pop
(
req_id
)
def
_first_free_index_in_mamba_cache
(
self
,
indices_range
:
Optional
[
List
[
int
]]
=
None
)
->
int
:
assert
self
.
mamba_cache
is
not
None
if
indices_range
is
None
:
max_possible_batch_size
=
self
.
mamba_cache
[
0
].
shape
[
1
]
indices_range
=
list
(
range
(
max_possible_batch_size
))
all_occupied_indices
=
self
.
_get_all_occupied_indices
()
for
i
in
indices_range
:
if
i
not
in
all_occupied_indices
:
return
i
raise
Exception
(
"Couldn't find a free spot in the mamba cache! This"
"should never happen"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment