Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
78029b34
Unverified
Commit
78029b34
authored
Dec 08, 2024
by
zhou fan
Committed by
GitHub
Dec 08, 2024
Browse files
[BugFix][Kernel]: fix illegal memory access in causal_conv1d when conv_states is None (#10928)
Signed-off-by:
xffxff
<
1247714429@qq.com
>
parent
c889d588
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
18 deletions
+23
-18
csrc/mamba/causal_conv1d/causal_conv1d.cu
csrc/mamba/causal_conv1d/causal_conv1d.cu
+1
-1
tests/kernels/test_causal_conv1d.py
tests/kernels/test_causal_conv1d.py
+22
-17
No files found.
csrc/mamba/causal_conv1d/causal_conv1d.cu
View file @
78029b34
...
...
@@ -424,7 +424,7 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
// (which occurs when `final_state_position` is a non-positivie index)
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
if
(
final_state_position
<
0
&&
seqlen
>
kWidth
){
if
(
conv_states
!=
nullptr
&&
final_state_position
<
0
&&
seqlen
>
kWidth
){
input_t
vals_load
[
kNElts
]
=
{
0
};
if
((
chunk
==
n_chunks
-
2
)
&&
(
tidx
==
kNThreads
-
1
)){
// chunk = n_chunks - 2, a segment of the final state sits in the last index
...
...
tests/kernels/test_causal_conv1d.py
View file @
78029b34
...
...
@@ -149,13 +149,14 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor,
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
,
torch
.
float
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_initial_state"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
1
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
784
,
1024
,
1025
,
2048
,
4096
])
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
64
])
@
pytest
.
mark
.
parametrize
(
'batch'
,
[
1
])
def
test_causal_conv1d
(
batch
,
dim
,
seqlen
,
width
,
has_bias
,
silu_activation
,
itype
):
has_initial_state
,
itype
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
...
...
@@ -167,11 +168,18 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
if
has_initial_state
:
initial_states
=
torch
.
randn
(
batch
,
dim
,
width
-
1
,
device
=
device
,
dtype
=
itype
)
has_initial_state_tensor
=
torch
.
ones
(
batch
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
else
:
initial_states
=
None
has_initial_state_tensor
=
None
x_ref
=
x
.
clone
()
weight_ref
=
weight
.
clone
()
bias_ref
=
bias
.
clone
()
if
bias
is
not
None
else
None
...
...
@@ -183,9 +191,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
bias
,
activation
=
activation
,
conv_states
=
initial_states
,
has_initial_state
=
torch
.
ones
(
batch
,
dtype
=
torch
.
bool
,
device
=
x
.
device
))
has_initial_state
=
has_initial_state_tensor
)
out_ref
,
final_states_ref
=
causal_conv1d_ref
(
x_ref
,
weight_ref
,
...
...
@@ -193,6 +199,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
initial_states
=
initial_states_ref
,
return_final_states
=
True
,
activation
=
activation
)
if
has_initial_state
:
assert
initial_states
is
not
None
and
final_states_ref
is
not
None
assert
torch
.
allclose
(
initial_states
,
final_states_ref
,
...
...
@@ -205,9 +212,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
bias
,
activation
=
activation
,
conv_states
=
initial_states
,
has_initial_state
=
torch
.
ones
(
batch
,
dtype
=
torch
.
bool
,
device
=
x
.
device
))
has_initial_state
=
has_initial_state_tensor
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment