Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9fb12f78
Unverified
Commit
9fb12f78
authored
Oct 31, 2024
by
Mor Zusman
Committed by
GitHub
Oct 31, 2024
Browse files
[BugFix][Kernel] Fix Illegal memory access in causal_conv1d in H100 (#9838)
Signed-off-by:
mzusman
<
mor.zusmann@gmail.com
>
parent
55650c83
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
40 additions
and
7 deletions
+40
-7
csrc/mamba/causal_conv1d/causal_conv1d.cu
csrc/mamba/causal_conv1d/causal_conv1d.cu
+32
-2
tests/kernels/test_causal_conv1d.py
tests/kernels/test_causal_conv1d.py
+5
-2
tests/kernels/test_mamba_ssm.py
tests/kernels/test_mamba_ssm.py
+3
-3
No files found.
csrc/mamba/causal_conv1d/causal_conv1d.cu
View file @
9fb12f78
...
@@ -418,6 +418,31 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
...
@@ -418,6 +418,31 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
typename
Ktraits
::
BlockStoreT
(
smem_store
).
Store
(
out
,
out_vals_store
,
seqlen
-
chunk
*
kChunkSize
);
typename
Ktraits
::
BlockStoreT
(
smem_store
).
Store
(
out
,
out_vals_store
,
seqlen
-
chunk
*
kChunkSize
);
}
}
out
+=
kChunkSize
;
out
+=
kChunkSize
;
int
final_state_position
=
((
seqlen
-
(
kWidth
-
1
))
-
(
n_chunks
-
1
)
*
kChunkSize
);
// in case the final state is separated between the last "smem_exchange" and
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
// (which occurs when `final_state_position` is a non-positivie index)
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
if
(
final_state_position
<
0
&&
seqlen
>
kWidth
){
input_t
vals_load
[
kNElts
]
=
{
0
};
if
((
chunk
==
n_chunks
-
2
)
&&
(
tidx
==
kNThreads
-
1
)){
// chunk = n_chunks - 2, a segment of the final state sits in the last index
reinterpret_cast
<
vec_t
*>
(
vals_load
)[
0
]
=
smem_exchange
[
kNThreads
-
1
];
#pragma unroll
for
(
int
w
=
0
;
w
<
-
final_state_position
;
++
w
){
conv_states
[
w
]
=
vals_load
[
kNElts
+
final_state_position
+
w
];
}
}
if
((
chunk
==
n_chunks
-
1
)
&&
tidx
==
0
){
// chunk = n_chunks - 1, the second segment of the final state first positions
reinterpret_cast
<
vec_t
*>
(
vals_load
)[
0
]
=
smem_exchange
[
0
];
for
(
int
w
=
-
final_state_position
;
w
<
kWidth
-
1
;
++
w
){
conv_states
[
w
]
=
vals_load
[
w
+
final_state_position
];
}
return
;
}
}
}
}
// Final state is stored in the smem_exchange last token slot,
// Final state is stored in the smem_exchange last token slot,
// in case seqlen < kWidth, we would need to take the final state from the
// in case seqlen < kWidth, we would need to take the final state from the
...
@@ -446,9 +471,14 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
...
@@ -446,9 +471,14 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
}
}
else
{
else
{
// in case the final state is in between the threads data
// in case the final state is in between the threads data
const
int
offset
=
((
seqlen
-
(
kWidth
-
1
))
%
(
kNElts
));
if
((
offset
+
kWidth
-
2
)
>=
kNElts
&&
(
last_thread
+
1
<
kNThreads
)){
// In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a
// illegal access error on H100.
// Therefore, we access last_thread + 1, only if the final state data sits there
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
1
]
=
smem_exchange
[
last_thread
+
1
];
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
1
]
=
smem_exchange
[
last_thread
+
1
];
}
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
0
]
=
smem_exchange
[
last_thread
];
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
0
]
=
smem_exchange
[
last_thread
];
const
int
offset
=
((
seqlen
-
(
kWidth
-
1
))
%
(
kNElts
));
#pragma unroll
#pragma unroll
for
(
int
w
=
0
;
w
<
kWidth
-
1
;
++
w
){
for
(
int
w
=
0
;
w
<
kWidth
-
1
;
++
w
){
conv_states
[
w
]
=
x_vals_load
[
offset
+
w
];
conv_states
[
w
]
=
x_vals_load
[
offset
+
w
];
...
...
tests/kernels/test_causal_conv1d.py
View file @
9fb12f78
...
@@ -151,7 +151,7 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor,
...
@@ -151,7 +151,7 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor,
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
1
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
784
,
1024
,
2048
,
4096
])
'seqlen'
,
[
1
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
784
,
1024
,
1025
,
2048
,
4096
])
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
64
])
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
64
])
@
pytest
.
mark
.
parametrize
(
'batch'
,
[
1
])
@
pytest
.
mark
.
parametrize
(
'batch'
,
[
1
])
def
test_causal_conv1d
(
batch
,
dim
,
seqlen
,
width
,
has_bias
,
silu_activation
,
def
test_causal_conv1d
(
batch
,
dim
,
seqlen
,
width
,
has_bias
,
silu_activation
,
...
@@ -420,7 +420,10 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
...
@@ -420,7 +420,10 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
unpadded_out
=
out
[:,
:
out_ref_tensor
.
shape
[
-
1
]]
unpadded_out
=
out
[:,
:
out_ref_tensor
.
shape
[
-
1
]]
assert
torch
.
allclose
(
unpadded_out
,
out_ref_tensor
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
unpadded_out
,
out_ref_tensor
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
final_states
,
final_states_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
final_states
[
state_indices
],
final_states_ref
[
state_indices
],
rtol
=
rtol
,
atol
=
atol
)
causal_conv1d_opcheck_fn
(
x
.
squeeze
(
0
),
weight
,
bias
,
cumsum
.
cuda
(),
causal_conv1d_opcheck_fn
(
x
.
squeeze
(
0
),
weight
,
bias
,
cumsum
.
cuda
(),
padded_state_indices
,
has_initial_states
,
padded_state_indices
,
has_initial_states
,
...
...
tests/kernels/test_mamba_ssm.py
View file @
9fb12f78
...
@@ -555,7 +555,7 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
...
@@ -555,7 +555,7 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
device
=
"cuda"
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
5e-3
,
1e-2
)
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
5e-3
,
1e-2
)
if
itype
==
torch
.
bfloat16
:
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
7
e-
2
,
7
e-
2
rtol
,
atol
=
1
e-
1
,
1
e-
1
if
torch
.
version
.
hip
:
if
torch
.
version
.
hip
:
atol
*=
2
atol
*=
2
# set seed
# set seed
...
@@ -610,8 +610,8 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
...
@@ -610,8 +610,8 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
dt_bias
=
dt_bias
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
)
dt_softplus
=
True
)
print
(
"Output diff max"
,
(
out
-
out_ref
[
0
]
).
max
())
print
(
"Output diff max"
,
(
out
[:
batch_size
]
-
out_ref
).
max
())
print
(
"Output diff mean"
,
(
out
-
out_ref
[
0
]
).
mean
())
print
(
"Output diff mean"
,
(
out
[:
batch_size
]
-
out_ref
).
mean
())
print
(
"Output state diff max"
,
(
state
[
state_indices
,
:]
-
state_ref
).
max
())
print
(
"Output state diff max"
,
(
state
[
state_indices
,
:]
-
state_ref
).
max
())
print
(
"Output state diff mean"
,
print
(
"Output state diff mean"
,
(
state
[
state_indices
,
:]
-
state_ref
).
mean
())
(
state
[
state_indices
,
:]
-
state_ref
).
mean
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment