Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8cbe1538
Unverified
Commit
8cbe1538
authored
Sep 10, 2025
by
Yi Zhang
Committed by
GitHub
Sep 09, 2025
Browse files
Add mamba kernel (#10234)
parent
8471e5e6
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1418 additions
and
0 deletions
+1418
-0
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+1
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+25
-0
sgl-kernel/csrc/mamba/causal_conv1d.cu
sgl-kernel/csrc/mamba/causal_conv1d.cu
+669
-0
sgl-kernel/csrc/mamba/causal_conv1d.h
sgl-kernel/csrc/mamba/causal_conv1d.h
+159
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+24
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-0
sgl-kernel/python/sgl_kernel/mamba.py
sgl-kernel/python/sgl_kernel/mamba.py
+50
-0
sgl-kernel/tests/test_causal_conv1d.py
sgl-kernel/tests/test_causal_conv1d.py
+489
-0
No files found.
sgl-kernel/CMakeLists.txt
View file @
8cbe1538
...
...
@@ -303,6 +303,7 @@ set(SOURCES
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
"csrc/moe/marlin_moe_wna16/ops.cu"
"csrc/mamba/causal_conv1d.cu"
"csrc/moe/moe_align_kernel.cu"
"csrc/moe/moe_fused_gate.cu"
"csrc/moe/moe_topk_softmax_kernels.cu"
...
...
sgl-kernel/csrc/common_extension.cc
View file @
8cbe1538
...
...
@@ -438,6 +438,31 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
impl
(
"copy_to_gpu_no_ce"
,
torch
::
kCUDA
,
&
copy_to_gpu_no_ce
);
m
.
def
(
"concat_mla_k(Tensor! k, Tensor k_nope, Tensor k_rope) -> ()"
);
m
.
impl
(
"concat_mla_k"
,
torch
::
kCUDA
,
&
concat_mla_k
);
/*
* From csrc/mamba
*/
m
.
def
(
"causal_conv1d_update(Tensor! x,"
"Tensor! conv_state,"
"Tensor! weight,"
"Tensor? bias_,"
"bool silu_activation,"
"Tensor? cache_seqlens_,"
"Tensor? conv_state_indices,"
"int pad_slot_id) -> ()"
);
m
.
impl
(
"causal_conv1d_update"
,
torch
::
kCUDA
,
&
causal_conv1d_update
);
m
.
def
(
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
"Tensor? bias_,"
"Tensor!? conv_states,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"bool silu_activation,"
"int pad_slot_id) -> ()"
);
m
.
impl
(
"causal_conv1d_fwd"
,
torch
::
kCUDA
,
&
causal_conv1d_fwd
);
}
REGISTER_EXTENSION
(
common_ops
)
sgl-kernel/csrc/mamba/causal_conv1d.cu
0 → 100644
View file @
8cbe1538
// clang-format off
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu
// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "causal_conv1d.h"
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
#include <cub/block/block_load.cuh>
#include <cub/block/block_store.cuh>
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
static constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
static constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
if (ITYPE == at::ScalarType::Half) { \
using input_t = at::Half; \
using weight_t = at::Half; \
__VA_ARGS__(); \
} else if (ITYPE == at::ScalarType::BFloat16) { \
using input_t = at::BFloat16; \
using weight_t = at::BFloat16; \
__VA_ARGS__(); \
} else if (ITYPE == at::ScalarType::Float) { \
using input_t = float; \
using weight_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
}
template
<
typename
input_t
,
typename
weight_t
>
void
causal_conv1d_fwd_cuda
(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
template
<
typename
input_t
,
typename
weight_t
>
void
causal_conv1d_update_cuda
(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
void
set_conv_params_fwd
(
ConvParamsBase
&
params
,
// sizes
const
size_t
batch
,
const
size_t
dim
,
const
size_t
seqlen
,
const
size_t
width
,
// device pointers
const
at
::
Tensor
x
,
const
at
::
Tensor
weight
,
const
at
::
Tensor
out
,
const
std
::
optional
<
at
::
Tensor
>&
bias
,
bool
silu_activation
,
int64_t
pad_slot_id
,
const
std
::
optional
<
at
::
Tensor
>&
query_start_loc
=
std
::
nullopt
,
const
std
::
optional
<
at
::
Tensor
>&
cache_indices
=
std
::
nullopt
,
const
std
::
optional
<
at
::
Tensor
>&
has_initial_state
=
std
::
nullopt
)
{
// Reset the parameters
memset
(
&
params
,
0
,
sizeof
(
params
));
params
.
batch
=
batch
;
params
.
dim
=
dim
;
params
.
seqlen
=
seqlen
;
params
.
width
=
width
;
params
.
pad_slot_id
=
pad_slot_id
;
params
.
silu_activation
=
silu_activation
;
// Set the pointers and strides.
params
.
x_ptr
=
x
.
data_ptr
();
params
.
weight_ptr
=
weight
.
data_ptr
();
params
.
bias_ptr
=
bias
.
has_value
()
?
bias
.
value
().
data_ptr
()
:
nullptr
;
params
.
out_ptr
=
out
.
data_ptr
();
// All stride are in elements, not bytes.
params
.
query_start_loc_ptr
=
query_start_loc
.
has_value
()
?
query_start_loc
.
value
().
data_ptr
()
:
nullptr
;
params
.
cache_indices_ptr
=
cache_indices
.
has_value
()
?
cache_indices
.
value
().
data_ptr
()
:
nullptr
;
params
.
has_initial_state_ptr
=
has_initial_state
.
has_value
()
?
has_initial_state
.
value
().
data_ptr
()
:
nullptr
;
const
bool
varlen
=
params
.
query_start_loc_ptr
!=
nullptr
;
params
.
x_batch_stride
=
x
.
stride
(
varlen
?
1
:
0
);
params
.
x_c_stride
=
x
.
stride
(
varlen
?
0
:
1
);
params
.
x_l_stride
=
x
.
stride
(
varlen
?
1
:
-
1
);
params
.
weight_c_stride
=
weight
.
stride
(
0
);
params
.
weight_width_stride
=
weight
.
stride
(
1
);
params
.
out_batch_stride
=
out
.
stride
(
varlen
?
1
:
0
);
params
.
out_c_stride
=
out
.
stride
(
varlen
?
0
:
1
);
params
.
out_l_stride
=
out
.
stride
(
varlen
?
1
:
-
1
);
}
void
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
std
::
optional
<
at
::
Tensor
>
&
bias_
,
const
std
::
optional
<
at
::
Tensor
>
&
conv_states
,
const
std
::
optional
<
at
::
Tensor
>
&
query_start_loc
,
const
std
::
optional
<
at
::
Tensor
>
&
cache_indices
,
const
std
::
optional
<
at
::
Tensor
>
&
has_initial_state
,
bool
silu_activation
,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t
pad_slot_id
)
{
auto
input_type
=
x
.
scalar_type
();
auto
weight_type
=
weight
.
scalar_type
();
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
TORCH_CHECK
(
weight_type
==
at
::
ScalarType
::
Float
||
weight_type
==
at
::
ScalarType
::
Half
||
weight_type
==
at
::
ScalarType
::
BFloat16
);
TORCH_CHECK
(
x
.
is_cuda
());
TORCH_CHECK
(
weight
.
is_cuda
());
const
bool
varlen
=
query_start_loc
.
has_value
()
?
true
:
false
;
const
auto
sizes
=
x
.
sizes
();
const
int
batch_size
=
varlen
?
query_start_loc
.
value
().
sizes
()[
0
]
-
1
:
sizes
[
0
];
const
int
dim
=
varlen
?
sizes
[
0
]
:
sizes
[
1
];
const
int
seqlen
=
varlen
?
sizes
[
1
]
:
sizes
[
2
];
const
int
width
=
weight
.
size
(
-
1
);
if
(
varlen
){
CHECK_SHAPE
(
x
,
dim
,
seqlen
);
}
else
{
CHECK_SHAPE
(
x
,
batch_size
,
dim
,
seqlen
);
}
CHECK_SHAPE
(
weight
,
dim
,
width
);
if
(
bias_
.
has_value
())
{
auto
bias
=
bias_
.
value
();
TORCH_CHECK
(
bias
.
scalar_type
()
==
weight_type
);
TORCH_CHECK
(
bias
.
is_cuda
());
TORCH_CHECK
(
bias
.
stride
(
-
1
)
==
1
);
CHECK_SHAPE
(
bias
,
dim
);
}
if
(
has_initial_state
.
has_value
())
{
auto
has_initial_state_
=
has_initial_state
.
value
();
TORCH_CHECK
(
has_initial_state_
.
scalar_type
()
==
at
::
ScalarType
::
Bool
);
TORCH_CHECK
(
has_initial_state_
.
is_cuda
());
CHECK_SHAPE
(
has_initial_state_
,
batch_size
);
}
if
(
query_start_loc
.
has_value
())
{
auto
query_start_loc_
=
query_start_loc
.
value
();
TORCH_CHECK
(
query_start_loc_
.
scalar_type
()
==
at
::
ScalarType
::
Int
);
TORCH_CHECK
(
query_start_loc_
.
is_cuda
());
}
if
(
cache_indices
.
has_value
())
{
auto
cache_indices_
=
cache_indices
.
value
();
TORCH_CHECK
(
cache_indices_
.
scalar_type
()
==
at
::
ScalarType
::
Int
);
TORCH_CHECK
(
cache_indices_
.
is_cuda
());
CHECK_SHAPE
(
cache_indices_
,
batch_size
);
}
at
::
Tensor
out
=
x
;
ConvParamsBase
params
;
set_conv_params_fwd
(
params
,
batch_size
,
dim
,
seqlen
,
width
,
x
,
weight
,
out
,
bias_
,
silu_activation
,
pad_slot_id
,
query_start_loc
,
cache_indices
,
has_initial_state
);
if
(
conv_states
.
has_value
())
{
auto
conv_states_
=
conv_states
.
value
();
TORCH_CHECK
(
conv_states_
.
scalar_type
()
==
input_type
);
TORCH_CHECK
(
conv_states_
.
is_cuda
());
params
.
conv_states_ptr
=
conv_states_
.
data_ptr
();
params
.
conv_states_batch_stride
=
conv_states_
.
stride
(
0
);
params
.
conv_states_c_stride
=
conv_states_
.
stride
(
1
);
params
.
conv_states_l_stride
=
conv_states_
.
stride
(
2
);
}
else
{
params
.
conv_states_ptr
=
nullptr
;
}
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
x
.
get_device
()};
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16
(
x
.
scalar_type
(),
"causal_conv1d_fwd"
,
[
&
]
{
causal_conv1d_fwd_cuda
<
input_t
,
weight_t
>
(
params
,
stream
);
});
}
void
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
std
::
optional
<
at
::
Tensor
>
&
bias_
,
bool
silu_activation
,
const
std
::
optional
<
at
::
Tensor
>
&
cache_seqlens_
,
const
std
::
optional
<
at
::
Tensor
>
&
conv_state_indices_
,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t
pad_slot_id
)
{
auto
input_type
=
x
.
scalar_type
();
auto
weight_type
=
weight
.
scalar_type
();
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
TORCH_CHECK
(
weight_type
==
at
::
ScalarType
::
Float
||
weight_type
==
at
::
ScalarType
::
Half
||
weight_type
==
at
::
ScalarType
::
BFloat16
);
TORCH_CHECK
(
weight_type
==
input_type
,
"weight type must equal to input type, other variations are disabled due to binary size limitations"
);
TORCH_CHECK
(
conv_state
.
scalar_type
()
==
input_type
);
TORCH_CHECK
(
x
.
is_cuda
());
TORCH_CHECK
(
conv_state
.
is_cuda
());
TORCH_CHECK
(
weight
.
is_cuda
());
const
auto
sizes
=
x
.
sizes
();
const
int
batch_size
=
sizes
[
0
];
const
int
dim
=
sizes
[
1
];
const
int
seqlen
=
sizes
[
2
];
const
int
width
=
weight
.
size
(
-
1
);
const
int
conv_state_len
=
conv_state
.
size
(
2
);
TORCH_CHECK
(
conv_state_len
>=
width
-
1
);
CHECK_SHAPE
(
x
,
batch_size
,
dim
,
seqlen
);
CHECK_SHAPE
(
weight
,
dim
,
width
);
TORCH_CHECK
(
width
>=
2
&&
width
<=
4
,
"causal_conv1d only supports width between 2 and 4"
);
if
(
bias_
.
has_value
())
{
auto
bias
=
bias_
.
value
();
TORCH_CHECK
(
bias
.
scalar_type
()
==
weight_type
);
TORCH_CHECK
(
bias
.
is_cuda
());
TORCH_CHECK
(
bias
.
stride
(
-
1
)
==
1
);
CHECK_SHAPE
(
bias
,
dim
);
}
at
::
Tensor
out
=
x
;
ConvParamsBase
params
;
set_conv_params_fwd
(
params
,
batch_size
,
dim
,
seqlen
,
width
,
x
,
weight
,
out
,
bias_
,
silu_activation
,
pad_slot_id
);
params
.
conv_state_ptr
=
conv_state
.
data_ptr
();
params
.
conv_state_len
=
conv_state_len
;
// All stride are in elements, not bytes.
params
.
conv_state_batch_stride
=
conv_state
.
stride
(
0
);
params
.
conv_state_c_stride
=
conv_state
.
stride
(
1
);
params
.
conv_state_l_stride
=
conv_state
.
stride
(
2
);
if
(
cache_seqlens_
.
has_value
())
{
auto
cache_seqlens
=
cache_seqlens_
.
value
();
TORCH_CHECK
(
cache_seqlens
.
scalar_type
()
==
torch
::
kInt32
);
TORCH_CHECK
(
cache_seqlens
.
is_cuda
());
TORCH_CHECK
(
cache_seqlens
.
stride
(
-
1
)
==
1
);
CHECK_SHAPE
(
cache_seqlens
,
batch_size
);
params
.
cache_seqlens
=
cache_seqlens
.
data_ptr
<
int32_t
>
();
}
else
{
params
.
cache_seqlens
=
nullptr
;
}
if
(
conv_state_indices_
.
has_value
())
{
auto
conv_state_indices
=
conv_state_indices_
.
value
();
TORCH_CHECK
(
conv_state_indices
.
scalar_type
()
==
torch
::
kInt32
)
TORCH_CHECK
(
conv_state_indices
.
is_cuda
());
TORCH_CHECK
(
conv_state_indices
.
stride
(
0
)
==
1
)
CHECK_SHAPE
(
conv_state_indices
,
batch_size
);
int
conv_state_entries
=
conv_state
.
size
(
0
);
CHECK_SHAPE
(
conv_state
,
conv_state_entries
,
dim
,
conv_state_len
);
params
.
conv_state_indices_ptr
=
conv_state_indices
.
data_ptr
<
int32_t
>
();
}
else
{
CHECK_SHAPE
(
conv_state
,
batch_size
,
dim
,
conv_state_len
);
params
.
conv_state_indices_ptr
=
nullptr
;
}
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
x
.
get_device
()};
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16
(
x
.
scalar_type
(),
"causal_conv1d_update"
,
[
&
]
{
causal_conv1d_update_cuda
<
input_t
,
weight_t
>
(
params
,
stream
);
});
}
template
<
int
kNThreads_
,
int
kWidth_
,
bool
kIsVecLoad_
,
typename
input_t_
,
typename
weight_t_
>
struct
Causal_conv1d_fwd_kernel_traits
{
using
input_t
=
input_t_
;
using
weight_t
=
weight_t_
;
static
constexpr
int
kNThreads
=
kNThreads_
;
static
constexpr
int
kWidth
=
kWidth_
;
static
constexpr
int
kNBytes
=
sizeof
(
input_t
);
static_assert
(
kNBytes
==
2
||
kNBytes
==
4
);
static
constexpr
int
kNElts
=
kNBytes
==
4
?
4
:
8
;
static_assert
(
kWidth
<=
kNElts
);
static
constexpr
bool
kIsVecLoad
=
kIsVecLoad_
;
using
vec_t
=
typename
BytesToType
<
kNBytes
*
kNElts
>::
Type
;
using
BlockLoadT
=
cub
::
BlockLoad
<
input_t
,
kNThreads
,
kNElts
,
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
>
;
using
BlockLoadVecT
=
cub
::
BlockLoad
<
vec_t
,
kNThreads
,
1
,
cub
::
BLOCK_LOAD_DIRECT
>
;
using
BlockStoreT
=
cub
::
BlockStore
<
input_t
,
kNThreads
,
kNElts
,
cub
::
BLOCK_STORE_WARP_TRANSPOSE
>
;
using
BlockStoreVecT
=
cub
::
BlockStore
<
vec_t
,
kNThreads
,
1
,
cub
::
BLOCK_STORE_DIRECT
>
;
static
constexpr
int
kSmemIOSize
=
kIsVecLoad
?
0
:
custom_max
({
sizeof
(
typename
BlockLoadT
::
TempStorage
),
sizeof
(
typename
BlockStoreT
::
TempStorage
)});
static
constexpr
int
kSmemExchangeSize
=
kNThreads
*
kNBytes
*
kNElts
;
static
constexpr
int
kSmemSize
=
kSmemIOSize
+
kSmemExchangeSize
;
};
template
<
typename
Ktraits
>
__global__
__launch_bounds__
(
Ktraits
::
kNThreads
)
void
causal_conv1d_fwd_kernel
(
ConvParamsBase
params
)
{
constexpr
int
kWidth
=
Ktraits
::
kWidth
;
constexpr
int
kNThreads
=
Ktraits
::
kNThreads
;
constexpr
int
kNElts
=
Ktraits
::
kNElts
;
constexpr
bool
kIsVecLoad
=
Ktraits
::
kIsVecLoad
;
using
input_t
=
typename
Ktraits
::
input_t
;
using
vec_t
=
typename
Ktraits
::
vec_t
;
using
weight_t
=
typename
Ktraits
::
weight_t
;
// Shared memory.
extern
__shared__
char
smem_
[];
auto
&
smem_load
=
reinterpret_cast
<
typename
Ktraits
::
BlockLoadT
::
TempStorage
&>
(
smem_
);
auto
&
smem_load_vec
=
reinterpret_cast
<
typename
Ktraits
::
BlockLoadVecT
::
TempStorage
&>
(
smem_
);
auto
&
smem_store
=
reinterpret_cast
<
typename
Ktraits
::
BlockStoreT
::
TempStorage
&>
(
smem_
);
auto
&
smem_store_vec
=
reinterpret_cast
<
typename
Ktraits
::
BlockStoreVecT
::
TempStorage
&>
(
smem_
);
vec_t
*
smem_exchange
=
reinterpret_cast
<
vec_t
*>
(
smem_
+
Ktraits
::
kSmemIOSize
);
const
bool
kVarlen
=
params
.
query_start_loc_ptr
!=
nullptr
;
const
int
tidx
=
threadIdx
.
x
;
const
int
batch_id
=
blockIdx
.
x
;
const
int
channel_id
=
blockIdx
.
y
;
const
int
*
query_start_loc
=
kVarlen
?
reinterpret_cast
<
int
*>
(
params
.
query_start_loc_ptr
)
:
nullptr
;
const
int
sequence_start_index
=
kVarlen
?
query_start_loc
[
batch_id
]
:
batch_id
;
const
int
seqlen
=
kVarlen
?
query_start_loc
[
batch_id
+
1
]
-
sequence_start_index
:
params
.
seqlen
;
input_t
*
x
=
reinterpret_cast
<
input_t
*>
(
params
.
x_ptr
)
+
sequence_start_index
*
params
.
x_batch_stride
+
channel_id
*
params
.
x_c_stride
;
weight_t
*
weight
=
reinterpret_cast
<
weight_t
*>
(
params
.
weight_ptr
)
+
channel_id
*
params
.
weight_c_stride
;
input_t
*
out
=
reinterpret_cast
<
input_t
*>
(
params
.
out_ptr
)
+
sequence_start_index
*
params
.
out_batch_stride
+
channel_id
*
params
.
out_c_stride
;
float
bias_val
=
params
.
bias_ptr
==
nullptr
?
0.
f
:
float
(
reinterpret_cast
<
weight_t
*>
(
params
.
bias_ptr
)[
channel_id
]);
bool
has_initial_state
=
params
.
has_initial_state_ptr
==
nullptr
?
false
:
reinterpret_cast
<
bool
*>
(
params
.
has_initial_state_ptr
)[
batch_id
];
int
*
cache_indices
=
params
.
cache_indices_ptr
==
nullptr
?
nullptr
:
reinterpret_cast
<
int
*>
(
params
.
cache_indices_ptr
);
int
cache_index
=
cache_indices
==
nullptr
?
batch_id
:
cache_indices
[
batch_id
];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if
(
cache_index
==
params
.
pad_slot_id
){
return
;
}
input_t
*
conv_states
=
params
.
conv_states_ptr
==
nullptr
?
nullptr
:
reinterpret_cast
<
input_t
*>
(
params
.
conv_states_ptr
)
+
cache_index
*
params
.
conv_states_batch_stride
+
channel_id
*
params
.
conv_states_c_stride
;
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
if
(
tidx
==
0
)
{
input_t
initial_state
[
kNElts
]
=
{
0
};
if
(
has_initial_state
)
{
#pragma unroll
for
(
int
w
=
0
;
w
<
kWidth
-
1
;
++
w
){
initial_state
[
kNElts
-
1
-
(
kWidth
-
2
)
+
w
]
=
conv_states
[
w
];
}
}
smem_exchange
[
kNThreads
-
1
]
=
reinterpret_cast
<
vec_t
*>
(
initial_state
)[
0
];
}
float
weight_vals
[
kWidth
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kWidth
;
++
i
)
{
weight_vals
[
i
]
=
float
(
weight
[
i
*
params
.
weight_width_stride
]);
}
constexpr
int
kChunkSize
=
kNThreads
*
kNElts
;
const
int
n_chunks
=
(
seqlen
+
kChunkSize
-
1
)
/
kChunkSize
;
for
(
int
chunk
=
0
;
chunk
<
n_chunks
;
++
chunk
)
{
input_t
x_vals_load
[
2
*
kNElts
]
=
{
0
};
if
constexpr
(
kIsVecLoad
)
{
typename
Ktraits
::
BlockLoadVecT
(
smem_load_vec
).
Load
(
reinterpret_cast
<
vec_t
*>
(
x
),
*
reinterpret_cast
<
vec_t
(
*
)[
1
]
>
(
&
x_vals_load
[
kNElts
]),
(
seqlen
-
chunk
*
kChunkSize
)
/
kNElts
);
}
else
{
__syncthreads
();
typename
Ktraits
::
BlockLoadT
(
smem_load
).
Load
(
x
,
*
reinterpret_cast
<
input_t
(
*
)[
kNElts
]
>
(
&
x_vals_load
[
kNElts
]),
seqlen
-
chunk
*
kChunkSize
);
}
x
+=
kChunkSize
;
__syncthreads
();
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
// the last elements of the previous chunk.
if
(
tidx
<
kNThreads
-
1
)
{
smem_exchange
[
tidx
]
=
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
1
];
}
__syncthreads
();
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
0
]
=
smem_exchange
[
tidx
>
0
?
tidx
-
1
:
kNThreads
-
1
];
__syncthreads
();
// Now thread kNThreads - 1 can write the last elements of the current chunk.
if
(
tidx
==
kNThreads
-
1
)
{
smem_exchange
[
tidx
]
=
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
1
];
}
float
x_vals
[
2
*
kNElts
];
#pragma unroll
for
(
int
i
=
0
;
i
<
2
*
kNElts
;
++
i
)
{
x_vals
[
i
]
=
float
(
x_vals_load
[
i
]);
}
float
out_vals
[
kNElts
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kNElts
;
++
i
)
{
out_vals
[
i
]
=
bias_val
;
#pragma unroll
for
(
int
w
=
0
;
w
<
kWidth
;
++
w
)
{
out_vals
[
i
]
+=
weight_vals
[
w
]
*
x_vals
[
kNElts
+
i
-
(
kWidth
-
w
-
1
)];
}
}
if
(
params
.
silu_activation
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
kNElts
;
++
i
)
{
out_vals
[
i
]
=
out_vals
[
i
]
/
(
1
+
expf
(
-
out_vals
[
i
]));
}
}
input_t
out_vals_store
[
kNElts
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kNElts
;
++
i
)
{
out_vals_store
[
i
]
=
out_vals
[
i
];
}
if
constexpr
(
kIsVecLoad
)
{
typename
Ktraits
::
BlockStoreVecT
(
smem_store_vec
).
Store
(
reinterpret_cast
<
vec_t
*>
(
out
),
reinterpret_cast
<
vec_t
(
&
)[
1
]
>
(
out_vals_store
),
(
seqlen
-
chunk
*
kChunkSize
)
/
kNElts
);
}
else
{
typename
Ktraits
::
BlockStoreT
(
smem_store
).
Store
(
out
,
out_vals_store
,
seqlen
-
chunk
*
kChunkSize
);
}
out
+=
kChunkSize
;
int
final_state_position
=
((
seqlen
-
(
kWidth
-
1
))
-
(
n_chunks
-
1
)
*
kChunkSize
);
// in case the final state is separated between the last "smem_exchange" and
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
// (which occurs when `final_state_position` is a non-positivie index)
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
if
(
conv_states
!=
nullptr
&&
final_state_position
<
0
&&
seqlen
>
kWidth
){
input_t
vals_load
[
kNElts
]
=
{
0
};
if
((
chunk
==
n_chunks
-
2
)
&&
(
tidx
==
kNThreads
-
1
)){
// chunk = n_chunks - 2, a segment of the final state sits in the last index
reinterpret_cast
<
vec_t
*>
(
vals_load
)[
0
]
=
smem_exchange
[
kNThreads
-
1
];
#pragma unroll
for
(
int
w
=
0
;
w
<
-
final_state_position
;
++
w
){
conv_states
[
w
]
=
vals_load
[
kNElts
+
final_state_position
+
w
];
}
}
if
((
chunk
==
n_chunks
-
1
)
&&
tidx
==
0
){
// chunk = n_chunks - 1, the second segment of the final state first positions
reinterpret_cast
<
vec_t
*>
(
vals_load
)[
0
]
=
smem_exchange
[
0
];
for
(
int
w
=
-
final_state_position
;
w
<
kWidth
-
1
;
++
w
){
conv_states
[
w
]
=
vals_load
[
w
+
final_state_position
];
}
return
;
}
}
}
// Final state is stored in the smem_exchange last token slot,
// in case seqlen < kWidth, we would need to take the final state from the
// initial state which is stored in conv_states
// in case seqlen > kWidth, we would need to load the last kWidth - 1 data
// and load it into conv_state accordingly
int
last_thread
=
((
seqlen
-
(
kWidth
-
1
))
-
(
n_chunks
-
1
)
*
kChunkSize
)
/
kNElts
;
if
(
conv_states
!=
nullptr
&&
tidx
==
last_thread
)
{
input_t
x_vals_load
[
kNElts
*
2
]
=
{
0
};
// in case we are on the first kWidth tokens
if
(
last_thread
==
0
&&
seqlen
<
kWidth
){
// Need to take the initial state
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
0
]
=
smem_exchange
[
0
];
const
int
offset
=
seqlen
-
(
kWidth
-
1
);
#pragma unroll
for
(
int
w
=
0
;
w
<
kWidth
-
1
;
++
w
){
// pad the existing state
if
((
w
-
seqlen
)
>=
0
&&
has_initial_state
)
{
conv_states
[
w
-
seqlen
]
=
conv_states
[
w
];
}
else
if
((
w
-
seqlen
)
>=
0
&&
!
has_initial_state
)
{
conv_states
[
w
-
seqlen
]
=
input_t
(
0.0
f
);
}
}
#pragma unroll
for
(
int
w
=
0
;
w
<
kWidth
-
1
;
++
w
){
if
(
offset
+
w
>=
0
)
conv_states
[
w
]
=
x_vals_load
[
offset
+
w
];
}
}
else
{
// in case the final state is in between the threads data
const
int
offset
=
((
seqlen
-
(
kWidth
-
1
))
%
(
kNElts
));
if
((
offset
+
kWidth
-
2
)
>=
kNElts
&&
(
last_thread
+
1
<
kNThreads
)){
// In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a
// illegal access error on H100.
// Therefore, we access last_thread + 1, only if the final state data sits there
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
1
]
=
smem_exchange
[
last_thread
+
1
];
}
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
0
]
=
smem_exchange
[
last_thread
];
#pragma unroll
for
(
int
w
=
0
;
w
<
kWidth
-
1
;
++
w
){
conv_states
[
w
]
=
x_vals_load
[
offset
+
w
];
}
}
}
}
template
<
int
kNThreads
,
int
kWidth
,
typename
input_t
,
typename
weight_t
>
void
causal_conv1d_fwd_launch
(
ConvParamsBase
&
params
,
cudaStream_t
stream
)
{
static
constexpr
int
kNElts
=
sizeof
(
input_t
)
==
4
?
4
:
8
;
const
bool
kVarlen
=
params
.
query_start_loc_ptr
!=
nullptr
;
BOOL_SWITCH
(
params
.
seqlen
%
kNElts
==
0
&&
!
kVarlen
,
kIsVecLoad
,
[
&
]
{
using
Ktraits
=
Causal_conv1d_fwd_kernel_traits
<
kNThreads
,
kWidth
,
kIsVecLoad
,
input_t
,
weight_t
>
;
constexpr
int
kSmemSize
=
Ktraits
::
kSmemSize
;
dim3
grid
(
params
.
batch
,
params
.
dim
);
auto
kernel
=
&
causal_conv1d_fwd_kernel
<
Ktraits
>
;
if
(
kSmemSize
>=
48
*
1024
)
{
#ifndef USE_ROCM
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kSmemSize
));
#else
// There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
(
void
*
)
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kSmemSize
));
std
::
cerr
<<
"Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior.
\n
"
<<
std
::
endl
;
#endif
}
kernel
<<<
grid
,
Ktraits
::
kNThreads
,
kSmemSize
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
}
template
<
typename
input_t
,
typename
weight_t
>
void
causal_conv1d_fwd_cuda
(
ConvParamsBase
&
params
,
cudaStream_t
stream
)
{
if
(
params
.
width
==
2
)
{
causal_conv1d_fwd_launch
<
128
,
2
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
width
==
3
)
{
causal_conv1d_fwd_launch
<
128
,
3
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
width
==
4
)
{
causal_conv1d_fwd_launch
<
128
,
4
,
input_t
,
weight_t
>
(
params
,
stream
);
}
}
template
void
causal_conv1d_fwd_cuda
<
float
,
float
>(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
template
void
causal_conv1d_fwd_cuda
<
at
::
Half
,
at
::
Half
>(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
template
void
causal_conv1d_fwd_cuda
<
at
::
BFloat16
,
at
::
BFloat16
>(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
template
<
int
kNThreads_
,
int
kWidth_
,
typename
input_t_
,
typename
weight_t_
>
struct
Causal_conv1d_update_kernel_traits
{
using
input_t
=
input_t_
;
using
weight_t
=
weight_t_
;
static
constexpr
int
kNThreads
=
kNThreads_
;
static
constexpr
int
kWidth
=
kWidth_
;
static
constexpr
int
kNBytes
=
sizeof
(
input_t
);
static_assert
(
kNBytes
==
2
||
kNBytes
==
4
);
};
template
<
typename
Ktraits
,
bool
kIsCircularBuffer
>
__global__
__launch_bounds__
(
Ktraits
::
kNThreads
)
void
causal_conv1d_update_kernel
(
ConvParamsBase
params
)
{
constexpr
int
kWidth
=
Ktraits
::
kWidth
;
constexpr
int
kNThreads
=
Ktraits
::
kNThreads
;
using
input_t
=
typename
Ktraits
::
input_t
;
using
weight_t
=
typename
Ktraits
::
weight_t
;
const
int
tidx
=
threadIdx
.
x
;
const
int
batch_id
=
blockIdx
.
x
;
const
int
channel_id
=
blockIdx
.
y
*
kNThreads
+
tidx
;
if
(
channel_id
>=
params
.
dim
)
return
;
input_t
*
x
=
reinterpret_cast
<
input_t
*>
(
params
.
x_ptr
)
+
batch_id
*
params
.
x_batch_stride
+
channel_id
*
params
.
x_c_stride
;
// If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
// along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
const
int
conv_state_batch_coord
=
params
.
conv_state_indices_ptr
==
nullptr
?
batch_id
:
params
.
conv_state_indices_ptr
[
batch_id
];
// conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early
if
(
conv_state_batch_coord
==
params
.
pad_slot_id
){
return
;
}
input_t
*
conv_state
=
reinterpret_cast
<
input_t
*>
(
params
.
conv_state_ptr
)
+
conv_state_batch_coord
*
params
.
conv_state_batch_stride
+
channel_id
*
params
.
conv_state_c_stride
;
weight_t
*
weight
=
reinterpret_cast
<
weight_t
*>
(
params
.
weight_ptr
)
+
channel_id
*
params
.
weight_c_stride
;
input_t
*
out
=
reinterpret_cast
<
input_t
*>
(
params
.
out_ptr
)
+
batch_id
*
params
.
out_batch_stride
+
channel_id
*
params
.
out_c_stride
;
float
bias_val
=
params
.
bias_ptr
==
nullptr
?
0.
f
:
float
(
reinterpret_cast
<
weight_t
*>
(
params
.
bias_ptr
)[
channel_id
]);
int
state_len
=
params
.
conv_state_len
;
int
advance_len
=
params
.
seqlen
;
int
cache_seqlen
=
kIsCircularBuffer
?
params
.
cache_seqlens
[
batch_id
]
%
state_len
:
0
;
int
update_idx
=
cache_seqlen
-
(
kWidth
-
1
);
update_idx
=
update_idx
<
0
?
update_idx
+
state_len
:
update_idx
;
float
weight_vals
[
kWidth
]
=
{
0
};
#pragma unroll
for
(
int
i
=
0
;
i
<
kWidth
;
++
i
)
{
weight_vals
[
i
]
=
float
(
weight
[
i
*
params
.
weight_width_stride
]);
}
float
x_vals
[
kWidth
]
=
{
0
};
if
constexpr
(
!
kIsCircularBuffer
)
{
#pragma unroll 2
for
(
int
i
=
0
;
i
<
state_len
-
advance_len
-
(
kWidth
-
1
);
++
i
)
{
conv_state
[
i
*
params
.
conv_state_l_stride
]
=
conv_state
[(
i
+
advance_len
)
*
params
.
conv_state_l_stride
];
}
#pragma unroll
for
(
int
i
=
0
;
i
<
kWidth
-
1
;
++
i
)
{
input_t
state_val
=
conv_state
[(
state_len
-
(
kWidth
-
1
)
+
i
)
*
params
.
conv_state_l_stride
];
if
(
i
<
advance_len
+
(
kWidth
-
1
)
&&
state_len
-
advance_len
-
(
kWidth
-
1
)
+
i
>=
0
)
{
conv_state
[(
state_len
-
advance_len
-
(
kWidth
-
1
)
+
i
)
*
params
.
conv_state_l_stride
]
=
state_val
;
}
x_vals
[
i
]
=
float
(
state_val
);
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
kWidth
-
1
;
++
i
,
update_idx
=
update_idx
+
1
>=
state_len
?
update_idx
+
1
-
state_len
:
update_idx
+
1
)
{
input_t
state_val
=
conv_state
[
update_idx
*
params
.
conv_state_l_stride
];
x_vals
[
i
]
=
float
(
state_val
);
}
}
#pragma unroll 2
for
(
int
i
=
0
;
i
<
params
.
seqlen
;
++
i
)
{
input_t
x_val
=
x
[
i
*
params
.
x_l_stride
];
if
constexpr
(
!
kIsCircularBuffer
)
{
if
(
i
<
advance_len
&&
state_len
-
advance_len
+
i
>=
0
)
{
conv_state
[(
state_len
-
advance_len
+
i
)
*
params
.
conv_state_l_stride
]
=
x_val
;
}
}
else
{
conv_state
[
update_idx
*
params
.
conv_state_l_stride
]
=
x_val
;
++
update_idx
;
update_idx
=
update_idx
>=
state_len
?
update_idx
-
state_len
:
update_idx
;
}
x_vals
[
kWidth
-
1
]
=
float
(
x_val
);
float
out_val
=
bias_val
;
#pragma unroll
for
(
int
j
=
0
;
j
<
kWidth
;
++
j
)
{
out_val
+=
weight_vals
[
j
]
*
x_vals
[
j
];
}
if
(
params
.
silu_activation
)
{
out_val
=
out_val
/
(
1
+
expf
(
-
out_val
));
}
out
[
i
*
params
.
out_l_stride
]
=
input_t
(
out_val
);
// Shift the input buffer by 1
#pragma unroll
for
(
int
i
=
0
;
i
<
kWidth
-
1
;
++
i
)
{
x_vals
[
i
]
=
x_vals
[
i
+
1
];
}
}
}
template
<
int
kNThreads
,
int
kWidth
,
typename
input_t
,
typename
weight_t
>
void
causal_conv1d_update_launch
(
ConvParamsBase
&
params
,
cudaStream_t
stream
)
{
using
Ktraits
=
Causal_conv1d_update_kernel_traits
<
kNThreads
,
kWidth
,
input_t
,
weight_t
>
;
dim3
grid
(
params
.
batch
,
(
params
.
dim
+
kNThreads
-
1
)
/
kNThreads
);
auto
kernel
=
params
.
cache_seqlens
==
nullptr
?
&
causal_conv1d_update_kernel
<
Ktraits
,
false
>
:
&
causal_conv1d_update_kernel
<
Ktraits
,
true
>
;
kernel
<<<
grid
,
Ktraits
::
kNThreads
,
0
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
template
<
typename
input_t
,
typename
weight_t
>
void
causal_conv1d_update_cuda
(
ConvParamsBase
&
params
,
cudaStream_t
stream
)
{
if
(
params
.
width
==
2
)
{
causal_conv1d_update_launch
<
64
,
2
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
width
==
3
)
{
causal_conv1d_update_launch
<
64
,
3
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
width
==
4
)
{
causal_conv1d_update_launch
<
64
,
4
,
input_t
,
weight_t
>
(
params
,
stream
);
}
}
template
void
causal_conv1d_update_cuda
<
float
,
float
>(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
template
void
causal_conv1d_update_cuda
<
at
::
Half
,
at
::
Half
>(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
template
void
causal_conv1d_update_cuda
<
at
::
BFloat16
,
at
::
BFloat16
>(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
sgl-kernel/csrc/mamba/causal_conv1d.h
0 → 100644
View file @
8cbe1538
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
// clang-format off
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
ConvParamsBase
{
using
index_t
=
uint32_t
;
int
batch
,
dim
,
seqlen
,
width
;
int64_t
pad_slot_id
;
bool
silu_activation
;
index_t
x_batch_stride
;
index_t
x_c_stride
;
index_t
x_l_stride
;
index_t
weight_c_stride
;
index_t
weight_width_stride
;
index_t
out_batch_stride
;
index_t
out_c_stride
;
index_t
out_l_stride
;
int
conv_state_len
;
index_t
conv_state_batch_stride
;
index_t
conv_state_c_stride
;
index_t
conv_state_l_stride
;
// Common data pointers.
void
*
__restrict__
x_ptr
;
void
*
__restrict__
weight_ptr
;
void
*
__restrict__
bias_ptr
;
void
*
__restrict__
out_ptr
;
void
*
__restrict__
conv_state_ptr
;
void
*
__restrict__
query_start_loc_ptr
;
void
*
__restrict__
has_initial_state_ptr
;
void
*
__restrict__
cache_indices_ptr
;
int32_t
*
__restrict__
cache_seqlens
;
// For the continuous batching case. Makes it so that the mamba state for
// the current batch doesn't need to be a contiguous tensor.
int32_t
*
__restrict__
conv_state_indices_ptr
;
void
*
__restrict__
seq_idx_ptr
;
// No __restrict__ since initial_states could be the same as final_states.
void
*
initial_states_ptr
;
index_t
initial_states_batch_stride
;
index_t
initial_states_l_stride
;
index_t
initial_states_c_stride
;
void
*
final_states_ptr
;
index_t
final_states_batch_stride
;
index_t
final_states_l_stride
;
index_t
final_states_c_stride
;
void
*
conv_states_ptr
;
index_t
conv_states_batch_stride
;
index_t
conv_states_l_stride
;
index_t
conv_states_c_stride
;
};
#ifndef USE_ROCM
#include <cuda_bf16.h>
template
<
typename
T
>
__device__
inline
T
shuffle_xor
(
T
val
,
int
offset
)
{
return
__shfl_xor_sync
(
uint32_t
(
-
1
),
val
,
offset
);
}
constexpr
size_t
custom_max
(
std
::
initializer_list
<
size_t
>
ilist
)
{
return
std
::
max
(
ilist
);
}
template
<
typename
T
>
constexpr
T
constexpr_min
(
T
a
,
T
b
)
{
return
std
::
min
(
a
,
b
);
}
#else
#include <hip/hip_bf16.h>
template
<
typename
T
>
__device__
inline
T
shuffle_xor
(
T
val
,
int
offset
)
{
return
__shfl_xor
(
val
,
offset
);
}
constexpr
size_t
custom_max
(
std
::
initializer_list
<
size_t
>
ilist
)
{
return
*
std
::
max_element
(
ilist
.
begin
(),
ilist
.
end
());
}
template
<
typename
T
>
constexpr
T
constexpr_min
(
T
a
,
T
b
)
{
return
a
<
b
?
a
:
b
;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
BYTES
>
struct
BytesToType
{};
template
<
>
struct
BytesToType
<
16
>
{
using
Type
=
uint4
;
static_assert
(
sizeof
(
Type
)
==
16
);
};
template
<
>
struct
BytesToType
<
8
>
{
using
Type
=
uint64_t
;
static_assert
(
sizeof
(
Type
)
==
8
);
};
template
<
>
struct
BytesToType
<
4
>
{
using
Type
=
uint32_t
;
static_assert
(
sizeof
(
Type
)
==
4
);
};
template
<
>
struct
BytesToType
<
2
>
{
using
Type
=
uint16_t
;
static_assert
(
sizeof
(
Type
)
==
2
);
};
template
<
>
struct
BytesToType
<
1
>
{
using
Type
=
uint8_t
;
static_assert
(
sizeof
(
Type
)
==
1
);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
SumOp
{
__device__
inline
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
+
y
;
}
};
template
<
int
THREADS
>
struct
Allreduce
{
static_assert
(
THREADS
==
32
||
THREADS
==
16
||
THREADS
==
8
||
THREADS
==
4
);
template
<
typename
T
,
typename
Operator
>
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
constexpr
int
OFFSET
=
THREADS
/
2
;
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
OFFSET
));
return
Allreduce
<
OFFSET
>::
run
(
x
,
op
);
}
};
template
<
>
struct
Allreduce
<
2
>
{
template
<
typename
T
,
typename
Operator
>
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
1
));
return
x
;
}
};
sgl-kernel/include/sgl_kernel_ops.h
View file @
8cbe1538
...
...
@@ -724,3 +724,27 @@ void store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc,
void
copy_to_gpu_no_ce
(
const
at
::
Tensor
&
input
,
at
::
Tensor
&
output
);
void
concat_mla_k
(
torch
::
Tensor
k
,
torch
::
Tensor
k_nope
,
torch
::
Tensor
k_rope
);
/*
* From csrc/mamba
*/
void
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
std
::
optional
<
at
::
Tensor
>&
bias_
,
bool
silu_activation
,
const
std
::
optional
<
at
::
Tensor
>&
cache_seqlens_
,
const
std
::
optional
<
at
::
Tensor
>&
conv_state_indices_
,
int64_t
pad_slot_id
);
void
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
std
::
optional
<
at
::
Tensor
>&
bias_
,
const
std
::
optional
<
at
::
Tensor
>&
conv_states
,
const
std
::
optional
<
at
::
Tensor
>&
query_start_loc
,
const
std
::
optional
<
at
::
Tensor
>&
cache_indices
,
const
std
::
optional
<
at
::
Tensor
>&
has_initial_state
,
bool
silu_activation
,
int64_t
pad_slot_id
);
sgl-kernel/python/sgl_kernel/__init__.py
View file @
8cbe1538
...
...
@@ -34,6 +34,7 @@ from sgl_kernel.elementwise import (
rmsnorm
,
silu_and_mul
,
)
from
sgl_kernel.mamba
import
causal_conv1d_fwd
,
causal_conv1d_update
if
torch
.
version
.
hip
is
not
None
:
from
sgl_kernel.elementwise
import
gelu_quick
...
...
sgl-kernel/python/sgl_kernel/mamba.py
0 → 100644
View file @
8cbe1538
from
typing
import
Optional
import
torch
# mamba
def
causal_conv1d_fwd
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
conv_states
:
Optional
[
torch
.
Tensor
],
query_start_loc
:
Optional
[
torch
.
Tensor
],
cache_indices
:
Optional
[
torch
.
Tensor
],
has_initial_state
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
,
pad_slot_id
:
int
,
):
torch
.
ops
.
sgl_kernel
.
causal_conv1d_fwd
(
x
,
weight
,
bias_
,
conv_states
,
query_start_loc
,
cache_indices
,
has_initial_state
,
silu_activation
,
pad_slot_id
,
)
def
causal_conv1d_update
(
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
,
cache_seqlens
:
Optional
[
torch
.
Tensor
],
conv_state_indices
:
Optional
[
torch
.
Tensor
],
pad_slot_id
:
int
,
):
torch
.
ops
.
sgl_kernel
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias_
,
silu_activation
,
cache_seqlens
,
conv_state_indices
,
pad_slot_id
,
)
sgl-kernel/tests/test_causal_conv1d.py
0 → 100644
View file @
8cbe1538
# Adapted from https://github.com/vllm-project/vllm/blob/main/tests/kernels/mamba/test_causal_conv1d.py
from
typing
import
Optional
import
torch
from
sgl_kernel
import
causal_conv1d_fwd
from
sgl_kernel
import
causal_conv1d_update
as
causal_conv1d_update_kernel
PAD_SLOT_ID
=
-
1
def
causal_conv1d_fn
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
has_initial_state
:
Optional
[
torch
.
Tensor
]
=
None
,
conv_states
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
"silu"
,
pad_slot_id
:
int
=
PAD_SLOT_ID
,
):
"""
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
sequences are concatenated from left to right for varlen
weight: (dim, width)
bias: (dim,)
query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended by 0.
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
indicates the corresponding state index,
like so: conv_state = conv_states[cache_indices[batch_id]]
has_initial_state: (batch) bool
indicates whether should the kernel take the current state as initial
state for the calculations
conv_states: (...,dim,width - 1) itype
updated inplace if provided
activation: either None or "silu" or "swish"
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim, seqlen)
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
if
x
.
stride
(
-
1
)
!=
1
:
x
=
x
.
contiguous
()
bias
=
bias
.
contiguous
()
if
bias
is
not
None
else
None
causal_conv1d_fwd
(
x
,
weight
,
bias
,
conv_states
,
query_start_loc
,
cache_indices
,
has_initial_state
,
activation
in
[
"silu"
,
"swish"
],
pad_slot_id
,
)
return
x
def
causal_conv1d_update
(
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
None
,
cache_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
conv_state_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
pad_slot_id
:
int
=
PAD_SLOT_ID
,
):
"""
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the conv_state
starting at the index
@cache_seqlens % state_len.
conv_state_indices: (batch,), dtype int32
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen)
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
raise
NotImplementedError
(
f
"activation must be None, silu, or swish, actual:
{
activation
}
"
)
activation_val
=
activation
in
[
"silu"
,
"swish"
]
unsqueeze
=
x
.
dim
()
==
2
if
unsqueeze
:
x
=
x
.
unsqueeze
(
-
1
)
causal_conv1d_update_kernel
(
x
,
conv_state
,
weight
,
bias
,
activation_val
,
cache_seqlens
,
conv_state_indices
,
pad_slot_id
,
)
if
unsqueeze
:
x
=
x
.
squeeze
(
-
1
)
return
x
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
import
pytest
import
torch
import
torch.nn.functional
as
F
def
causal_conv1d_ref
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
initial_states
:
Optional
[
torch
.
Tensor
]
=
None
,
return_final_states
:
bool
=
False
,
final_states_out
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
"silu"
,
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1)
out: (batch, dim, seqlen)
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
dtype_in
=
x
.
dtype
x
=
x
.
to
(
weight
.
dtype
)
seqlen
=
x
.
shape
[
-
1
]
dim
,
width
=
weight
.
shape
if
initial_states
is
None
:
out
=
F
.
conv1d
(
x
,
weight
.
unsqueeze
(
1
),
bias
,
padding
=
width
-
1
,
groups
=
dim
)
else
:
x
=
torch
.
cat
([
initial_states
,
x
],
dim
=-
1
)
out
=
F
.
conv1d
(
x
,
weight
.
unsqueeze
(
1
),
bias
,
padding
=
0
,
groups
=
dim
)
out
=
out
[...,
:
seqlen
]
if
return_final_states
:
final_states
=
F
.
pad
(
x
,
(
width
-
1
-
x
.
shape
[
-
1
],
0
)).
to
(
dtype_in
)
# (batch, dim, width - 1)
if
final_states_out
is
not
None
:
final_states_out
.
copy_
(
final_states
)
else
:
final_states_out
=
final_states
out
=
(
out
if
activation
is
None
else
F
.
silu
(
out
)).
to
(
dtype
=
dtype_in
)
return
(
out
,
None
)
if
not
return_final_states
else
(
out
,
final_states_out
)
def
causal_conv1d_update_ref
(
x
,
conv_state
,
weight
,
bias
=
None
,
activation
=
None
,
cache_seqlens
=
None
):
"""
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the
conv_state starting at the index
@cache_seqlens % state_len before performing the convolution.
out: (batch, dim) or (batch, dim, seqlen)
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
dtype_in
=
x
.
dtype
unsqueeze
=
x
.
dim
()
==
2
if
unsqueeze
:
x
=
x
.
unsqueeze
(
-
1
)
batch
,
dim
,
seqlen
=
x
.
shape
width
=
weight
.
shape
[
1
]
state_len
=
conv_state
.
shape
[
-
1
]
assert
conv_state
.
shape
==
(
batch
,
dim
,
state_len
)
assert
weight
.
shape
==
(
dim
,
width
)
if
cache_seqlens
is
None
:
x_new
=
torch
.
cat
([
conv_state
,
x
],
dim
=-
1
).
to
(
weight
.
dtype
)
# (batch, dim, state_len + seqlen)
conv_state
.
copy_
(
x_new
[:,
:,
-
state_len
:])
else
:
width_idx
=
torch
.
arange
(
-
(
width
-
1
),
0
,
dtype
=
torch
.
long
,
device
=
x
.
device
).
unsqueeze
(
0
)
+
cache_seqlens
.
unsqueeze
(
1
)
width_idx
=
(
torch
.
remainder
(
width_idx
,
state_len
).
unsqueeze
(
1
).
expand
(
-
1
,
dim
,
-
1
)
)
x_new
=
torch
.
cat
([
conv_state
.
gather
(
2
,
width_idx
),
x
],
dim
=-
1
).
to
(
weight
.
dtype
)
copy_idx
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
x
.
device
).
unsqueeze
(
0
)
+
cache_seqlens
.
unsqueeze
(
1
)
copy_idx
=
torch
.
remainder
(
copy_idx
,
state_len
).
unsqueeze
(
1
).
expand
(
-
1
,
dim
,
-
1
)
conv_state
.
scatter_
(
2
,
copy_idx
,
x
)
out
=
F
.
conv1d
(
x_new
,
weight
.
unsqueeze
(
1
),
bias
,
padding
=
0
,
groups
=
dim
)[
:,
:,
-
seqlen
:
]
if
unsqueeze
:
out
=
out
.
squeeze
(
-
1
)
return
(
out
if
activation
is
None
else
F
.
silu
(
out
)).
to
(
dtype
=
dtype_in
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
,
torch
.
float
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_initial_state"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
1
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
784
,
1024
,
1025
,
2048
,
4096
]
)
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"batch"
,
[
1
])
def
test_causal_conv1d
(
batch
,
dim
,
seqlen
,
width
,
has_bias
,
silu_activation
,
has_initial_state
,
itype
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
x
=
torch
.
randn
(
batch
,
dim
,
seqlen
,
device
=
device
,
dtype
=
itype
).
contiguous
()
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
if
has_initial_state
:
initial_states
=
torch
.
randn
(
batch
,
dim
,
width
-
1
,
device
=
device
,
dtype
=
itype
)
has_initial_state_tensor
=
torch
.
ones
(
batch
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
else
:
initial_states
=
None
has_initial_state_tensor
=
None
x_ref
=
x
.
clone
()
weight_ref
=
weight
.
clone
()
bias_ref
=
bias
.
clone
()
if
bias
is
not
None
else
None
initial_states_ref
=
initial_states
.
clone
()
if
initial_states
is
not
None
else
None
activation
=
None
if
not
silu_activation
else
"silu"
out
=
causal_conv1d_fn
(
x
,
weight
,
bias
,
activation
=
activation
,
conv_states
=
initial_states
,
has_initial_state
=
has_initial_state_tensor
,
)
out_ref
,
final_states_ref
=
causal_conv1d_ref
(
x_ref
,
weight_ref
,
bias_ref
,
initial_states
=
initial_states_ref
,
return_final_states
=
True
,
activation
=
activation
,
)
if
has_initial_state
:
assert
initial_states
is
not
None
and
final_states_ref
is
not
None
assert
torch
.
allclose
(
initial_states
,
final_states_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
2048
+
16
,
4096
])
def
test_causal_conv1d_update
(
dim
,
width
,
seqlen
,
has_bias
,
silu_activation
,
itype
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
batch
=
2
x
=
torch
.
randn
(
batch
,
dim
,
seqlen
,
device
=
device
,
dtype
=
itype
)
x_ref
=
x
.
clone
()
conv_state
=
torch
.
randn
(
batch
,
dim
,
width
-
1
,
device
=
device
,
dtype
=
itype
)
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
conv_state_ref
=
conv_state
.
detach
().
clone
()
activation
=
None
if
not
silu_activation
else
"silu"
out
=
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias
,
activation
=
activation
)
out_ref
=
causal_conv1d_update_ref
(
x_ref
,
conv_state_ref
,
weight
,
bias
,
activation
=
activation
)
assert
torch
.
equal
(
conv_state
,
conv_state_ref
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
1
,
4
,
5
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
2
,
3
,
4
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
2048
+
16
,
4096
])
# tests correctness in case subset of the sequences are padded
@
pytest
.
mark
.
parametrize
(
"with_padding"
,
[
True
,
False
])
def
test_causal_conv1d_update_with_batch_gather
(
with_padding
,
dim
,
width
,
seqlen
,
has_bias
,
silu_activation
,
itype
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
batch_size
=
3
padding
=
5
if
with_padding
else
0
padded_batch_size
=
batch_size
+
padding
total_entries
=
10
*
batch_size
x
=
torch
.
randn
(
padded_batch_size
,
dim
,
1
,
device
=
device
,
dtype
=
itype
)
x_ref
=
x
.
clone
()
conv_state_indices
=
torch
.
randperm
(
total_entries
)[:
batch_size
].
to
(
dtype
=
torch
.
int32
,
device
=
device
)
unused_states_bool
=
torch
.
ones
(
total_entries
,
dtype
=
torch
.
bool
,
device
=
device
)
unused_states_bool
[
conv_state_indices
]
=
False
padded_state_indices
=
torch
.
concat
(
[
conv_state_indices
,
torch
.
as_tensor
([
PAD_SLOT_ID
]
*
padding
,
dtype
=
torch
.
int32
,
device
=
device
),
],
dim
=
0
,
)
conv_state
=
torch
.
randn
(
total_entries
,
dim
,
width
-
1
,
device
=
device
,
dtype
=
itype
)
conv_state_for_padding_test
=
conv_state
.
clone
()
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
conv_state_ref
=
conv_state
[
conv_state_indices
,
:].
detach
().
clone
()
activation
=
None
if
not
silu_activation
else
"silu"
out
=
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias
,
activation
=
activation
,
conv_state_indices
=
padded_state_indices
,
pad_slot_id
=
PAD_SLOT_ID
,
)
out_ref
=
causal_conv1d_update_ref
(
x_ref
[:
batch_size
],
conv_state_ref
,
weight
,
bias
,
activation
=
activation
)
assert
torch
.
equal
(
conv_state
[
conv_state_indices
,
:],
conv_state_ref
)
assert
torch
.
allclose
(
out
[:
batch_size
],
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
equal
(
conv_state
[
unused_states_bool
],
conv_state_for_padding_test
[
unused_states_bool
]
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
8
,
16
,
32
,
64
,
128
,
256
,
512
,
784
,
1024
,
2048
,
2049
,
4096
]
)
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
64
,
4096
])
# tests correctness in case subset of the sequences are padded
@
pytest
.
mark
.
parametrize
(
"with_padding"
,
[
True
,
False
])
def
test_causal_conv1d_varlen
(
with_padding
,
dim
,
seqlen
,
width
,
has_bias
,
silu_activation
,
itype
):
device
=
"cuda"
torch
.
cuda
.
empty_cache
()
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
seqlens
=
[]
batch_size
=
4
if
seqlen
<
10
:
batch_size
=
1
padding
=
3
if
with_padding
else
0
padded_batch_size
=
batch_size
+
padding
nsplits
=
padded_batch_size
-
1
eos_pos
=
torch
.
randperm
(
seqlen
-
1
)[:
nsplits
].
sort
().
values
seqlens
.
append
(
torch
.
diff
(
torch
.
cat
([
torch
.
tensor
([
-
1
]),
eos_pos
,
torch
.
tensor
([
seqlen
-
1
])])
).
tolist
()
)
assert
sum
(
seqlens
[
-
1
])
==
seqlen
assert
all
(
s
>
0
for
s
in
seqlens
[
-
1
])
total_entries
=
batch_size
*
10
cumsum
=
torch
.
cumsum
(
torch
.
tensor
(
seqlens
[
0
]),
dim
=
0
).
to
(
torch
.
int32
)
cumsum
=
torch
.
concat
([
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
),
cumsum
],
dim
=
0
)
x
=
torch
.
randn
(
1
,
4096
+
dim
+
64
,
seqlen
,
device
=
device
,
dtype
=
itype
)[
:,
4096
:
4096
+
dim
,
:
]
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
x_ref
=
x
.
clone
()
weight_ref
=
weight
.
clone
()
bias_ref
=
bias
.
clone
()
if
bias
is
not
None
else
None
activation
=
None
if
not
silu_activation
else
"silu"
final_states
=
torch
.
randn
(
total_entries
,
dim
,
width
-
1
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
final_states_ref
=
final_states
.
clone
()
has_initial_states
=
torch
.
randint
(
0
,
2
,
(
cumsum
.
shape
[
0
]
-
1
,),
dtype
=
torch
.
bool
,
device
=
x
.
device
)
state_indices
=
torch
.
randperm
(
total_entries
,
dtype
=
torch
.
int32
,
device
=
x
.
device
)[
:
batch_size
]
padded_state_indices
=
torch
.
concat
(
[
state_indices
,
torch
.
as_tensor
([
PAD_SLOT_ID
]
*
padding
,
dtype
=
torch
.
int32
,
device
=
device
),
],
dim
=-
1
,
)
out
=
causal_conv1d_fn
(
x
.
squeeze
(
0
),
weight
,
bias
,
cumsum
.
cuda
(),
padded_state_indices
,
has_initial_states
,
final_states
,
activation
,
PAD_SLOT_ID
,
)
out_ref
=
[]
out_ref_b
=
[]
splits
=
[
torch
.
split
(
var
,
seqlens
[
0
],
dim
=-
1
)
for
var
in
(
x_ref
)]
for
i
in
range
(
len
(
seqlens
[
0
])):
x_s
=
[
v
[
i
].
unsqueeze
(
0
)
for
v
in
splits
][
0
]
if
padded_state_indices
[
i
]
==
PAD_SLOT_ID
:
continue
out_ref_b
.
append
(
causal_conv1d_ref
(
x_s
,
weight_ref
,
bias_ref
,
activation
=
activation
,
return_final_states
=
True
,
final_states_out
=
final_states_ref
[
padded_state_indices
[
i
]].
unsqueeze
(
0
),
initial_states
=
(
final_states_ref
[
padded_state_indices
[
i
]].
unsqueeze
(
0
)
if
has_initial_states
[
i
]
else
None
),
)
)
out_ref
.
append
(
torch
.
cat
([
t
[
0
]
for
t
in
out_ref_b
],
dim
=
2
))
out_ref_tensor
=
torch
.
cat
(
out_ref
,
dim
=
0
)
unpadded_out
=
out
[:,
:
out_ref_tensor
.
shape
[
-
1
]]
assert
torch
.
allclose
(
unpadded_out
,
out_ref_tensor
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
final_states
[
state_indices
],
final_states_ref
[
state_indices
],
rtol
=
rtol
,
atol
=
atol
,
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment