Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8cbe1538
Unverified
Commit
8cbe1538
authored
Sep 10, 2025
by
Yi Zhang
Committed by
GitHub
Sep 09, 2025
Browse files
Add mamba kernel (#10234)
parent
8471e5e6
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1418 additions
and
0 deletions
+1418
-0
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+1
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+25
-0
sgl-kernel/csrc/mamba/causal_conv1d.cu
sgl-kernel/csrc/mamba/causal_conv1d.cu
+669
-0
sgl-kernel/csrc/mamba/causal_conv1d.h
sgl-kernel/csrc/mamba/causal_conv1d.h
+159
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+24
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-0
sgl-kernel/python/sgl_kernel/mamba.py
sgl-kernel/python/sgl_kernel/mamba.py
+50
-0
sgl-kernel/tests/test_causal_conv1d.py
sgl-kernel/tests/test_causal_conv1d.py
+489
-0
No files found.
sgl-kernel/CMakeLists.txt
View file @
8cbe1538
...
...
@@ -303,6 +303,7 @@ set(SOURCES
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
"csrc/moe/marlin_moe_wna16/ops.cu"
"csrc/mamba/causal_conv1d.cu"
"csrc/moe/moe_align_kernel.cu"
"csrc/moe/moe_fused_gate.cu"
"csrc/moe/moe_topk_softmax_kernels.cu"
...
...
sgl-kernel/csrc/common_extension.cc
View file @
8cbe1538
...
...
@@ -438,6 +438,31 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
impl
(
"copy_to_gpu_no_ce"
,
torch
::
kCUDA
,
&
copy_to_gpu_no_ce
);
m
.
def
(
"concat_mla_k(Tensor! k, Tensor k_nope, Tensor k_rope) -> ()"
);
m
.
impl
(
"concat_mla_k"
,
torch
::
kCUDA
,
&
concat_mla_k
);
/*
* From csrc/mamba
*/
m
.
def
(
"causal_conv1d_update(Tensor! x,"
"Tensor! conv_state,"
"Tensor! weight,"
"Tensor? bias_,"
"bool silu_activation,"
"Tensor? cache_seqlens_,"
"Tensor? conv_state_indices,"
"int pad_slot_id) -> ()"
);
m
.
impl
(
"causal_conv1d_update"
,
torch
::
kCUDA
,
&
causal_conv1d_update
);
m
.
def
(
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
"Tensor? bias_,"
"Tensor!? conv_states,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"bool silu_activation,"
"int pad_slot_id) -> ()"
);
m
.
impl
(
"causal_conv1d_fwd"
,
torch
::
kCUDA
,
&
causal_conv1d_fwd
);
}
REGISTER_EXTENSION
(
common_ops
)
sgl-kernel/csrc/mamba/causal_conv1d.cu
0 → 100644
View file @
8cbe1538
This diff is collapsed.
Click to expand it.
sgl-kernel/csrc/mamba/causal_conv1d.h
0 → 100644
View file @
8cbe1538
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
// clang-format off
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
ConvParamsBase
{
using
index_t
=
uint32_t
;
int
batch
,
dim
,
seqlen
,
width
;
int64_t
pad_slot_id
;
bool
silu_activation
;
index_t
x_batch_stride
;
index_t
x_c_stride
;
index_t
x_l_stride
;
index_t
weight_c_stride
;
index_t
weight_width_stride
;
index_t
out_batch_stride
;
index_t
out_c_stride
;
index_t
out_l_stride
;
int
conv_state_len
;
index_t
conv_state_batch_stride
;
index_t
conv_state_c_stride
;
index_t
conv_state_l_stride
;
// Common data pointers.
void
*
__restrict__
x_ptr
;
void
*
__restrict__
weight_ptr
;
void
*
__restrict__
bias_ptr
;
void
*
__restrict__
out_ptr
;
void
*
__restrict__
conv_state_ptr
;
void
*
__restrict__
query_start_loc_ptr
;
void
*
__restrict__
has_initial_state_ptr
;
void
*
__restrict__
cache_indices_ptr
;
int32_t
*
__restrict__
cache_seqlens
;
// For the continuous batching case. Makes it so that the mamba state for
// the current batch doesn't need to be a contiguous tensor.
int32_t
*
__restrict__
conv_state_indices_ptr
;
void
*
__restrict__
seq_idx_ptr
;
// No __restrict__ since initial_states could be the same as final_states.
void
*
initial_states_ptr
;
index_t
initial_states_batch_stride
;
index_t
initial_states_l_stride
;
index_t
initial_states_c_stride
;
void
*
final_states_ptr
;
index_t
final_states_batch_stride
;
index_t
final_states_l_stride
;
index_t
final_states_c_stride
;
void
*
conv_states_ptr
;
index_t
conv_states_batch_stride
;
index_t
conv_states_l_stride
;
index_t
conv_states_c_stride
;
};
#ifndef USE_ROCM
#include <cuda_bf16.h>
template
<
typename
T
>
__device__
inline
T
shuffle_xor
(
T
val
,
int
offset
)
{
return
__shfl_xor_sync
(
uint32_t
(
-
1
),
val
,
offset
);
}
constexpr
size_t
custom_max
(
std
::
initializer_list
<
size_t
>
ilist
)
{
return
std
::
max
(
ilist
);
}
template
<
typename
T
>
constexpr
T
constexpr_min
(
T
a
,
T
b
)
{
return
std
::
min
(
a
,
b
);
}
#else
#include <hip/hip_bf16.h>
template
<
typename
T
>
__device__
inline
T
shuffle_xor
(
T
val
,
int
offset
)
{
return
__shfl_xor
(
val
,
offset
);
}
constexpr
size_t
custom_max
(
std
::
initializer_list
<
size_t
>
ilist
)
{
return
*
std
::
max_element
(
ilist
.
begin
(),
ilist
.
end
());
}
template
<
typename
T
>
constexpr
T
constexpr_min
(
T
a
,
T
b
)
{
return
a
<
b
?
a
:
b
;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
BYTES
>
struct
BytesToType
{};
template
<
>
struct
BytesToType
<
16
>
{
using
Type
=
uint4
;
static_assert
(
sizeof
(
Type
)
==
16
);
};
template
<
>
struct
BytesToType
<
8
>
{
using
Type
=
uint64_t
;
static_assert
(
sizeof
(
Type
)
==
8
);
};
template
<
>
struct
BytesToType
<
4
>
{
using
Type
=
uint32_t
;
static_assert
(
sizeof
(
Type
)
==
4
);
};
template
<
>
struct
BytesToType
<
2
>
{
using
Type
=
uint16_t
;
static_assert
(
sizeof
(
Type
)
==
2
);
};
template
<
>
struct
BytesToType
<
1
>
{
using
Type
=
uint8_t
;
static_assert
(
sizeof
(
Type
)
==
1
);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
SumOp
{
__device__
inline
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
+
y
;
}
};
template
<
int
THREADS
>
struct
Allreduce
{
static_assert
(
THREADS
==
32
||
THREADS
==
16
||
THREADS
==
8
||
THREADS
==
4
);
template
<
typename
T
,
typename
Operator
>
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
constexpr
int
OFFSET
=
THREADS
/
2
;
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
OFFSET
));
return
Allreduce
<
OFFSET
>::
run
(
x
,
op
);
}
};
template
<
>
struct
Allreduce
<
2
>
{
template
<
typename
T
,
typename
Operator
>
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
1
));
return
x
;
}
};
sgl-kernel/include/sgl_kernel_ops.h
View file @
8cbe1538
...
...
@@ -724,3 +724,27 @@ void store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc,
void
copy_to_gpu_no_ce
(
const
at
::
Tensor
&
input
,
at
::
Tensor
&
output
);
void
concat_mla_k
(
torch
::
Tensor
k
,
torch
::
Tensor
k_nope
,
torch
::
Tensor
k_rope
);
/*
* From csrc/mamba
*/
void
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
std
::
optional
<
at
::
Tensor
>&
bias_
,
bool
silu_activation
,
const
std
::
optional
<
at
::
Tensor
>&
cache_seqlens_
,
const
std
::
optional
<
at
::
Tensor
>&
conv_state_indices_
,
int64_t
pad_slot_id
);
void
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
std
::
optional
<
at
::
Tensor
>&
bias_
,
const
std
::
optional
<
at
::
Tensor
>&
conv_states
,
const
std
::
optional
<
at
::
Tensor
>&
query_start_loc
,
const
std
::
optional
<
at
::
Tensor
>&
cache_indices
,
const
std
::
optional
<
at
::
Tensor
>&
has_initial_state
,
bool
silu_activation
,
int64_t
pad_slot_id
);
sgl-kernel/python/sgl_kernel/__init__.py
View file @
8cbe1538
...
...
@@ -34,6 +34,7 @@ from sgl_kernel.elementwise import (
rmsnorm
,
silu_and_mul
,
)
from
sgl_kernel.mamba
import
causal_conv1d_fwd
,
causal_conv1d_update
if
torch
.
version
.
hip
is
not
None
:
from
sgl_kernel.elementwise
import
gelu_quick
...
...
sgl-kernel/python/sgl_kernel/mamba.py
0 → 100644
View file @
8cbe1538
from
typing
import
Optional
import
torch
# mamba
def
causal_conv1d_fwd
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
conv_states
:
Optional
[
torch
.
Tensor
],
query_start_loc
:
Optional
[
torch
.
Tensor
],
cache_indices
:
Optional
[
torch
.
Tensor
],
has_initial_state
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
,
pad_slot_id
:
int
,
):
torch
.
ops
.
sgl_kernel
.
causal_conv1d_fwd
(
x
,
weight
,
bias_
,
conv_states
,
query_start_loc
,
cache_indices
,
has_initial_state
,
silu_activation
,
pad_slot_id
,
)
def
causal_conv1d_update
(
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
,
cache_seqlens
:
Optional
[
torch
.
Tensor
],
conv_state_indices
:
Optional
[
torch
.
Tensor
],
pad_slot_id
:
int
,
):
torch
.
ops
.
sgl_kernel
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias_
,
silu_activation
,
cache_seqlens
,
conv_state_indices
,
pad_slot_id
,
)
sgl-kernel/tests/test_causal_conv1d.py
0 → 100644
View file @
8cbe1538
# Adapted from https://github.com/vllm-project/vllm/blob/main/tests/kernels/mamba/test_causal_conv1d.py
from
typing
import
Optional
import
torch
from
sgl_kernel
import
causal_conv1d_fwd
from
sgl_kernel
import
causal_conv1d_update
as
causal_conv1d_update_kernel
PAD_SLOT_ID
=
-
1
def
causal_conv1d_fn
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
has_initial_state
:
Optional
[
torch
.
Tensor
]
=
None
,
conv_states
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
"silu"
,
pad_slot_id
:
int
=
PAD_SLOT_ID
,
):
"""
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
sequences are concatenated from left to right for varlen
weight: (dim, width)
bias: (dim,)
query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended by 0.
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
indicates the corresponding state index,
like so: conv_state = conv_states[cache_indices[batch_id]]
has_initial_state: (batch) bool
indicates whether should the kernel take the current state as initial
state for the calculations
conv_states: (...,dim,width - 1) itype
updated inplace if provided
activation: either None or "silu" or "swish"
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim, seqlen)
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
if
x
.
stride
(
-
1
)
!=
1
:
x
=
x
.
contiguous
()
bias
=
bias
.
contiguous
()
if
bias
is
not
None
else
None
causal_conv1d_fwd
(
x
,
weight
,
bias
,
conv_states
,
query_start_loc
,
cache_indices
,
has_initial_state
,
activation
in
[
"silu"
,
"swish"
],
pad_slot_id
,
)
return
x
def
causal_conv1d_update
(
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
None
,
cache_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
conv_state_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
pad_slot_id
:
int
=
PAD_SLOT_ID
,
):
"""
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the conv_state
starting at the index
@cache_seqlens % state_len.
conv_state_indices: (batch,), dtype int32
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen)
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
raise
NotImplementedError
(
f
"activation must be None, silu, or swish, actual:
{
activation
}
"
)
activation_val
=
activation
in
[
"silu"
,
"swish"
]
unsqueeze
=
x
.
dim
()
==
2
if
unsqueeze
:
x
=
x
.
unsqueeze
(
-
1
)
causal_conv1d_update_kernel
(
x
,
conv_state
,
weight
,
bias
,
activation_val
,
cache_seqlens
,
conv_state_indices
,
pad_slot_id
,
)
if
unsqueeze
:
x
=
x
.
squeeze
(
-
1
)
return
x
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
import
pytest
import
torch
import
torch.nn.functional
as
F
def
causal_conv1d_ref
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
initial_states
:
Optional
[
torch
.
Tensor
]
=
None
,
return_final_states
:
bool
=
False
,
final_states_out
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
"silu"
,
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1)
out: (batch, dim, seqlen)
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
dtype_in
=
x
.
dtype
x
=
x
.
to
(
weight
.
dtype
)
seqlen
=
x
.
shape
[
-
1
]
dim
,
width
=
weight
.
shape
if
initial_states
is
None
:
out
=
F
.
conv1d
(
x
,
weight
.
unsqueeze
(
1
),
bias
,
padding
=
width
-
1
,
groups
=
dim
)
else
:
x
=
torch
.
cat
([
initial_states
,
x
],
dim
=-
1
)
out
=
F
.
conv1d
(
x
,
weight
.
unsqueeze
(
1
),
bias
,
padding
=
0
,
groups
=
dim
)
out
=
out
[...,
:
seqlen
]
if
return_final_states
:
final_states
=
F
.
pad
(
x
,
(
width
-
1
-
x
.
shape
[
-
1
],
0
)).
to
(
dtype_in
)
# (batch, dim, width - 1)
if
final_states_out
is
not
None
:
final_states_out
.
copy_
(
final_states
)
else
:
final_states_out
=
final_states
out
=
(
out
if
activation
is
None
else
F
.
silu
(
out
)).
to
(
dtype
=
dtype_in
)
return
(
out
,
None
)
if
not
return_final_states
else
(
out
,
final_states_out
)
def
causal_conv1d_update_ref
(
x
,
conv_state
,
weight
,
bias
=
None
,
activation
=
None
,
cache_seqlens
=
None
):
"""
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the
conv_state starting at the index
@cache_seqlens % state_len before performing the convolution.
out: (batch, dim) or (batch, dim, seqlen)
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
dtype_in
=
x
.
dtype
unsqueeze
=
x
.
dim
()
==
2
if
unsqueeze
:
x
=
x
.
unsqueeze
(
-
1
)
batch
,
dim
,
seqlen
=
x
.
shape
width
=
weight
.
shape
[
1
]
state_len
=
conv_state
.
shape
[
-
1
]
assert
conv_state
.
shape
==
(
batch
,
dim
,
state_len
)
assert
weight
.
shape
==
(
dim
,
width
)
if
cache_seqlens
is
None
:
x_new
=
torch
.
cat
([
conv_state
,
x
],
dim
=-
1
).
to
(
weight
.
dtype
)
# (batch, dim, state_len + seqlen)
conv_state
.
copy_
(
x_new
[:,
:,
-
state_len
:])
else
:
width_idx
=
torch
.
arange
(
-
(
width
-
1
),
0
,
dtype
=
torch
.
long
,
device
=
x
.
device
).
unsqueeze
(
0
)
+
cache_seqlens
.
unsqueeze
(
1
)
width_idx
=
(
torch
.
remainder
(
width_idx
,
state_len
).
unsqueeze
(
1
).
expand
(
-
1
,
dim
,
-
1
)
)
x_new
=
torch
.
cat
([
conv_state
.
gather
(
2
,
width_idx
),
x
],
dim
=-
1
).
to
(
weight
.
dtype
)
copy_idx
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
x
.
device
).
unsqueeze
(
0
)
+
cache_seqlens
.
unsqueeze
(
1
)
copy_idx
=
torch
.
remainder
(
copy_idx
,
state_len
).
unsqueeze
(
1
).
expand
(
-
1
,
dim
,
-
1
)
conv_state
.
scatter_
(
2
,
copy_idx
,
x
)
out
=
F
.
conv1d
(
x_new
,
weight
.
unsqueeze
(
1
),
bias
,
padding
=
0
,
groups
=
dim
)[
:,
:,
-
seqlen
:
]
if
unsqueeze
:
out
=
out
.
squeeze
(
-
1
)
return
(
out
if
activation
is
None
else
F
.
silu
(
out
)).
to
(
dtype
=
dtype_in
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
,
torch
.
float
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_initial_state"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
1
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
784
,
1024
,
1025
,
2048
,
4096
]
)
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"batch"
,
[
1
])
def
test_causal_conv1d
(
batch
,
dim
,
seqlen
,
width
,
has_bias
,
silu_activation
,
has_initial_state
,
itype
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
x
=
torch
.
randn
(
batch
,
dim
,
seqlen
,
device
=
device
,
dtype
=
itype
).
contiguous
()
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
if
has_initial_state
:
initial_states
=
torch
.
randn
(
batch
,
dim
,
width
-
1
,
device
=
device
,
dtype
=
itype
)
has_initial_state_tensor
=
torch
.
ones
(
batch
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
else
:
initial_states
=
None
has_initial_state_tensor
=
None
x_ref
=
x
.
clone
()
weight_ref
=
weight
.
clone
()
bias_ref
=
bias
.
clone
()
if
bias
is
not
None
else
None
initial_states_ref
=
initial_states
.
clone
()
if
initial_states
is
not
None
else
None
activation
=
None
if
not
silu_activation
else
"silu"
out
=
causal_conv1d_fn
(
x
,
weight
,
bias
,
activation
=
activation
,
conv_states
=
initial_states
,
has_initial_state
=
has_initial_state_tensor
,
)
out_ref
,
final_states_ref
=
causal_conv1d_ref
(
x_ref
,
weight_ref
,
bias_ref
,
initial_states
=
initial_states_ref
,
return_final_states
=
True
,
activation
=
activation
,
)
if
has_initial_state
:
assert
initial_states
is
not
None
and
final_states_ref
is
not
None
assert
torch
.
allclose
(
initial_states
,
final_states_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
2048
+
16
,
4096
])
def
test_causal_conv1d_update
(
dim
,
width
,
seqlen
,
has_bias
,
silu_activation
,
itype
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
batch
=
2
x
=
torch
.
randn
(
batch
,
dim
,
seqlen
,
device
=
device
,
dtype
=
itype
)
x_ref
=
x
.
clone
()
conv_state
=
torch
.
randn
(
batch
,
dim
,
width
-
1
,
device
=
device
,
dtype
=
itype
)
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
conv_state_ref
=
conv_state
.
detach
().
clone
()
activation
=
None
if
not
silu_activation
else
"silu"
out
=
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias
,
activation
=
activation
)
out_ref
=
causal_conv1d_update_ref
(
x_ref
,
conv_state_ref
,
weight
,
bias
,
activation
=
activation
)
assert
torch
.
equal
(
conv_state
,
conv_state_ref
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
1
,
4
,
5
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
2
,
3
,
4
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
2048
+
16
,
4096
])
# tests correctness in case subset of the sequences are padded
@
pytest
.
mark
.
parametrize
(
"with_padding"
,
[
True
,
False
])
def
test_causal_conv1d_update_with_batch_gather
(
with_padding
,
dim
,
width
,
seqlen
,
has_bias
,
silu_activation
,
itype
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
batch_size
=
3
padding
=
5
if
with_padding
else
0
padded_batch_size
=
batch_size
+
padding
total_entries
=
10
*
batch_size
x
=
torch
.
randn
(
padded_batch_size
,
dim
,
1
,
device
=
device
,
dtype
=
itype
)
x_ref
=
x
.
clone
()
conv_state_indices
=
torch
.
randperm
(
total_entries
)[:
batch_size
].
to
(
dtype
=
torch
.
int32
,
device
=
device
)
unused_states_bool
=
torch
.
ones
(
total_entries
,
dtype
=
torch
.
bool
,
device
=
device
)
unused_states_bool
[
conv_state_indices
]
=
False
padded_state_indices
=
torch
.
concat
(
[
conv_state_indices
,
torch
.
as_tensor
([
PAD_SLOT_ID
]
*
padding
,
dtype
=
torch
.
int32
,
device
=
device
),
],
dim
=
0
,
)
conv_state
=
torch
.
randn
(
total_entries
,
dim
,
width
-
1
,
device
=
device
,
dtype
=
itype
)
conv_state_for_padding_test
=
conv_state
.
clone
()
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
conv_state_ref
=
conv_state
[
conv_state_indices
,
:].
detach
().
clone
()
activation
=
None
if
not
silu_activation
else
"silu"
out
=
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias
,
activation
=
activation
,
conv_state_indices
=
padded_state_indices
,
pad_slot_id
=
PAD_SLOT_ID
,
)
out_ref
=
causal_conv1d_update_ref
(
x_ref
[:
batch_size
],
conv_state_ref
,
weight
,
bias
,
activation
=
activation
)
assert
torch
.
equal
(
conv_state
[
conv_state_indices
,
:],
conv_state_ref
)
assert
torch
.
allclose
(
out
[:
batch_size
],
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
equal
(
conv_state
[
unused_states_bool
],
conv_state_for_padding_test
[
unused_states_bool
]
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
8
,
16
,
32
,
64
,
128
,
256
,
512
,
784
,
1024
,
2048
,
2049
,
4096
]
)
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
64
,
4096
])
# tests correctness in case subset of the sequences are padded
@
pytest
.
mark
.
parametrize
(
"with_padding"
,
[
True
,
False
])
def
test_causal_conv1d_varlen
(
with_padding
,
dim
,
seqlen
,
width
,
has_bias
,
silu_activation
,
itype
):
device
=
"cuda"
torch
.
cuda
.
empty_cache
()
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
seqlens
=
[]
batch_size
=
4
if
seqlen
<
10
:
batch_size
=
1
padding
=
3
if
with_padding
else
0
padded_batch_size
=
batch_size
+
padding
nsplits
=
padded_batch_size
-
1
eos_pos
=
torch
.
randperm
(
seqlen
-
1
)[:
nsplits
].
sort
().
values
seqlens
.
append
(
torch
.
diff
(
torch
.
cat
([
torch
.
tensor
([
-
1
]),
eos_pos
,
torch
.
tensor
([
seqlen
-
1
])])
).
tolist
()
)
assert
sum
(
seqlens
[
-
1
])
==
seqlen
assert
all
(
s
>
0
for
s
in
seqlens
[
-
1
])
total_entries
=
batch_size
*
10
cumsum
=
torch
.
cumsum
(
torch
.
tensor
(
seqlens
[
0
]),
dim
=
0
).
to
(
torch
.
int32
)
cumsum
=
torch
.
concat
([
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
),
cumsum
],
dim
=
0
)
x
=
torch
.
randn
(
1
,
4096
+
dim
+
64
,
seqlen
,
device
=
device
,
dtype
=
itype
)[
:,
4096
:
4096
+
dim
,
:
]
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
x_ref
=
x
.
clone
()
weight_ref
=
weight
.
clone
()
bias_ref
=
bias
.
clone
()
if
bias
is
not
None
else
None
activation
=
None
if
not
silu_activation
else
"silu"
final_states
=
torch
.
randn
(
total_entries
,
dim
,
width
-
1
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
final_states_ref
=
final_states
.
clone
()
has_initial_states
=
torch
.
randint
(
0
,
2
,
(
cumsum
.
shape
[
0
]
-
1
,),
dtype
=
torch
.
bool
,
device
=
x
.
device
)
state_indices
=
torch
.
randperm
(
total_entries
,
dtype
=
torch
.
int32
,
device
=
x
.
device
)[
:
batch_size
]
padded_state_indices
=
torch
.
concat
(
[
state_indices
,
torch
.
as_tensor
([
PAD_SLOT_ID
]
*
padding
,
dtype
=
torch
.
int32
,
device
=
device
),
],
dim
=-
1
,
)
out
=
causal_conv1d_fn
(
x
.
squeeze
(
0
),
weight
,
bias
,
cumsum
.
cuda
(),
padded_state_indices
,
has_initial_states
,
final_states
,
activation
,
PAD_SLOT_ID
,
)
out_ref
=
[]
out_ref_b
=
[]
splits
=
[
torch
.
split
(
var
,
seqlens
[
0
],
dim
=-
1
)
for
var
in
(
x_ref
)]
for
i
in
range
(
len
(
seqlens
[
0
])):
x_s
=
[
v
[
i
].
unsqueeze
(
0
)
for
v
in
splits
][
0
]
if
padded_state_indices
[
i
]
==
PAD_SLOT_ID
:
continue
out_ref_b
.
append
(
causal_conv1d_ref
(
x_s
,
weight_ref
,
bias_ref
,
activation
=
activation
,
return_final_states
=
True
,
final_states_out
=
final_states_ref
[
padded_state_indices
[
i
]].
unsqueeze
(
0
),
initial_states
=
(
final_states_ref
[
padded_state_indices
[
i
]].
unsqueeze
(
0
)
if
has_initial_states
[
i
]
else
None
),
)
)
out_ref
.
append
(
torch
.
cat
([
t
[
0
]
for
t
in
out_ref_b
],
dim
=
2
))
out_ref_tensor
=
torch
.
cat
(
out_ref
,
dim
=
0
)
unpadded_out
=
out
[:,
:
out_ref_tensor
.
shape
[
-
1
]]
assert
torch
.
allclose
(
unpadded_out
,
out_ref_tensor
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
final_states
[
state_indices
],
final_states_ref
[
state_indices
],
rtol
=
rtol
,
atol
=
atol
,
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment