Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
db9120cd
Unverified
Commit
db9120cd
authored
Sep 18, 2024
by
Tyler Michael Smith
Committed by
GitHub
Sep 18, 2024
Browse files
[Kernel] Change interface to Mamba selective_state_update for continuous batching (#8039)
parent
b3195bc9
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
174 additions
and
3 deletions
+174
-3
tests/kernels/test_mamba_ssm.py
tests/kernels/test_mamba_ssm.py
+146
-0
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
+28
-3
No files found.
tests/kernels/test_mamba_ssm.py
View file @
db9120cd
...
...
@@ -323,3 +323,149 @@ def test_selective_state_update(dim, dstate, has_z, itype):
assert
torch
.
allclose
(
state
,
state_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"has_z"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"dstate"
,
[
16
,
32
,
64
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
2048
+
16
,
4096
])
def
test_selective_state_update_with_batch_indices
(
dim
,
dstate
,
has_z
,
itype
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
5e-3
,
1e-2
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
7e-2
,
7e-2
if
torch
.
version
.
hip
:
atol
*=
2
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
16
total_entries
=
10
*
batch_size
state
=
torch
.
randn
(
total_entries
,
dim
,
dstate
,
dtype
=
itype
,
device
=
device
)
state_indices
=
torch
.
randperm
(
total_entries
)[:
batch_size
].
to
(
dtype
=
torch
.
int32
,
device
=
device
)
x
=
torch
.
randn
(
batch_size
,
dim
,
device
=
device
,
dtype
=
itype
)
dt
=
torch
.
randn
(
batch_size
,
dim
,
device
=
device
,
dtype
=
itype
)
dt_bias
=
torch
.
rand
(
dim
,
device
=
device
)
-
4.0
A
=
-
torch
.
rand
(
dim
,
dstate
,
device
=
device
)
-
1.0
B
=
torch
.
randn
(
batch_size
,
dstate
,
device
=
device
)
C
=
torch
.
randn
(
batch_size
,
dstate
,
device
=
device
)
D
=
torch
.
randn
(
dim
,
device
=
device
)
z
=
torch
.
randn_like
(
x
)
if
has_z
else
None
state_ref
=
state
[
state_indices
,
:].
detach
().
clone
()
out
=
selective_state_update
(
state
,
x
,
dt
,
A
,
B
,
C
,
D
=
D
,
z
=
z
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
state_batch_indices
=
state_indices
)
out_ref
=
selective_state_update_ref
(
state_ref
,
x
,
dt
,
A
,
B
,
C
,
D
=
D
,
z
=
z
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
)
assert
torch
.
allclose
(
state
[
state_indices
,
:],
state_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"has_z"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"tie_hdim"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"ngroups"
,
[
1
,
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"dstate"
,
[
16
,
32
,
64
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
4096
])
def
test_selective_state_update_with_heads_with_batch_indices
(
dim
,
dstate
,
ngroups
,
has_z
,
tie_hdim
,
itype
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
5e-3
,
3e-2
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-1
,
1e-1
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
16
headdim
=
64
nheads
=
dim
//
headdim
total_entries
=
10
*
batch_size
state
=
torch
.
randn
(
total_entries
,
nheads
,
headdim
,
dstate
,
dtype
=
itype
,
device
=
device
)
state_indices
=
torch
.
randperm
(
total_entries
)[:
batch_size
].
to
(
dtype
=
torch
.
int32
,
device
=
device
)
x
=
torch
.
randn
(
batch_size
,
nheads
,
headdim
,
device
=
device
,
dtype
=
itype
)
if
not
tie_hdim
:
dt
=
torch
.
randn
(
batch_size
,
nheads
,
headdim
,
device
=
device
,
dtype
=
itype
)
dt_bias
=
torch
.
rand
(
nheads
,
headdim
,
device
=
device
)
-
4.0
A
=
-
torch
.
rand
(
nheads
,
headdim
,
dstate
,
device
=
device
)
-
1.0
D
=
torch
.
randn
(
nheads
,
headdim
,
device
=
device
)
else
:
dt
=
repeat
(
torch
.
randn
(
batch_size
,
nheads
,
device
=
device
,
dtype
=
itype
),
"b h -> b h p"
,
p
=
headdim
)
dt_bias
=
repeat
(
torch
.
rand
(
nheads
,
device
=
device
)
-
4.0
,
"h -> h p"
,
p
=
headdim
)
A
=
repeat
(
-
torch
.
rand
(
nheads
,
device
=
device
)
-
1.0
,
"h -> h p n"
,
p
=
headdim
,
n
=
dstate
)
D
=
repeat
(
torch
.
randn
(
nheads
,
device
=
device
),
"h -> h p"
,
p
=
headdim
)
B
=
torch
.
randn
(
batch_size
,
ngroups
,
dstate
,
device
=
device
)
C
=
torch
.
randn
(
batch_size
,
ngroups
,
dstate
,
device
=
device
)
z
=
torch
.
randn_like
(
x
)
if
has_z
else
None
state_ref
=
state
[
state_indices
,
:].
detach
().
clone
()
out
=
selective_state_update
(
state
,
x
,
dt
,
A
,
B
,
C
,
D
=
D
,
z
=
z
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
state_batch_indices
=
state_indices
)
out_ref
=
selective_state_update_ref
(
state_ref
,
x
,
dt
,
A
,
B
,
C
,
D
=
D
,
z
=
z
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
assert
torch
.
allclose
(
state
[
state_indices
,
:],
state_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
View file @
db9120cd
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py
import
torch
import
triton
...
...
@@ -27,6 +28,10 @@ else:
{
"HAS_DT_BIAS"
:
lambda
args
:
args
[
"dt_bias_ptr"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_D"
:
lambda
args
:
args
[
"D_ptr"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_Z"
:
lambda
args
:
args
[
"z_ptr"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_STATE_BATCH_INDICES"
:
lambda
args
:
args
[
"state_batch_indices_ptr"
]
is
not
None
})
@
triton
.
heuristics
(
{
"BLOCK_SIZE_DSTATE"
:
lambda
args
:
triton
.
next_power_of_2
(
args
[
"dstate"
])})
@
triton
.
jit
...
...
@@ -42,6 +47,7 @@ def _selective_scan_update_kernel(
D_ptr
,
z_ptr
,
out_ptr
,
state_batch_indices_ptr
,
# Matrix dimensions
batch
,
nheads
,
...
...
@@ -85,12 +91,24 @@ def _selective_scan_update_kernel(
HAS_DT_BIAS
:
tl
.
constexpr
,
HAS_D
:
tl
.
constexpr
,
HAS_Z
:
tl
.
constexpr
,
HAS_STATE_BATCH_INDICES
:
tl
.
constexpr
,
BLOCK_SIZE_DSTATE
:
tl
.
constexpr
,
):
pid_m
=
tl
.
program_id
(
axis
=
0
)
pid_b
=
tl
.
program_id
(
axis
=
1
)
pid_h
=
tl
.
program_id
(
axis
=
2
)
# If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
# is the same as the batch id.
if
HAS_STATE_BATCH_INDICES
:
state_batch_indices_ptr
+=
pid_b
state_batch_idx
=
tl
.
load
(
state_batch_indices_ptr
)
state_ptr
+=
(
state_batch_idx
*
stride_state_batch
+
pid_h
*
stride_state_head
)
else
:
state_ptr
+=
pid_b
*
stride_state_batch
+
pid_h
*
stride_state_head
x_ptr
+=
pid_b
*
stride_x_batch
+
pid_h
*
stride_x_head
dt_ptr
+=
pid_b
*
stride_dt_batch
+
pid_h
*
stride_dt_head
if
HAS_DT_BIAS
:
...
...
@@ -177,7 +195,8 @@ def selective_state_update(state,
D
=
None
,
z
=
None
,
dt_bias
=
None
,
dt_softplus
=
False
):
dt_softplus
=
False
,
state_batch_indices
=
None
):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
...
...
@@ -211,7 +230,10 @@ def selective_state_update(state,
z
=
z
.
unsqueeze
(
1
)
if
dt_bias
is
not
None
and
dt_bias
.
dim
()
==
1
:
dt_bias
=
dt_bias
.
unsqueeze
(
0
)
batch
,
nheads
,
dim
,
dstate
=
state
.
shape
_
,
nheads
,
dim
,
dstate
=
state
.
shape
batch
=
x
.
shape
[
0
]
assert
x
.
shape
==
(
batch
,
nheads
,
dim
)
assert
dt
.
shape
==
x
.
shape
assert
A
.
shape
==
(
nheads
,
dim
,
dstate
)
...
...
@@ -225,6 +247,8 @@ def selective_state_update(state,
assert
z
.
shape
==
x
.
shape
if
dt_bias
is
not
None
:
assert
dt_bias
.
shape
==
(
nheads
,
dim
)
if
state_batch_indices
is
not
None
:
assert
state_batch_indices
.
shape
==
(
batch
,
)
out
=
torch
.
empty_like
(
x
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
dim
,
META
[
'BLOCK_SIZE_M'
]),
batch
,
nheads
)
z_strides
=
((
z
.
stride
(
0
),
z
.
stride
(
1
),
z
.
stride
(
2
))
if
z
is
not
None
else
...
...
@@ -249,6 +273,7 @@ def selective_state_update(state,
D
,
z
,
out
,
state_batch_indices
,
batch
,
nheads
,
dim
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment