Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8110e445
Unverified
Commit
8110e445
authored
Sep 17, 2024
by
Tyler Michael Smith
Committed by
GitHub
Sep 17, 2024
Browse files
[Kernel] Change interface to Mamba causal_conv1d_update for continuous batching (#8012)
parent
09deb472
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
114 additions
and
16 deletions
+114
-16
csrc/mamba/causal_conv1d/causal_conv1d.cu
csrc/mamba/causal_conv1d/causal_conv1d.cu
+27
-3
csrc/mamba/causal_conv1d/causal_conv1d.h
csrc/mamba/causal_conv1d/causal_conv1d.h
+4
-0
csrc/ops.h
csrc/ops.h
+4
-5
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+3
-2
tests/kernels/test_causal_conv1d.py
tests/kernels/test_causal_conv1d.py
+58
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+10
-4
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
+8
-2
No files found.
csrc/mamba/causal_conv1d/causal_conv1d.cu
View file @
8110e445
...
@@ -198,7 +198,8 @@ causal_conv1d_update(const at::Tensor &x,
...
@@ -198,7 +198,8 @@ causal_conv1d_update(const at::Tensor &x,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>
&
bias_
,
const
c10
::
optional
<
at
::
Tensor
>
&
bias_
,
bool
silu_activation
)
{
bool
silu_activation
,
const
c10
::
optional
<
at
::
Tensor
>
&
conv_state_indices_
)
{
auto
input_type
=
x
.
scalar_type
();
auto
input_type
=
x
.
scalar_type
();
auto
weight_type
=
weight
.
scalar_type
();
auto
weight_type
=
weight
.
scalar_type
();
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
...
@@ -216,7 +217,6 @@ causal_conv1d_update(const at::Tensor &x,
...
@@ -216,7 +217,6 @@ causal_conv1d_update(const at::Tensor &x,
const
int
width
=
weight
.
size
(
-
1
);
const
int
width
=
weight
.
size
(
-
1
);
CHECK_SHAPE
(
x
,
batch_size
,
dim
);
CHECK_SHAPE
(
x
,
batch_size
,
dim
);
CHECK_SHAPE
(
conv_state
,
batch_size
,
dim
,
width
);
CHECK_SHAPE
(
weight
,
dim
,
width
);
CHECK_SHAPE
(
weight
,
dim
,
width
);
TORCH_CHECK
(
width
>=
2
&&
width
<=
4
,
"causal_conv1d only supports width between 2 and 4"
);
TORCH_CHECK
(
width
>=
2
&&
width
<=
4
,
"causal_conv1d only supports width between 2 and 4"
);
...
@@ -241,6 +241,22 @@ causal_conv1d_update(const at::Tensor &x,
...
@@ -241,6 +241,22 @@ causal_conv1d_update(const at::Tensor &x,
params
.
conv_state_c_stride
=
conv_state
.
stride
(
1
);
params
.
conv_state_c_stride
=
conv_state
.
stride
(
1
);
params
.
conv_state_l_stride
=
conv_state
.
stride
(
2
);
params
.
conv_state_l_stride
=
conv_state
.
stride
(
2
);
if
(
conv_state_indices_
.
has_value
())
{
auto
conv_state_indices
=
conv_state_indices_
.
value
();
TORCH_CHECK
(
conv_state_indices
.
scalar_type
()
==
torch
::
kInt32
)
TORCH_CHECK
(
conv_state_indices
.
is_cuda
());
TORCH_CHECK
(
conv_state_indices
.
stride
(
0
)
==
1
)
CHECK_SHAPE
(
conv_state_indices
,
batch_size
);
int
conv_state_entries
=
conv_state
.
size
(
0
);
CHECK_SHAPE
(
conv_state
,
conv_state_entries
,
dim
,
width
);
params
.
conv_state_indices_ptr
=
conv_state_indices
.
data_ptr
<
int32_t
>
();
}
else
{
CHECK_SHAPE
(
conv_state
,
batch_size
,
dim
,
width
);
params
.
conv_state_indices_ptr
=
nullptr
;
}
// Otherwise the kernel will be launched from cuda:0 device
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
x
.
get_device
()};
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
x
.
get_device
()};
...
@@ -646,8 +662,16 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
...
@@ -646,8 +662,16 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
const
int
channel_id
=
blockIdx
.
y
*
kNThreads
+
tidx
;
const
int
channel_id
=
blockIdx
.
y
*
kNThreads
+
tidx
;
input_t
*
x
=
reinterpret_cast
<
input_t
*>
(
params
.
x_ptr
)
+
batch_id
*
params
.
x_batch_stride
input_t
*
x
=
reinterpret_cast
<
input_t
*>
(
params
.
x_ptr
)
+
batch_id
*
params
.
x_batch_stride
+
channel_id
*
params
.
x_c_stride
;
+
channel_id
*
params
.
x_c_stride
;
input_t
*
conv_state
=
reinterpret_cast
<
input_t
*>
(
params
.
conv_state_ptr
)
+
batch_id
*
params
.
conv_state_batch_stride
// If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
// along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
const
int
conv_state_batch_coord
=
params
.
conv_state_indices_ptr
==
nullptr
?
batch_id
:
params
.
conv_state_indices_ptr
[
batch_id
];
input_t
*
conv_state
=
reinterpret_cast
<
input_t
*>
(
params
.
conv_state_ptr
)
+
conv_state_batch_coord
*
params
.
conv_state_batch_stride
+
channel_id
*
params
.
conv_state_c_stride
;
+
channel_id
*
params
.
conv_state_c_stride
;
weight_t
*
weight
=
reinterpret_cast
<
weight_t
*>
(
params
.
weight_ptr
)
+
channel_id
*
params
.
weight_c_stride
;
weight_t
*
weight
=
reinterpret_cast
<
weight_t
*>
(
params
.
weight_ptr
)
+
channel_id
*
params
.
weight_c_stride
;
input_t
*
out
=
reinterpret_cast
<
input_t
*>
(
params
.
out_ptr
)
+
batch_id
*
params
.
out_batch_stride
input_t
*
out
=
reinterpret_cast
<
input_t
*>
(
params
.
out_ptr
)
+
batch_id
*
params
.
out_batch_stride
+
channel_id
*
params
.
out_c_stride
;
+
channel_id
*
params
.
out_c_stride
;
...
...
csrc/mamba/causal_conv1d/causal_conv1d.h
View file @
8110e445
...
@@ -36,6 +36,10 @@ struct ConvParamsBase {
...
@@ -36,6 +36,10 @@ struct ConvParamsBase {
void
*
__restrict__
conv_state_ptr
;
void
*
__restrict__
conv_state_ptr
;
// For the continuous batching case. Makes it so that the mamba state for
// the current batch doesn't need to be a contiguous tensor.
int32_t
*
__restrict__
conv_state_indices_ptr
;
void
*
__restrict__
seq_idx_ptr
;
void
*
__restrict__
seq_idx_ptr
;
// No __restrict__ since initial_states could be the same as final_states.
// No __restrict__ since initial_states could be the same as final_states.
...
...
csrc/ops.h
View file @
8110e445
...
@@ -222,11 +222,10 @@ std::vector<torch::Tensor> selective_scan_fwd(
...
@@ -222,11 +222,10 @@ std::vector<torch::Tensor> selective_scan_fwd(
const
c10
::
optional
<
torch
::
Tensor
>&
index_
,
const
c10
::
optional
<
torch
::
Tensor
>&
index_
,
const
c10
::
optional
<
torch
::
Tensor
>&
x
);
const
c10
::
optional
<
torch
::
Tensor
>&
x
);
at
::
Tensor
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
at
::
Tensor
causal_conv1d_update
(
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias
,
bool
silu_activation
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
const
c10
::
optional
<
at
::
Tensor
>&
conv_state_indices
);
bool
silu_activation
);
at
::
Tensor
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
at
::
Tensor
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
...
...
csrc/torch_bindings.cpp
View file @
8110e445
...
@@ -279,8 +279,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -279,8 +279,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"causal_conv1d_update(Tensor! x,"
"causal_conv1d_update(Tensor! x,"
"Tensor! conv_state,"
"Tensor! conv_state,"
"Tensor! weight,"
"Tensor! weight,"
"Tensor? bias_,"
"Tensor? bias,"
"bool silu_activation) -> Tensor"
);
"bool silu_activation,"
"Tensor? conv_state_indices) -> Tensor"
);
ops
.
impl
(
"causal_conv1d_update"
,
torch
::
kCUDA
,
&
causal_conv1d_update
);
ops
.
impl
(
"causal_conv1d_update"
,
torch
::
kCUDA
,
&
causal_conv1d_update
);
ops
.
def
(
ops
.
def
(
...
...
tests/kernels/test_causal_conv1d.py
View file @
8110e445
...
@@ -203,3 +203,61 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
...
@@ -203,3 +203,61 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
assert
torch
.
equal
(
conv_state
,
conv_state_ref
)
assert
torch
.
equal
(
conv_state
,
conv_state_ref
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
1
,
4
,
5
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
2
,
3
,
4
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
2048
+
16
,
4096
])
def
test_causal_conv1d_update_with_batch_gather
(
dim
,
width
,
seqlen
,
has_bias
,
silu_activation
,
itype
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
# set seed
torch
.
random
.
manual_seed
(
0
)
batch
=
64
x
=
torch
.
randn
(
batch
,
dim
,
device
=
device
,
dtype
=
itype
)
total_entries
=
10
*
batch
conv_state
=
torch
.
randn
(
total_entries
,
dim
,
width
,
device
=
device
,
dtype
=
itype
)
conv_state_indices
=
torch
.
randperm
(
total_entries
)[:
batch
].
to
(
dtype
=
torch
.
int32
,
device
=
device
)
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
,
requires_grad
=
True
)
if
has_bias
:
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
,
requires_grad
=
True
)
else
:
bias
=
None
conv_state_ref
=
conv_state
[
conv_state_indices
,
:].
detach
().
clone
()
activation
=
None
if
not
silu_activation
else
"silu"
out
=
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias
,
activation
=
activation
,
conv_state_indices
=
conv_state_indices
)
out_ref
=
causal_conv1d_update_ref
(
x
,
conv_state_ref
,
weight
,
bias
,
activation
=
activation
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
assert
torch
.
equal
(
conv_state
[
conv_state_indices
,
:],
conv_state_ref
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
vllm/_custom_ops.py
View file @
8110e445
...
@@ -768,11 +768,17 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
...
@@ -768,11 +768,17 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
silu_activation
)
silu_activation
)
def
causal_conv1d_update
(
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
def
causal_conv1d_update
(
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
x
:
torch
.
Tensor
,
silu_activation
:
bool
)
->
torch
.
Tensor
:
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
,
conv_state_indices
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias_
,
return
torch
.
ops
.
_C
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias_
,
silu_activation
)
silu_activation
,
conv_state_indices
)
def
selective_scan_fwd
(
u
:
torch
.
Tensor
,
delta
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
def
selective_scan_fwd
(
u
:
torch
.
Tensor
,
delta
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
View file @
8110e445
# Copyright (c) 2024, Tri Dao.
# Copyright (c) 2024, Tri Dao.
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
from
typing
import
Optional
from
typing
import
Optional
...
@@ -70,12 +71,17 @@ def causal_conv1d_update(x: torch.Tensor,
...
@@ -70,12 +71,17 @@ def causal_conv1d_update(x: torch.Tensor,
conv_state
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
None
):
activation
:
Optional
[
str
]
=
None
,
conv_state_indices
:
Optional
[
torch
.
Tensor
]
=
None
):
"""
"""
x: (batch, dim)
x: (batch, dim)
conv_state: (batch, dim, width)
conv_state: (batch, dim, width)
weight: (dim, width)
weight: (dim, width)
bias: (dim,)
bias: (dim,)
conv_state_indices: (batch,), dtype int32
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
out: (batch, dim)
out: (batch, dim)
"""
"""
...
@@ -83,4 +89,4 @@ def causal_conv1d_update(x: torch.Tensor,
...
@@ -83,4 +89,4 @@ def causal_conv1d_update(x: torch.Tensor,
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
activation_bool
=
activation
in
[
"silu"
,
"swish"
]
activation_bool
=
activation
in
[
"silu"
,
"swish"
]
return
ops
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias
,
return
ops
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias
,
activation_bool
)
activation_bool
,
conv_state_indices
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment