Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8cbe1538
"...blobs/bbbcb9f65616524d6199fa3bc16dc0500fb2cbbb" did not exist on "fead3ba3867d09a2ac0e21a2e7395be5d70c02d1"
Unverified
Commit
8cbe1538
authored
Sep 10, 2025
by
Yi Zhang
Committed by
GitHub
Sep 09, 2025
Browse files
Add mamba kernel (#10234)
parent
8471e5e6
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1418 additions
and
0 deletions
+1418
-0
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+1
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+25
-0
sgl-kernel/csrc/mamba/causal_conv1d.cu
sgl-kernel/csrc/mamba/causal_conv1d.cu
+669
-0
sgl-kernel/csrc/mamba/causal_conv1d.h
sgl-kernel/csrc/mamba/causal_conv1d.h
+159
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+24
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-0
sgl-kernel/python/sgl_kernel/mamba.py
sgl-kernel/python/sgl_kernel/mamba.py
+50
-0
sgl-kernel/tests/test_causal_conv1d.py
sgl-kernel/tests/test_causal_conv1d.py
+489
-0
No files found.
sgl-kernel/CMakeLists.txt
View file @
8cbe1538
...
...
@@ -303,6 +303,7 @@ set(SOURCES
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
"csrc/moe/marlin_moe_wna16/ops.cu"
"csrc/mamba/causal_conv1d.cu"
"csrc/moe/moe_align_kernel.cu"
"csrc/moe/moe_fused_gate.cu"
"csrc/moe/moe_topk_softmax_kernels.cu"
...
...
sgl-kernel/csrc/common_extension.cc
View file @
8cbe1538
...
...
@@ -438,6 +438,31 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
impl
(
"copy_to_gpu_no_ce"
,
torch
::
kCUDA
,
&
copy_to_gpu_no_ce
);
m
.
def
(
"concat_mla_k(Tensor! k, Tensor k_nope, Tensor k_rope) -> ()"
);
m
.
impl
(
"concat_mla_k"
,
torch
::
kCUDA
,
&
concat_mla_k
);
/*
* From csrc/mamba
*/
m
.
def
(
"causal_conv1d_update(Tensor! x,"
"Tensor! conv_state,"
"Tensor! weight,"
"Tensor? bias_,"
"bool silu_activation,"
"Tensor? cache_seqlens_,"
"Tensor? conv_state_indices,"
"int pad_slot_id) -> ()"
);
m
.
impl
(
"causal_conv1d_update"
,
torch
::
kCUDA
,
&
causal_conv1d_update
);
m
.
def
(
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
"Tensor? bias_,"
"Tensor!? conv_states,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"bool silu_activation,"
"int pad_slot_id) -> ()"
);
m
.
impl
(
"causal_conv1d_fwd"
,
torch
::
kCUDA
,
&
causal_conv1d_fwd
);
}
REGISTER_EXTENSION
(
common_ops
)
sgl-kernel/csrc/mamba/causal_conv1d.cu
0 → 100644
View file @
8cbe1538
// clang-format off
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu
// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "causal_conv1d.h"
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
#include <cub/block/block_load.cuh>
#include <cub/block/block_store.cuh>
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
static constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
static constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
if (ITYPE == at::ScalarType::Half) { \
using input_t = at::Half; \
using weight_t = at::Half; \
__VA_ARGS__(); \
} else if (ITYPE == at::ScalarType::BFloat16) { \
using input_t = at::BFloat16; \
using weight_t = at::BFloat16; \
__VA_ARGS__(); \
} else if (ITYPE == at::ScalarType::Float) { \
using input_t = float; \
using weight_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
}
template
<
typename
input_t
,
typename
weight_t
>
void
causal_conv1d_fwd_cuda
(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
template
<
typename
input_t
,
typename
weight_t
>
void
causal_conv1d_update_cuda
(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
void
set_conv_params_fwd
(
ConvParamsBase
&
params
,
// sizes
const
size_t
batch
,
const
size_t
dim
,
const
size_t
seqlen
,
const
size_t
width
,
// device pointers
const
at
::
Tensor
x
,
const
at
::
Tensor
weight
,
const
at
::
Tensor
out
,
const
std
::
optional
<
at
::
Tensor
>&
bias
,
bool
silu_activation
,
int64_t
pad_slot_id
,
const
std
::
optional
<
at
::
Tensor
>&
query_start_loc
=
std
::
nullopt
,
const
std
::
optional
<
at
::
Tensor
>&
cache_indices
=
std
::
nullopt
,
const
std
::
optional
<
at
::
Tensor
>&
has_initial_state
=
std
::
nullopt
)
{
// Reset the parameters
memset
(
&
params
,
0
,
sizeof
(
params
));
params
.
batch
=
batch
;
params
.
dim
=
dim
;
params
.
seqlen
=
seqlen
;
params
.
width
=
width
;
params
.
pad_slot_id
=
pad_slot_id
;
params
.
silu_activation
=
silu_activation
;
// Set the pointers and strides.
params
.
x_ptr
=
x
.
data_ptr
();
params
.
weight_ptr
=
weight
.
data_ptr
();
params
.
bias_ptr
=
bias
.
has_value
()
?
bias
.
value
().
data_ptr
()
:
nullptr
;
params
.
out_ptr
=
out
.
data_ptr
();
// All stride are in elements, not bytes.
params
.
query_start_loc_ptr
=
query_start_loc
.
has_value
()
?
query_start_loc
.
value
().
data_ptr
()
:
nullptr
;
params
.
cache_indices_ptr
=
cache_indices
.
has_value
()
?
cache_indices
.
value
().
data_ptr
()
:
nullptr
;
params
.
has_initial_state_ptr
=
has_initial_state
.
has_value
()
?
has_initial_state
.
value
().
data_ptr
()
:
nullptr
;
const
bool
varlen
=
params
.
query_start_loc_ptr
!=
nullptr
;
params
.
x_batch_stride
=
x
.
stride
(
varlen
?
1
:
0
);
params
.
x_c_stride
=
x
.
stride
(
varlen
?
0
:
1
);
params
.
x_l_stride
=
x
.
stride
(
varlen
?
1
:
-
1
);
params
.
weight_c_stride
=
weight
.
stride
(
0
);
params
.
weight_width_stride
=
weight
.
stride
(
1
);
params
.
out_batch_stride
=
out
.
stride
(
varlen
?
1
:
0
);
params
.
out_c_stride
=
out
.
stride
(
varlen
?
0
:
1
);
params
.
out_l_stride
=
out
.
stride
(
varlen
?
1
:
-
1
);
}
void
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
std
::
optional
<
at
::
Tensor
>
&
bias_
,
const
std
::
optional
<
at
::
Tensor
>
&
conv_states
,
const
std
::
optional
<
at
::
Tensor
>
&
query_start_loc
,
const
std
::
optional
<
at
::
Tensor
>
&
cache_indices
,
const
std
::
optional
<
at
::
Tensor
>
&
has_initial_state
,
bool
silu_activation
,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t
pad_slot_id
)
{
auto
input_type
=
x
.
scalar_type
();
auto
weight_type
=
weight
.
scalar_type
();
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
TORCH_CHECK
(
weight_type
==
at
::
ScalarType
::
Float
||
weight_type
==
at
::
ScalarType
::
Half
||
weight_type
==
at
::
ScalarType
::
BFloat16
);
TORCH_CHECK
(
x
.
is_cuda
());
TORCH_CHECK
(
weight
.
is_cuda
());
const
bool
varlen
=
query_start_loc
.
has_value
()
?
true
:
false
;
const
auto
sizes
=
x
.
sizes
();
const
int
batch_size
=
varlen
?
query_start_loc
.
value
().
sizes
()[
0
]
-
1
:
sizes
[
0
];
const
int
dim
=
varlen
?
sizes
[
0
]
:
sizes
[
1
];
const
int
seqlen
=
varlen
?
sizes
[
1
]
:
sizes
[
2
];
const
int
width
=
weight
.
size
(
-
1
);
if
(
varlen
){
CHECK_SHAPE
(
x
,
dim
,
seqlen
);
}
else
{
CHECK_SHAPE
(
x
,
batch_size
,
dim
,
seqlen
);
}
CHECK_SHAPE
(
weight
,
dim
,
width
);
if
(
bias_
.
has_value
())
{
auto
bias
=
bias_
.
value
();
TORCH_CHECK
(
bias
.
scalar_type
()
==
weight_type
);
TORCH_CHECK
(
bias
.
is_cuda
());
TORCH_CHECK
(
bias
.
stride
(
-
1
)
==
1
);
CHECK_SHAPE
(
bias
,
dim
);
}
if
(
has_initial_state
.
has_value
())
{
auto
has_initial_state_
=
has_initial_state
.
value
();
TORCH_CHECK
(
has_initial_state_
.
scalar_type
()
==
at
::
ScalarType
::
Bool
);
TORCH_CHECK
(
has_initial_state_
.
is_cuda
());
CHECK_SHAPE
(
has_initial_state_
,
batch_size
);
}
if
(
query_start_loc
.
has_value
())
{
auto
query_start_loc_
=
query_start_loc
.
value
();
TORCH_CHECK
(
query_start_loc_
.
scalar_type
()
==
at
::
ScalarType
::
Int
);
TORCH_CHECK
(
query_start_loc_
.
is_cuda
());
}
if
(
cache_indices
.
has_value
())
{
auto
cache_indices_
=
cache_indices
.
value
();
TORCH_CHECK
(
cache_indices_
.
scalar_type
()
==
at
::
ScalarType
::
Int
);
TORCH_CHECK
(
cache_indices_
.
is_cuda
());
CHECK_SHAPE
(
cache_indices_
,
batch_size
);
}
at
::
Tensor
out
=
x
;
ConvParamsBase
params
;
set_conv_params_fwd
(
params
,
batch_size
,
dim
,
seqlen
,
width
,
x
,
weight
,
out
,
bias_
,
silu_activation
,
pad_slot_id
,
query_start_loc
,
cache_indices
,
has_initial_state
);
if
(
conv_states
.
has_value
())
{
auto
conv_states_
=
conv_states
.
value
();
TORCH_CHECK
(
conv_states_
.
scalar_type
()
==
input_type
);
TORCH_CHECK
(
conv_states_
.
is_cuda
());
params
.
conv_states_ptr
=
conv_states_
.
data_ptr
();
params
.
conv_states_batch_stride
=
conv_states_
.
stride
(
0
);
params
.
conv_states_c_stride
=
conv_states_
.
stride
(
1
);
params
.
conv_states_l_stride
=
conv_states_
.
stride
(
2
);
}
else
{
params
.
conv_states_ptr
=
nullptr
;
}
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
x
.
get_device
()};
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16
(
x
.
scalar_type
(),
"causal_conv1d_fwd"
,
[
&
]
{
causal_conv1d_fwd_cuda
<
input_t
,
weight_t
>
(
params
,
stream
);
});
}
void
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
std
::
optional
<
at
::
Tensor
>
&
bias_
,
bool
silu_activation
,
const
std
::
optional
<
at
::
Tensor
>
&
cache_seqlens_
,
const
std
::
optional
<
at
::
Tensor
>
&
conv_state_indices_
,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t
pad_slot_id
)
{
auto
input_type
=
x
.
scalar_type
();
auto
weight_type
=
weight
.
scalar_type
();
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
TORCH_CHECK
(
weight_type
==
at
::
ScalarType
::
Float
||
weight_type
==
at
::
ScalarType
::
Half
||
weight_type
==
at
::
ScalarType
::
BFloat16
);
TORCH_CHECK
(
weight_type
==
input_type
,
"weight type must equal to input type, other variations are disabled due to binary size limitations"
);
TORCH_CHECK
(
conv_state
.
scalar_type
()
==
input_type
);
TORCH_CHECK
(
x
.
is_cuda
());
TORCH_CHECK
(
conv_state
.
is_cuda
());
TORCH_CHECK
(
weight
.
is_cuda
());
const
auto
sizes
=
x
.
sizes
();
const
int
batch_size
=
sizes
[
0
];
const
int
dim
=
sizes
[
1
];
const
int
seqlen
=
sizes
[
2
];
const
int
width
=
weight
.
size
(
-
1
);
const
int
conv_state_len
=
conv_state
.
size
(
2
);
TORCH_CHECK
(
conv_state_len
>=
width
-
1
);
CHECK_SHAPE
(
x
,
batch_size
,
dim
,
seqlen
);
CHECK_SHAPE
(
weight
,
dim
,
width
);
TORCH_CHECK
(
width
>=
2
&&
width
<=
4
,
"causal_conv1d only supports width between 2 and 4"
);
if
(
bias_
.
has_value
())
{
auto
bias
=
bias_
.
value
();
TORCH_CHECK
(
bias
.
scalar_type
()
==
weight_type
);
TORCH_CHECK
(
bias
.
is_cuda
());
TORCH_CHECK
(
bias
.
stride
(
-
1
)
==
1
);
CHECK_SHAPE
(
bias
,
dim
);
}
at
::
Tensor
out
=
x
;
ConvParamsBase
params
;
set_conv_params_fwd
(
params
,
batch_size
,
dim
,
seqlen
,
width
,
x
,
weight
,
out
,
bias_
,
silu_activation
,
pad_slot_id
);
params
.
conv_state_ptr
=
conv_state
.
data_ptr
();
params
.
conv_state_len
=
conv_state_len
;
// All stride are in elements, not bytes.
params
.
conv_state_batch_stride
=
conv_state
.
stride
(
0
);
params
.
conv_state_c_stride
=
conv_state
.
stride
(
1
);
params
.
conv_state_l_stride
=
conv_state
.
stride
(
2
);
if
(
cache_seqlens_
.
has_value
())
{
auto
cache_seqlens
=
cache_seqlens_
.
value
();
TORCH_CHECK
(
cache_seqlens
.
scalar_type
()
==
torch
::
kInt32
);
TORCH_CHECK
(
cache_seqlens
.
is_cuda
());
TORCH_CHECK
(
cache_seqlens
.
stride
(
-
1
)
==
1
);
CHECK_SHAPE
(
cache_seqlens
,
batch_size
);
params
.
cache_seqlens
=
cache_seqlens
.
data_ptr
<
int32_t
>
();
}
else
{
params
.
cache_seqlens
=
nullptr
;
}
if
(
conv_state_indices_
.
has_value
())
{
auto
conv_state_indices
=
conv_state_indices_
.
value
();
TORCH_CHECK
(
conv_state_indices
.
scalar_type
()
==
torch
::
kInt32
)
TORCH_CHECK
(
conv_state_indices
.
is_cuda
());
TORCH_CHECK
(
conv_state_indices
.
stride
(
0
)
==
1
)
CHECK_SHAPE
(
conv_state_indices
,
batch_size
);
int
conv_state_entries
=
conv_state
.
size
(
0
);
CHECK_SHAPE
(
conv_state
,
conv_state_entries
,
dim
,
conv_state_len
);
params
.
conv_state_indices_ptr
=
conv_state_indices
.
data_ptr
<
int32_t
>
();
}
else
{
CHECK_SHAPE
(
conv_state
,
batch_size
,
dim
,
conv_state_len
);
params
.
conv_state_indices_ptr
=
nullptr
;
}
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
x
.
get_device
()};
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16
(
x
.
scalar_type
(),
"causal_conv1d_update"
,
[
&
]
{
causal_conv1d_update_cuda
<
input_t
,
weight_t
>
(
params
,
stream
);
});
}
template
<
int
kNThreads_
,
int
kWidth_
,
bool
kIsVecLoad_
,
typename
input_t_
,
typename
weight_t_
>
struct
Causal_conv1d_fwd_kernel_traits
{
using
input_t
=
input_t_
;
using
weight_t
=
weight_t_
;
static
constexpr
int
kNThreads
=
kNThreads_
;
static
constexpr
int
kWidth
=
kWidth_
;
static
constexpr
int
kNBytes
=
sizeof
(
input_t
);
static_assert
(
kNBytes
==
2
||
kNBytes
==
4
);
static
constexpr
int
kNElts
=
kNBytes
==
4
?
4
:
8
;
static_assert
(
kWidth
<=
kNElts
);
static
constexpr
bool
kIsVecLoad
=
kIsVecLoad_
;
using
vec_t
=
typename
BytesToType
<
kNBytes
*
kNElts
>::
Type
;
using
BlockLoadT
=
cub
::
BlockLoad
<
input_t
,
kNThreads
,
kNElts
,
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
>
;
using
BlockLoadVecT
=
cub
::
BlockLoad
<
vec_t
,
kNThreads
,
1
,
cub
::
BLOCK_LOAD_DIRECT
>
;
using
BlockStoreT
=
cub
::
BlockStore
<
input_t
,
kNThreads
,
kNElts
,
cub
::
BLOCK_STORE_WARP_TRANSPOSE
>
;
using
BlockStoreVecT
=
cub
::
BlockStore
<
vec_t
,
kNThreads
,
1
,
cub
::
BLOCK_STORE_DIRECT
>
;
static
constexpr
int
kSmemIOSize
=
kIsVecLoad
?
0
:
custom_max
({
sizeof
(
typename
BlockLoadT
::
TempStorage
),
sizeof
(
typename
BlockStoreT
::
TempStorage
)});
static
constexpr
int
kSmemExchangeSize
=
kNThreads
*
kNBytes
*
kNElts
;
static
constexpr
int
kSmemSize
=
kSmemIOSize
+
kSmemExchangeSize
;
};
template
<
typename
Ktraits
>
__global__
__launch_bounds__
(
Ktraits
::
kNThreads
)
void
causal_conv1d_fwd_kernel
(
ConvParamsBase
params
)
{
constexpr
int
kWidth
=
Ktraits
::
kWidth
;
constexpr
int
kNThreads
=
Ktraits
::
kNThreads
;
constexpr
int
kNElts
=
Ktraits
::
kNElts
;
constexpr
bool
kIsVecLoad
=
Ktraits
::
kIsVecLoad
;
using
input_t
=
typename
Ktraits
::
input_t
;
using
vec_t
=
typename
Ktraits
::
vec_t
;
using
weight_t
=
typename
Ktraits
::
weight_t
;
// Shared memory.
extern
__shared__
char
smem_
[];
auto
&
smem_load
=
reinterpret_cast
<
typename
Ktraits
::
BlockLoadT
::
TempStorage
&>
(
smem_
);
auto
&
smem_load_vec
=
reinterpret_cast
<
typename
Ktraits
::
BlockLoadVecT
::
TempStorage
&>
(
smem_
);
auto
&
smem_store
=
reinterpret_cast
<
typename
Ktraits
::
BlockStoreT
::
TempStorage
&>
(
smem_
);
auto
&
smem_store_vec
=
reinterpret_cast
<
typename
Ktraits
::
BlockStoreVecT
::
TempStorage
&>
(
smem_
);
vec_t
*
smem_exchange
=
reinterpret_cast
<
vec_t
*>
(
smem_
+
Ktraits
::
kSmemIOSize
);
const
bool
kVarlen
=
params
.
query_start_loc_ptr
!=
nullptr
;
const
int
tidx
=
threadIdx
.
x
;
const
int
batch_id
=
blockIdx
.
x
;
const
int
channel_id
=
blockIdx
.
y
;
const
int
*
query_start_loc
=
kVarlen
?
reinterpret_cast
<
int
*>
(
params
.
query_start_loc_ptr
)
:
nullptr
;
const
int
sequence_start_index
=
kVarlen
?
query_start_loc
[
batch_id
]
:
batch_id
;
const
int
seqlen
=
kVarlen
?
query_start_loc
[
batch_id
+
1
]
-
sequence_start_index
:
params
.
seqlen
;
input_t
*
x
=
reinterpret_cast
<
input_t
*>
(
params
.
x_ptr
)
+
sequence_start_index
*
params
.
x_batch_stride
+
channel_id
*
params
.
x_c_stride
;
weight_t
*
weight
=
reinterpret_cast
<
weight_t
*>
(
params
.
weight_ptr
)
+
channel_id
*
params
.
weight_c_stride
;
input_t
*
out
=
reinterpret_cast
<
input_t
*>
(
params
.
out_ptr
)
+
sequence_start_index
*
params
.
out_batch_stride
+
channel_id
*
params
.
out_c_stride
;
float
bias_val
=
params
.
bias_ptr
==
nullptr
?
0.
f
:
float
(
reinterpret_cast
<
weight_t
*>
(
params
.
bias_ptr
)[
channel_id
]);
bool
has_initial_state
=
params
.
has_initial_state_ptr
==
nullptr
?
false
:
reinterpret_cast
<
bool
*>
(
params
.
has_initial_state_ptr
)[
batch_id
];
int
*
cache_indices
=
params
.
cache_indices_ptr
==
nullptr
?
nullptr
:
reinterpret_cast
<
int
*>
(
params
.
cache_indices_ptr
);
int
cache_index
=
cache_indices
==
nullptr
?
batch_id
:
cache_indices
[
batch_id
];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if
(
cache_index
==
params
.
pad_slot_id
){
return
;
}
input_t
*
conv_states
=
params
.
conv_states_ptr
==
nullptr
?
nullptr
:
reinterpret_cast
<
input_t
*>
(
params
.
conv_states_ptr
)
+
cache_index
*
params
.
conv_states_batch_stride
+
channel_id
*
params
.
conv_states_c_stride
;
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
if
(
tidx
==
0
)
{
input_t
initial_state
[
kNElts
]
=
{
0
};
if
(
has_initial_state
)
{
#pragma unroll
for
(
int
w
=
0
;
w
<
kWidth
-
1
;
++
w
){
initial_state
[
kNElts
-
1
-
(
kWidth
-
2
)
+
w
]
=
conv_states
[
w
];
}
}
smem_exchange
[
kNThreads
-
1
]
=
reinterpret_cast
<
vec_t
*>
(
initial_state
)[
0
];
}
float
weight_vals
[
kWidth
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kWidth
;
++
i
)
{
weight_vals
[
i
]
=
float
(
weight
[
i
*
params
.
weight_width_stride
]);
}
constexpr
int
kChunkSize
=
kNThreads
*
kNElts
;
const
int
n_chunks
=
(
seqlen
+
kChunkSize
-
1
)
/
kChunkSize
;
for
(
int
chunk
=
0
;
chunk
<
n_chunks
;
++
chunk
)
{
input_t
x_vals_load
[
2
*
kNElts
]
=
{
0
};
if
constexpr
(
kIsVecLoad
)
{
typename
Ktraits
::
BlockLoadVecT
(
smem_load_vec
).
Load
(
reinterpret_cast
<
vec_t
*>
(
x
),
*
reinterpret_cast
<
vec_t
(
*
)[
1
]
>
(
&
x_vals_load
[
kNElts
]),
(
seqlen
-
chunk
*
kChunkSize
)
/
kNElts
);
}
else
{
__syncthreads
();
typename
Ktraits
::
BlockLoadT
(
smem_load
).
Load
(
x
,
*
reinterpret_cast
<
input_t
(
*
)[
kNElts
]
>
(
&
x_vals_load
[
kNElts
]),
seqlen
-
chunk
*
kChunkSize
);
}
x
+=
kChunkSize
;
__syncthreads
();
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
// the last elements of the previous chunk.
if
(
tidx
<
kNThreads
-
1
)
{
smem_exchange
[
tidx
]
=
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
1
];
}
__syncthreads
();
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
0
]
=
smem_exchange
[
tidx
>
0
?
tidx
-
1
:
kNThreads
-
1
];
__syncthreads
();
// Now thread kNThreads - 1 can write the last elements of the current chunk.
if
(
tidx
==
kNThreads
-
1
)
{
smem_exchange
[
tidx
]
=
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
1
];
}
float
x_vals
[
2
*
kNElts
];
#pragma unroll
for
(
int
i
=
0
;
i
<
2
*
kNElts
;
++
i
)
{
x_vals
[
i
]
=
float
(
x_vals_load
[
i
]);
}
float
out_vals
[
kNElts
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kNElts
;
++
i
)
{
out_vals
[
i
]
=
bias_val
;
#pragma unroll
for
(
int
w
=
0
;
w
<
kWidth
;
++
w
)
{
out_vals
[
i
]
+=
weight_vals
[
w
]
*
x_vals
[
kNElts
+
i
-
(
kWidth
-
w
-
1
)];
}
}
if
(
params
.
silu_activation
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
kNElts
;
++
i
)
{
out_vals
[
i
]
=
out_vals
[
i
]
/
(
1
+
expf
(
-
out_vals
[
i
]));
}
}
input_t
out_vals_store
[
kNElts
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kNElts
;
++
i
)
{
out_vals_store
[
i
]
=
out_vals
[
i
];
}
if
constexpr
(
kIsVecLoad
)
{
typename
Ktraits
::
BlockStoreVecT
(
smem_store_vec
).
Store
(
reinterpret_cast
<
vec_t
*>
(
out
),
reinterpret_cast
<
vec_t
(
&
)[
1
]
>
(
out_vals_store
),
(
seqlen
-
chunk
*
kChunkSize
)
/
kNElts
);
}
else
{
typename
Ktraits
::
BlockStoreT
(
smem_store
).
Store
(
out
,
out_vals_store
,
seqlen
-
chunk
*
kChunkSize
);
}
out
+=
kChunkSize
;
int
final_state_position
=
((
seqlen
-
(
kWidth
-
1
))
-
(
n_chunks
-
1
)
*
kChunkSize
);
// in case the final state is separated between the last "smem_exchange" and
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
// (which occurs when `final_state_position` is a non-positivie index)
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
if
(
conv_states
!=
nullptr
&&
final_state_position
<
0
&&
seqlen
>
kWidth
){
input_t
vals_load
[
kNElts
]
=
{
0
};
if
((
chunk
==
n_chunks
-
2
)
&&
(
tidx
==
kNThreads
-
1
)){
// chunk = n_chunks - 2, a segment of the final state sits in the last index
reinterpret_cast
<
vec_t
*>
(
vals_load
)[
0
]
=
smem_exchange
[
kNThreads
-
1
];
#pragma unroll
for
(
int
w
=
0
;
w
<
-
final_state_position
;
++
w
){
conv_states
[
w
]
=
vals_load
[
kNElts
+
final_state_position
+
w
];
}
}
if
((
chunk
==
n_chunks
-
1
)
&&
tidx
==
0
){
// chunk = n_chunks - 1, the second segment of the final state first positions
reinterpret_cast
<
vec_t
*>
(
vals_load
)[
0
]
=
smem_exchange
[
0
];
for
(
int
w
=
-
final_state_position
;
w
<
kWidth
-
1
;
++
w
){
conv_states
[
w
]
=
vals_load
[
w
+
final_state_position
];
}
return
;
}
}
}
// Final state is stored in the smem_exchange last token slot,
// in case seqlen < kWidth, we would need to take the final state from the
// initial state which is stored in conv_states
// in case seqlen > kWidth, we would need to load the last kWidth - 1 data
// and load it into conv_state accordingly
int
last_thread
=
((
seqlen
-
(
kWidth
-
1
))
-
(
n_chunks
-
1
)
*
kChunkSize
)
/
kNElts
;
if
(
conv_states
!=
nullptr
&&
tidx
==
last_thread
)
{
input_t
x_vals_load
[
kNElts
*
2
]
=
{
0
};
// in case we are on the first kWidth tokens
if
(
last_thread
==
0
&&
seqlen
<
kWidth
){
// Need to take the initial state
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
0
]
=
smem_exchange
[
0
];
const
int
offset
=
seqlen
-
(
kWidth
-
1
);
#pragma unroll
for
(
int
w
=
0
;
w
<
kWidth
-
1
;
++
w
){
// pad the existing state
if
((
w
-
seqlen
)
>=
0
&&
has_initial_state
)
{
conv_states
[
w
-
seqlen
]
=
conv_states
[
w
];
}
else
if
((
w
-
seqlen
)
>=
0
&&
!
has_initial_state
)
{
conv_states
[
w
-
seqlen
]
=
input_t
(
0.0
f
);
}
}
#pragma unroll
for
(
int
w
=
0
;
w
<
kWidth
-
1
;
++
w
){
if
(
offset
+
w
>=
0
)
conv_states
[
w
]
=
x_vals_load
[
offset
+
w
];
}
}
else
{
// in case the final state is in between the threads data
const
int
offset
=
((
seqlen
-
(
kWidth
-
1
))
%
(
kNElts
));
if
((
offset
+
kWidth
-
2
)
>=
kNElts
&&
(
last_thread
+
1
<
kNThreads
)){
// In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a
// illegal access error on H100.
// Therefore, we access last_thread + 1, only if the final state data sits there
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
1
]
=
smem_exchange
[
last_thread
+
1
];
}
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
0
]
=
smem_exchange
[
last_thread
];
#pragma unroll
for
(
int
w
=
0
;
w
<
kWidth
-
1
;
++
w
){
conv_states
[
w
]
=
x_vals_load
[
offset
+
w
];
}
}
}
}
template
<
int
kNThreads
,
int
kWidth
,
typename
input_t
,
typename
weight_t
>
void
causal_conv1d_fwd_launch
(
ConvParamsBase
&
params
,
cudaStream_t
stream
)
{
static
constexpr
int
kNElts
=
sizeof
(
input_t
)
==
4
?
4
:
8
;
const
bool
kVarlen
=
params
.
query_start_loc_ptr
!=
nullptr
;
BOOL_SWITCH
(
params
.
seqlen
%
kNElts
==
0
&&
!
kVarlen
,
kIsVecLoad
,
[
&
]
{
using
Ktraits
=
Causal_conv1d_fwd_kernel_traits
<
kNThreads
,
kWidth
,
kIsVecLoad
,
input_t
,
weight_t
>
;
constexpr
int
kSmemSize
=
Ktraits
::
kSmemSize
;
dim3
grid
(
params
.
batch
,
params
.
dim
);
auto
kernel
=
&
causal_conv1d_fwd_kernel
<
Ktraits
>
;
if
(
kSmemSize
>=
48
*
1024
)
{
#ifndef USE_ROCM
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kSmemSize
));
#else
// There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
(
void
*
)
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kSmemSize
));
std
::
cerr
<<
"Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior.
\n
"
<<
std
::
endl
;
#endif
}
kernel
<<<
grid
,
Ktraits
::
kNThreads
,
kSmemSize
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
}
template
<
typename
input_t
,
typename
weight_t
>
void
causal_conv1d_fwd_cuda
(
ConvParamsBase
&
params
,
cudaStream_t
stream
)
{
if
(
params
.
width
==
2
)
{
causal_conv1d_fwd_launch
<
128
,
2
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
width
==
3
)
{
causal_conv1d_fwd_launch
<
128
,
3
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
width
==
4
)
{
causal_conv1d_fwd_launch
<
128
,
4
,
input_t
,
weight_t
>
(
params
,
stream
);
}
}
template
void
causal_conv1d_fwd_cuda
<
float
,
float
>(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
template
void
causal_conv1d_fwd_cuda
<
at
::
Half
,
at
::
Half
>(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
template
void
causal_conv1d_fwd_cuda
<
at
::
BFloat16
,
at
::
BFloat16
>(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
template
<
int
kNThreads_
,
int
kWidth_
,
typename
input_t_
,
typename
weight_t_
>
struct
Causal_conv1d_update_kernel_traits
{
using
input_t
=
input_t_
;
using
weight_t
=
weight_t_
;
static
constexpr
int
kNThreads
=
kNThreads_
;
static
constexpr
int
kWidth
=
kWidth_
;
static
constexpr
int
kNBytes
=
sizeof
(
input_t
);
static_assert
(
kNBytes
==
2
||
kNBytes
==
4
);
};
template
<
typename
Ktraits
,
bool
kIsCircularBuffer
>
__global__
__launch_bounds__
(
Ktraits
::
kNThreads
)
void
causal_conv1d_update_kernel
(
ConvParamsBase
params
)
{
constexpr
int
kWidth
=
Ktraits
::
kWidth
;
constexpr
int
kNThreads
=
Ktraits
::
kNThreads
;
using
input_t
=
typename
Ktraits
::
input_t
;
using
weight_t
=
typename
Ktraits
::
weight_t
;
const
int
tidx
=
threadIdx
.
x
;
const
int
batch_id
=
blockIdx
.
x
;
const
int
channel_id
=
blockIdx
.
y
*
kNThreads
+
tidx
;
if
(
channel_id
>=
params
.
dim
)
return
;
input_t
*
x
=
reinterpret_cast
<
input_t
*>
(
params
.
x_ptr
)
+
batch_id
*
params
.
x_batch_stride
+
channel_id
*
params
.
x_c_stride
;
// If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
// along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
const
int
conv_state_batch_coord
=
params
.
conv_state_indices_ptr
==
nullptr
?
batch_id
:
params
.
conv_state_indices_ptr
[
batch_id
];
// conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early
if
(
conv_state_batch_coord
==
params
.
pad_slot_id
){
return
;
}
input_t
*
conv_state
=
reinterpret_cast
<
input_t
*>
(
params
.
conv_state_ptr
)
+
conv_state_batch_coord
*
params
.
conv_state_batch_stride
+
channel_id
*
params
.
conv_state_c_stride
;
weight_t
*
weight
=
reinterpret_cast
<
weight_t
*>
(
params
.
weight_ptr
)
+
channel_id
*
params
.
weight_c_stride
;
input_t
*
out
=
reinterpret_cast
<
input_t
*>
(
params
.
out_ptr
)
+
batch_id
*
params
.
out_batch_stride
+
channel_id
*
params
.
out_c_stride
;
float
bias_val
=
params
.
bias_ptr
==
nullptr
?
0.
f
:
float
(
reinterpret_cast
<
weight_t
*>
(
params
.
bias_ptr
)[
channel_id
]);
int
state_len
=
params
.
conv_state_len
;
int
advance_len
=
params
.
seqlen
;
int
cache_seqlen
=
kIsCircularBuffer
?
params
.
cache_seqlens
[
batch_id
]
%
state_len
:
0
;
int
update_idx
=
cache_seqlen
-
(
kWidth
-
1
);
update_idx
=
update_idx
<
0
?
update_idx
+
state_len
:
update_idx
;
float
weight_vals
[
kWidth
]
=
{
0
};
#pragma unroll
for
(
int
i
=
0
;
i
<
kWidth
;
++
i
)
{
weight_vals
[
i
]
=
float
(
weight
[
i
*
params
.
weight_width_stride
]);
}
float
x_vals
[
kWidth
]
=
{
0
};
if
constexpr
(
!
kIsCircularBuffer
)
{
#pragma unroll 2
for
(
int
i
=
0
;
i
<
state_len
-
advance_len
-
(
kWidth
-
1
);
++
i
)
{
conv_state
[
i
*
params
.
conv_state_l_stride
]
=
conv_state
[(
i
+
advance_len
)
*
params
.
conv_state_l_stride
];
}
#pragma unroll
for
(
int
i
=
0
;
i
<
kWidth
-
1
;
++
i
)
{
input_t
state_val
=
conv_state
[(
state_len
-
(
kWidth
-
1
)
+
i
)
*
params
.
conv_state_l_stride
];
if
(
i
<
advance_len
+
(
kWidth
-
1
)
&&
state_len
-
advance_len
-
(
kWidth
-
1
)
+
i
>=
0
)
{
conv_state
[(
state_len
-
advance_len
-
(
kWidth
-
1
)
+
i
)
*
params
.
conv_state_l_stride
]
=
state_val
;
}
x_vals
[
i
]
=
float
(
state_val
);
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
kWidth
-
1
;
++
i
,
update_idx
=
update_idx
+
1
>=
state_len
?
update_idx
+
1
-
state_len
:
update_idx
+
1
)
{
input_t
state_val
=
conv_state
[
update_idx
*
params
.
conv_state_l_stride
];
x_vals
[
i
]
=
float
(
state_val
);
}
}
#pragma unroll 2
for
(
int
i
=
0
;
i
<
params
.
seqlen
;
++
i
)
{
input_t
x_val
=
x
[
i
*
params
.
x_l_stride
];
if
constexpr
(
!
kIsCircularBuffer
)
{
if
(
i
<
advance_len
&&
state_len
-
advance_len
+
i
>=
0
)
{
conv_state
[(
state_len
-
advance_len
+
i
)
*
params
.
conv_state_l_stride
]
=
x_val
;
}
}
else
{
conv_state
[
update_idx
*
params
.
conv_state_l_stride
]
=
x_val
;
++
update_idx
;
update_idx
=
update_idx
>=
state_len
?
update_idx
-
state_len
:
update_idx
;
}
x_vals
[
kWidth
-
1
]
=
float
(
x_val
);
float
out_val
=
bias_val
;
#pragma unroll
for
(
int
j
=
0
;
j
<
kWidth
;
++
j
)
{
out_val
+=
weight_vals
[
j
]
*
x_vals
[
j
];
}
if
(
params
.
silu_activation
)
{
out_val
=
out_val
/
(
1
+
expf
(
-
out_val
));
}
out
[
i
*
params
.
out_l_stride
]
=
input_t
(
out_val
);
// Shift the input buffer by 1
#pragma unroll
for
(
int
i
=
0
;
i
<
kWidth
-
1
;
++
i
)
{
x_vals
[
i
]
=
x_vals
[
i
+
1
];
}
}
}
template
<
int
kNThreads
,
int
kWidth
,
typename
input_t
,
typename
weight_t
>
void
causal_conv1d_update_launch
(
ConvParamsBase
&
params
,
cudaStream_t
stream
)
{
using
Ktraits
=
Causal_conv1d_update_kernel_traits
<
kNThreads
,
kWidth
,
input_t
,
weight_t
>
;
dim3
grid
(
params
.
batch
,
(
params
.
dim
+
kNThreads
-
1
)
/
kNThreads
);
auto
kernel
=
params
.
cache_seqlens
==
nullptr
?
&
causal_conv1d_update_kernel
<
Ktraits
,
false
>
:
&
causal_conv1d_update_kernel
<
Ktraits
,
true
>
;
kernel
<<<
grid
,
Ktraits
::
kNThreads
,
0
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
template
<
typename
input_t
,
typename
weight_t
>
void
causal_conv1d_update_cuda
(
ConvParamsBase
&
params
,
cudaStream_t
stream
)
{
if
(
params
.
width
==
2
)
{
causal_conv1d_update_launch
<
64
,
2
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
width
==
3
)
{
causal_conv1d_update_launch
<
64
,
3
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
width
==
4
)
{
causal_conv1d_update_launch
<
64
,
4
,
input_t
,
weight_t
>
(
params
,
stream
);
}
}
template
void
causal_conv1d_update_cuda
<
float
,
float
>(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
template
void
causal_conv1d_update_cuda
<
at
::
Half
,
at
::
Half
>(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
template
void
causal_conv1d_update_cuda
<
at
::
BFloat16
,
at
::
BFloat16
>(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
sgl-kernel/csrc/mamba/causal_conv1d.h
0 → 100644
View file @
8cbe1538
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
// clang-format off
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
ConvParamsBase
{
using
index_t
=
uint32_t
;
int
batch
,
dim
,
seqlen
,
width
;
int64_t
pad_slot_id
;
bool
silu_activation
;
index_t
x_batch_stride
;
index_t
x_c_stride
;
index_t
x_l_stride
;
index_t
weight_c_stride
;
index_t
weight_width_stride
;
index_t
out_batch_stride
;
index_t
out_c_stride
;
index_t
out_l_stride
;
int
conv_state_len
;
index_t
conv_state_batch_stride
;
index_t
conv_state_c_stride
;
index_t
conv_state_l_stride
;
// Common data pointers.
void
*
__restrict__
x_ptr
;
void
*
__restrict__
weight_ptr
;
void
*
__restrict__
bias_ptr
;
void
*
__restrict__
out_ptr
;
void
*
__restrict__
conv_state_ptr
;
void
*
__restrict__
query_start_loc_ptr
;
void
*
__restrict__
has_initial_state_ptr
;
void
*
__restrict__
cache_indices_ptr
;
int32_t
*
__restrict__
cache_seqlens
;
// For the continuous batching case. Makes it so that the mamba state for
// the current batch doesn't need to be a contiguous tensor.
int32_t
*
__restrict__
conv_state_indices_ptr
;
void
*
__restrict__
seq_idx_ptr
;
// No __restrict__ since initial_states could be the same as final_states.
void
*
initial_states_ptr
;
index_t
initial_states_batch_stride
;
index_t
initial_states_l_stride
;
index_t
initial_states_c_stride
;
void
*
final_states_ptr
;
index_t
final_states_batch_stride
;
index_t
final_states_l_stride
;
index_t
final_states_c_stride
;
void
*
conv_states_ptr
;
index_t
conv_states_batch_stride
;
index_t
conv_states_l_stride
;
index_t
conv_states_c_stride
;
};
#ifndef USE_ROCM
#include <cuda_bf16.h>
template
<
typename
T
>
__device__
inline
T
shuffle_xor
(
T
val
,
int
offset
)
{
return
__shfl_xor_sync
(
uint32_t
(
-
1
),
val
,
offset
);
}
constexpr
size_t
custom_max
(
std
::
initializer_list
<
size_t
>
ilist
)
{
return
std
::
max
(
ilist
);
}
template
<
typename
T
>
constexpr
T
constexpr_min
(
T
a
,
T
b
)
{
return
std
::
min
(
a
,
b
);
}
#else
#include <hip/hip_bf16.h>
template
<
typename
T
>
__device__
inline
T
shuffle_xor
(
T
val
,
int
offset
)
{
return
__shfl_xor
(
val
,
offset
);
}
constexpr
size_t
custom_max
(
std
::
initializer_list
<
size_t
>
ilist
)
{
return
*
std
::
max_element
(
ilist
.
begin
(),
ilist
.
end
());
}
template
<
typename
T
>
constexpr
T
constexpr_min
(
T
a
,
T
b
)
{
return
a
<
b
?
a
:
b
;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
BYTES
>
struct
BytesToType
{};
template
<
>
struct
BytesToType
<
16
>
{
using
Type
=
uint4
;
static_assert
(
sizeof
(
Type
)
==
16
);
};
template
<
>
struct
BytesToType
<
8
>
{
using
Type
=
uint64_t
;
static_assert
(
sizeof
(
Type
)
==
8
);
};
template
<
>
struct
BytesToType
<
4
>
{
using
Type
=
uint32_t
;
static_assert
(
sizeof
(
Type
)
==
4
);
};
template
<
>
struct
BytesToType
<
2
>
{
using
Type
=
uint16_t
;
static_assert
(
sizeof
(
Type
)
==
2
);
};
template
<
>
struct
BytesToType
<
1
>
{
using
Type
=
uint8_t
;
static_assert
(
sizeof
(
Type
)
==
1
);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
SumOp
{
__device__
inline
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
+
y
;
}
};
template
<
int
THREADS
>
struct
Allreduce
{
static_assert
(
THREADS
==
32
||
THREADS
==
16
||
THREADS
==
8
||
THREADS
==
4
);
template
<
typename
T
,
typename
Operator
>
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
constexpr
int
OFFSET
=
THREADS
/
2
;
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
OFFSET
));
return
Allreduce
<
OFFSET
>::
run
(
x
,
op
);
}
};
template
<
>
struct
Allreduce
<
2
>
{
template
<
typename
T
,
typename
Operator
>
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
1
));
return
x
;
}
};
sgl-kernel/include/sgl_kernel_ops.h
View file @
8cbe1538
...
...
@@ -724,3 +724,27 @@ void store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc,
void
copy_to_gpu_no_ce
(
const
at
::
Tensor
&
input
,
at
::
Tensor
&
output
);
void
concat_mla_k
(
torch
::
Tensor
k
,
torch
::
Tensor
k_nope
,
torch
::
Tensor
k_rope
);
/*
* From csrc/mamba
*/
void
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
std
::
optional
<
at
::
Tensor
>&
bias_
,
bool
silu_activation
,
const
std
::
optional
<
at
::
Tensor
>&
cache_seqlens_
,
const
std
::
optional
<
at
::
Tensor
>&
conv_state_indices_
,
int64_t
pad_slot_id
);
void
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
std
::
optional
<
at
::
Tensor
>&
bias_
,
const
std
::
optional
<
at
::
Tensor
>&
conv_states
,
const
std
::
optional
<
at
::
Tensor
>&
query_start_loc
,
const
std
::
optional
<
at
::
Tensor
>&
cache_indices
,
const
std
::
optional
<
at
::
Tensor
>&
has_initial_state
,
bool
silu_activation
,
int64_t
pad_slot_id
);
sgl-kernel/python/sgl_kernel/__init__.py
View file @
8cbe1538
...
...
@@ -34,6 +34,7 @@ from sgl_kernel.elementwise import (
rmsnorm
,
silu_and_mul
,
)
from
sgl_kernel.mamba
import
causal_conv1d_fwd
,
causal_conv1d_update
if
torch
.
version
.
hip
is
not
None
:
from
sgl_kernel.elementwise
import
gelu_quick
...
...
sgl-kernel/python/sgl_kernel/mamba.py
0 → 100644
View file @
8cbe1538
from
typing
import
Optional
import
torch
# mamba
def
causal_conv1d_fwd
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
conv_states
:
Optional
[
torch
.
Tensor
],
query_start_loc
:
Optional
[
torch
.
Tensor
],
cache_indices
:
Optional
[
torch
.
Tensor
],
has_initial_state
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
,
pad_slot_id
:
int
,
):
torch
.
ops
.
sgl_kernel
.
causal_conv1d_fwd
(
x
,
weight
,
bias_
,
conv_states
,
query_start_loc
,
cache_indices
,
has_initial_state
,
silu_activation
,
pad_slot_id
,
)
def
causal_conv1d_update
(
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
,
cache_seqlens
:
Optional
[
torch
.
Tensor
],
conv_state_indices
:
Optional
[
torch
.
Tensor
],
pad_slot_id
:
int
,
):
torch
.
ops
.
sgl_kernel
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias_
,
silu_activation
,
cache_seqlens
,
conv_state_indices
,
pad_slot_id
,
)
sgl-kernel/tests/test_causal_conv1d.py
0 → 100644
View file @
8cbe1538
# Adapted from https://github.com/vllm-project/vllm/blob/main/tests/kernels/mamba/test_causal_conv1d.py
from
typing
import
Optional
import
torch
from
sgl_kernel
import
causal_conv1d_fwd
from
sgl_kernel
import
causal_conv1d_update
as
causal_conv1d_update_kernel
PAD_SLOT_ID
=
-
1
def
causal_conv1d_fn
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
has_initial_state
:
Optional
[
torch
.
Tensor
]
=
None
,
conv_states
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
"silu"
,
pad_slot_id
:
int
=
PAD_SLOT_ID
,
):
"""
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
sequences are concatenated from left to right for varlen
weight: (dim, width)
bias: (dim,)
query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended by 0.
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
indicates the corresponding state index,
like so: conv_state = conv_states[cache_indices[batch_id]]
has_initial_state: (batch) bool
indicates whether should the kernel take the current state as initial
state for the calculations
conv_states: (...,dim,width - 1) itype
updated inplace if provided
activation: either None or "silu" or "swish"
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim, seqlen)
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
if
x
.
stride
(
-
1
)
!=
1
:
x
=
x
.
contiguous
()
bias
=
bias
.
contiguous
()
if
bias
is
not
None
else
None
causal_conv1d_fwd
(
x
,
weight
,
bias
,
conv_states
,
query_start_loc
,
cache_indices
,
has_initial_state
,
activation
in
[
"silu"
,
"swish"
],
pad_slot_id
,
)
return
x
def
causal_conv1d_update
(
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
None
,
cache_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
conv_state_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
pad_slot_id
:
int
=
PAD_SLOT_ID
,
):
"""
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the conv_state
starting at the index
@cache_seqlens % state_len.
conv_state_indices: (batch,), dtype int32
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen)
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
raise
NotImplementedError
(
f
"activation must be None, silu, or swish, actual:
{
activation
}
"
)
activation_val
=
activation
in
[
"silu"
,
"swish"
]
unsqueeze
=
x
.
dim
()
==
2
if
unsqueeze
:
x
=
x
.
unsqueeze
(
-
1
)
causal_conv1d_update_kernel
(
x
,
conv_state
,
weight
,
bias
,
activation_val
,
cache_seqlens
,
conv_state_indices
,
pad_slot_id
,
)
if
unsqueeze
:
x
=
x
.
squeeze
(
-
1
)
return
x
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
import
pytest
import
torch
import
torch.nn.functional
as
F
def
causal_conv1d_ref
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
initial_states
:
Optional
[
torch
.
Tensor
]
=
None
,
return_final_states
:
bool
=
False
,
final_states_out
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
"silu"
,
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1)
out: (batch, dim, seqlen)
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
dtype_in
=
x
.
dtype
x
=
x
.
to
(
weight
.
dtype
)
seqlen
=
x
.
shape
[
-
1
]
dim
,
width
=
weight
.
shape
if
initial_states
is
None
:
out
=
F
.
conv1d
(
x
,
weight
.
unsqueeze
(
1
),
bias
,
padding
=
width
-
1
,
groups
=
dim
)
else
:
x
=
torch
.
cat
([
initial_states
,
x
],
dim
=-
1
)
out
=
F
.
conv1d
(
x
,
weight
.
unsqueeze
(
1
),
bias
,
padding
=
0
,
groups
=
dim
)
out
=
out
[...,
:
seqlen
]
if
return_final_states
:
final_states
=
F
.
pad
(
x
,
(
width
-
1
-
x
.
shape
[
-
1
],
0
)).
to
(
dtype_in
)
# (batch, dim, width - 1)
if
final_states_out
is
not
None
:
final_states_out
.
copy_
(
final_states
)
else
:
final_states_out
=
final_states
out
=
(
out
if
activation
is
None
else
F
.
silu
(
out
)).
to
(
dtype
=
dtype_in
)
return
(
out
,
None
)
if
not
return_final_states
else
(
out
,
final_states_out
)
def
causal_conv1d_update_ref
(
x
,
conv_state
,
weight
,
bias
=
None
,
activation
=
None
,
cache_seqlens
=
None
):
"""
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the
conv_state starting at the index
@cache_seqlens % state_len before performing the convolution.
out: (batch, dim) or (batch, dim, seqlen)
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
dtype_in
=
x
.
dtype
unsqueeze
=
x
.
dim
()
==
2
if
unsqueeze
:
x
=
x
.
unsqueeze
(
-
1
)
batch
,
dim
,
seqlen
=
x
.
shape
width
=
weight
.
shape
[
1
]
state_len
=
conv_state
.
shape
[
-
1
]
assert
conv_state
.
shape
==
(
batch
,
dim
,
state_len
)
assert
weight
.
shape
==
(
dim
,
width
)
if
cache_seqlens
is
None
:
x_new
=
torch
.
cat
([
conv_state
,
x
],
dim
=-
1
).
to
(
weight
.
dtype
)
# (batch, dim, state_len + seqlen)
conv_state
.
copy_
(
x_new
[:,
:,
-
state_len
:])
else
:
width_idx
=
torch
.
arange
(
-
(
width
-
1
),
0
,
dtype
=
torch
.
long
,
device
=
x
.
device
).
unsqueeze
(
0
)
+
cache_seqlens
.
unsqueeze
(
1
)
width_idx
=
(
torch
.
remainder
(
width_idx
,
state_len
).
unsqueeze
(
1
).
expand
(
-
1
,
dim
,
-
1
)
)
x_new
=
torch
.
cat
([
conv_state
.
gather
(
2
,
width_idx
),
x
],
dim
=-
1
).
to
(
weight
.
dtype
)
copy_idx
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
x
.
device
).
unsqueeze
(
0
)
+
cache_seqlens
.
unsqueeze
(
1
)
copy_idx
=
torch
.
remainder
(
copy_idx
,
state_len
).
unsqueeze
(
1
).
expand
(
-
1
,
dim
,
-
1
)
conv_state
.
scatter_
(
2
,
copy_idx
,
x
)
out
=
F
.
conv1d
(
x_new
,
weight
.
unsqueeze
(
1
),
bias
,
padding
=
0
,
groups
=
dim
)[
:,
:,
-
seqlen
:
]
if
unsqueeze
:
out
=
out
.
squeeze
(
-
1
)
return
(
out
if
activation
is
None
else
F
.
silu
(
out
)).
to
(
dtype
=
dtype_in
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
,
torch
.
float
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_initial_state"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
1
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
784
,
1024
,
1025
,
2048
,
4096
]
)
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"batch"
,
[
1
])
def
test_causal_conv1d
(
batch
,
dim
,
seqlen
,
width
,
has_bias
,
silu_activation
,
has_initial_state
,
itype
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
x
=
torch
.
randn
(
batch
,
dim
,
seqlen
,
device
=
device
,
dtype
=
itype
).
contiguous
()
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
if
has_initial_state
:
initial_states
=
torch
.
randn
(
batch
,
dim
,
width
-
1
,
device
=
device
,
dtype
=
itype
)
has_initial_state_tensor
=
torch
.
ones
(
batch
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
else
:
initial_states
=
None
has_initial_state_tensor
=
None
x_ref
=
x
.
clone
()
weight_ref
=
weight
.
clone
()
bias_ref
=
bias
.
clone
()
if
bias
is
not
None
else
None
initial_states_ref
=
initial_states
.
clone
()
if
initial_states
is
not
None
else
None
activation
=
None
if
not
silu_activation
else
"silu"
out
=
causal_conv1d_fn
(
x
,
weight
,
bias
,
activation
=
activation
,
conv_states
=
initial_states
,
has_initial_state
=
has_initial_state_tensor
,
)
out_ref
,
final_states_ref
=
causal_conv1d_ref
(
x_ref
,
weight_ref
,
bias_ref
,
initial_states
=
initial_states_ref
,
return_final_states
=
True
,
activation
=
activation
,
)
if
has_initial_state
:
assert
initial_states
is
not
None
and
final_states_ref
is
not
None
assert
torch
.
allclose
(
initial_states
,
final_states_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
2048
+
16
,
4096
])
def
test_causal_conv1d_update
(
dim
,
width
,
seqlen
,
has_bias
,
silu_activation
,
itype
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
batch
=
2
x
=
torch
.
randn
(
batch
,
dim
,
seqlen
,
device
=
device
,
dtype
=
itype
)
x_ref
=
x
.
clone
()
conv_state
=
torch
.
randn
(
batch
,
dim
,
width
-
1
,
device
=
device
,
dtype
=
itype
)
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
conv_state_ref
=
conv_state
.
detach
().
clone
()
activation
=
None
if
not
silu_activation
else
"silu"
out
=
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias
,
activation
=
activation
)
out_ref
=
causal_conv1d_update_ref
(
x_ref
,
conv_state_ref
,
weight
,
bias
,
activation
=
activation
)
assert
torch
.
equal
(
conv_state
,
conv_state_ref
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
1
,
4
,
5
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
2
,
3
,
4
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
2048
+
16
,
4096
])
# tests correctness in case subset of the sequences are padded
@
pytest
.
mark
.
parametrize
(
"with_padding"
,
[
True
,
False
])
def
test_causal_conv1d_update_with_batch_gather
(
with_padding
,
dim
,
width
,
seqlen
,
has_bias
,
silu_activation
,
itype
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
batch_size
=
3
padding
=
5
if
with_padding
else
0
padded_batch_size
=
batch_size
+
padding
total_entries
=
10
*
batch_size
x
=
torch
.
randn
(
padded_batch_size
,
dim
,
1
,
device
=
device
,
dtype
=
itype
)
x_ref
=
x
.
clone
()
conv_state_indices
=
torch
.
randperm
(
total_entries
)[:
batch_size
].
to
(
dtype
=
torch
.
int32
,
device
=
device
)
unused_states_bool
=
torch
.
ones
(
total_entries
,
dtype
=
torch
.
bool
,
device
=
device
)
unused_states_bool
[
conv_state_indices
]
=
False
padded_state_indices
=
torch
.
concat
(
[
conv_state_indices
,
torch
.
as_tensor
([
PAD_SLOT_ID
]
*
padding
,
dtype
=
torch
.
int32
,
device
=
device
),
],
dim
=
0
,
)
conv_state
=
torch
.
randn
(
total_entries
,
dim
,
width
-
1
,
device
=
device
,
dtype
=
itype
)
conv_state_for_padding_test
=
conv_state
.
clone
()
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
conv_state_ref
=
conv_state
[
conv_state_indices
,
:].
detach
().
clone
()
activation
=
None
if
not
silu_activation
else
"silu"
out
=
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias
,
activation
=
activation
,
conv_state_indices
=
padded_state_indices
,
pad_slot_id
=
PAD_SLOT_ID
,
)
out_ref
=
causal_conv1d_update_ref
(
x_ref
[:
batch_size
],
conv_state_ref
,
weight
,
bias
,
activation
=
activation
)
assert
torch
.
equal
(
conv_state
[
conv_state_indices
,
:],
conv_state_ref
)
assert
torch
.
allclose
(
out
[:
batch_size
],
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
equal
(
conv_state
[
unused_states_bool
],
conv_state_for_padding_test
[
unused_states_bool
]
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
8
,
16
,
32
,
64
,
128
,
256
,
512
,
784
,
1024
,
2048
,
2049
,
4096
]
)
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
64
,
4096
])
# tests correctness in case subset of the sequences are padded
@
pytest
.
mark
.
parametrize
(
"with_padding"
,
[
True
,
False
])
def
test_causal_conv1d_varlen
(
with_padding
,
dim
,
seqlen
,
width
,
has_bias
,
silu_activation
,
itype
):
device
=
"cuda"
torch
.
cuda
.
empty_cache
()
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
seqlens
=
[]
batch_size
=
4
if
seqlen
<
10
:
batch_size
=
1
padding
=
3
if
with_padding
else
0
padded_batch_size
=
batch_size
+
padding
nsplits
=
padded_batch_size
-
1
eos_pos
=
torch
.
randperm
(
seqlen
-
1
)[:
nsplits
].
sort
().
values
seqlens
.
append
(
torch
.
diff
(
torch
.
cat
([
torch
.
tensor
([
-
1
]),
eos_pos
,
torch
.
tensor
([
seqlen
-
1
])])
).
tolist
()
)
assert
sum
(
seqlens
[
-
1
])
==
seqlen
assert
all
(
s
>
0
for
s
in
seqlens
[
-
1
])
total_entries
=
batch_size
*
10
cumsum
=
torch
.
cumsum
(
torch
.
tensor
(
seqlens
[
0
]),
dim
=
0
).
to
(
torch
.
int32
)
cumsum
=
torch
.
concat
([
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
),
cumsum
],
dim
=
0
)
x
=
torch
.
randn
(
1
,
4096
+
dim
+
64
,
seqlen
,
device
=
device
,
dtype
=
itype
)[
:,
4096
:
4096
+
dim
,
:
]
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
x_ref
=
x
.
clone
()
weight_ref
=
weight
.
clone
()
bias_ref
=
bias
.
clone
()
if
bias
is
not
None
else
None
activation
=
None
if
not
silu_activation
else
"silu"
final_states
=
torch
.
randn
(
total_entries
,
dim
,
width
-
1
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
final_states_ref
=
final_states
.
clone
()
has_initial_states
=
torch
.
randint
(
0
,
2
,
(
cumsum
.
shape
[
0
]
-
1
,),
dtype
=
torch
.
bool
,
device
=
x
.
device
)
state_indices
=
torch
.
randperm
(
total_entries
,
dtype
=
torch
.
int32
,
device
=
x
.
device
)[
:
batch_size
]
padded_state_indices
=
torch
.
concat
(
[
state_indices
,
torch
.
as_tensor
([
PAD_SLOT_ID
]
*
padding
,
dtype
=
torch
.
int32
,
device
=
device
),
],
dim
=-
1
,
)
out
=
causal_conv1d_fn
(
x
.
squeeze
(
0
),
weight
,
bias
,
cumsum
.
cuda
(),
padded_state_indices
,
has_initial_states
,
final_states
,
activation
,
PAD_SLOT_ID
,
)
out_ref
=
[]
out_ref_b
=
[]
splits
=
[
torch
.
split
(
var
,
seqlens
[
0
],
dim
=-
1
)
for
var
in
(
x_ref
)]
for
i
in
range
(
len
(
seqlens
[
0
])):
x_s
=
[
v
[
i
].
unsqueeze
(
0
)
for
v
in
splits
][
0
]
if
padded_state_indices
[
i
]
==
PAD_SLOT_ID
:
continue
out_ref_b
.
append
(
causal_conv1d_ref
(
x_s
,
weight_ref
,
bias_ref
,
activation
=
activation
,
return_final_states
=
True
,
final_states_out
=
final_states_ref
[
padded_state_indices
[
i
]].
unsqueeze
(
0
),
initial_states
=
(
final_states_ref
[
padded_state_indices
[
i
]].
unsqueeze
(
0
)
if
has_initial_states
[
i
]
else
None
),
)
)
out_ref
.
append
(
torch
.
cat
([
t
[
0
]
for
t
in
out_ref_b
],
dim
=
2
))
out_ref_tensor
=
torch
.
cat
(
out_ref
,
dim
=
0
)
unpadded_out
=
out
[:,
:
out_ref_tensor
.
shape
[
-
1
]]
assert
torch
.
allclose
(
unpadded_out
,
out_ref_tensor
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
final_states
[
state_indices
],
final_states_ref
[
state_indices
],
rtol
=
rtol
,
atol
=
atol
,
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment