Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
711aa9d5
Commit
711aa9d5
authored
Jul 30, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.10.0' into v0.10.0-dev
parents
751c492c
6d8d0a24
Changes
519
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
276 additions
and
999 deletions
+276
-999
csrc/cpu/quant.cpp
csrc/cpu/quant.cpp
+12
-9
csrc/cpu/sgl-kernels/common.h
csrc/cpu/sgl-kernels/common.h
+1
-1
csrc/cpu/sgl-kernels/gemm.h
csrc/cpu/sgl-kernels/gemm.h
+1
-1
csrc/cpu/sgl-kernels/gemm_int8.cpp
csrc/cpu/sgl-kernels/gemm_int8.cpp
+1
-1
csrc/cpu/sgl-kernels/vec.h
csrc/cpu/sgl-kernels/vec.h
+1
-1
csrc/cpu/shm.cpp
csrc/cpu/shm.cpp
+48
-21
csrc/cpu/torch_bindings.cpp
csrc/cpu/torch_bindings.cpp
+2
-1
csrc/cuda_compat.h
csrc/cuda_compat.h
+3
-3
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
+3
-3
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
+4
-4
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+40
-23
csrc/layernorm_quant_kernels.cu
csrc/layernorm_quant_kernels.cu
+25
-14
csrc/mamba/causal_conv1d/causal_conv1d.cu
csrc/mamba/causal_conv1d/causal_conv1d.cu
+0
-656
csrc/mamba/causal_conv1d/causal_conv1d.h
csrc/mamba/causal_conv1d/causal_conv1d.h
+0
-159
csrc/mamba/causal_conv1d/static_switch.h
csrc/mamba/causal_conv1d/static_switch.h
+0
-28
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
+35
-25
csrc/moe/moe_align_sum_kernels.cu
csrc/moe/moe_align_sum_kernels.cu
+55
-16
csrc/ops.h
csrc/ops.h
+5
-16
csrc/opt/layernorm_kernels_opt.cu
csrc/opt/layernorm_kernels_opt.cu
+36
-12
csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu
...ation/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu
+4
-5
No files found.
Too many changes to show.
To preserve performance only
519 of 519+
files are displayed.
Plain diff
Email patch
csrc/cpu/quant.cpp
View file @
711aa9d5
...
...
@@ -36,7 +36,7 @@ struct KernelVecType<c10::Half> {
using
cvt_vec_type
=
vec_op
::
FP32Vec16
;
};
#ifdef
__AVX512F__
#if
def
ined(
__AVX512F__
) || defined(__aarch64__)
template
<
bool
AZP
,
typename
scalar_t
>
void
static_scaled_int8_quant_impl
(
const
scalar_t
*
input
,
int8_t
*
output
,
const
float
*
scale
,
const
int32_t
*
azp
,
...
...
@@ -598,8 +598,9 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
const
float
*
scale
,
const
int32_t
*
azp
,
const
int
num_tokens
,
const
int
hidden_size
)
{
TORCH_CHECK
(
false
,
"static_scaled_int8_quant_impl requires AVX512/powerpc64 support."
)
TORCH_CHECK
(
false
,
"static_scaled_int8_quant_impl requires AVX512/powerpc64/AArch64 "
"support."
)
}
template
<
typename
scalar_t
>
...
...
@@ -607,9 +608,9 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
float
*
scale
,
int32_t
*
azp
,
const
int
num_tokens
,
const
int
hidden_size
)
{
TORCH_CHECK
(
false
,
"dynamic_scaled_int8_quant_impl requires
AVX512/powerpc64 support."
)
TORCH_CHECK
(
false
,
"dynamic_scaled_int8_quant_impl requires "
"
AVX512/powerpc64
/AArch64
support."
)
}
template
<
bool
PerChannel
,
typename
scalar_t
>
...
...
@@ -617,7 +618,8 @@ void static_quant_epilogue(const float* input, scalar_t* output,
const
float
a_scale
,
const
float
*
b_scale
,
const
int32_t
*
azp_with_adj
,
const
int
num_tokens
,
const
int
hidden_size
)
{
TORCH_CHECK
(
false
,
"static_quant_epilogue requires AVX512/powerpc64 support."
)
TORCH_CHECK
(
false
,
"static_quant_epilogue requires AVX512/powerpc64/AArch64 support."
)
}
template
<
typename
scalar_t
>
...
...
@@ -626,8 +628,9 @@ void dynamic_quant_epilogue(const float* input, scalar_t* output,
const
int32_t
*
azp
,
const
int32_t
*
azp_with_adj
,
const
scalar_t
*
bias
,
const
int
num_tokens
,
const
int
hidden_size
)
{
TORCH_CHECK
(
false
,
"dynamic_quant_epilogue requires AVX512/powerpc64 support."
)
TORCH_CHECK
(
false
,
"dynamic_quant_epilogue requires AVX512/powerpc64/AArch64 support."
)
}
#endif
}
// namespace
...
...
csrc/cpu/sgl-kernels/common.h
View file @
711aa9d5
...
...
@@ -58,7 +58,7 @@ namespace {
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimen
t
ion")
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimen
s
ion")
#define CHECK_INPUT(x) \
CHECK_CPU(x); \
...
...
csrc/cpu/sgl-kernels/gemm.h
View file @
711aa9d5
...
...
@@ -126,7 +126,7 @@ void fused_experts_int4_w4a16_kernel_impl(
int64_t
topk
,
int64_t
num_tokens_post_pad
);
// shared expert impleme
m
ntation for int8 w8a8
// shared expert implementation for int8 w8a8
template
<
typename
scalar_t
>
void
shared_expert_int8_kernel_impl
(
scalar_t
*
__restrict__
output
,
...
...
csrc/cpu/sgl-kernels/gemm_int8.cpp
View file @
711aa9d5
...
...
@@ -41,7 +41,7 @@ struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
__m512
vd0
;
__m512
vd1
[
COLS
];
// oops! 4x4 spills but luckly we use 4x2
// oops! 4x4 spills but luck
i
ly we use 4x2
__m512
vbias
[
COLS
];
// [NOTE]: s8s8 igemm compensation in avx512-vnni
...
...
csrc/cpu/sgl-kernels/vec.h
View file @
711aa9d5
...
...
@@ -37,7 +37,7 @@ inline Vectorized<at::BFloat16> convert_from_float_ext<at::BFloat16>(const Vecto
#define CVT_FP16_TO_FP32(a) \
_mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
// this doesn't hane
l
NaN.
// this doesn't han
dl
e NaN.
inline
__m512bh
cvt_e4m3_bf16_intrinsic_no_nan
(
__m256i
fp8_vec
)
{
const
__m512i
x
=
_mm512_cvtepu8_epi16
(
fp8_vec
);
...
...
csrc/cpu/shm.cpp
View file @
711aa9d5
...
...
@@ -7,7 +7,7 @@
namespace
{
#define MAX_SHM_RANK_NUM 8
#define PER_THREAD_SHM_BUFFER_BYTES (
2
* 1024 * 1024)
#define PER_THREAD_SHM_BUFFER_BYTES (
4
* 1024 * 1024)
static_assert
(
PER_THREAD_SHM_BUFFER_BYTES
%
2
==
0
);
#define PER_THREAD_SHM_BUFFER_OFFSET (PER_THREAD_SHM_BUFFER_BYTES >> 1)
#define MIN_THREAD_PROCESS_SIZE (256)
...
...
@@ -34,9 +34,10 @@ struct KernelVecType<c10::Half> {
};
struct
ThreadSHMContext
{
volatile
char
_curr_thread_stamp
;
volatile
char
_ready_thread_stamp
;
char
_padding1
[
6
];
volatile
char
_curr_thread_stamp
[
2
];
volatile
char
_ready_thread_stamp
[
2
];
int
local_stamp_buffer_idx
;
int
remote_stamp_buffer_idx
;
int
thread_id
;
int
thread_num
;
int
rank
;
...
...
@@ -45,23 +46,28 @@ struct ThreadSHMContext {
int
swizzled_ranks
[
MAX_SHM_RANK_NUM
];
void
*
thread_shm_ptrs
[
MAX_SHM_RANK_NUM
];
ThreadSHMContext
*
shm_contexts
[
MAX_SHM_RANK_NUM
];
size_t
_thread_buffer_mask
;
char
_padding2
[
56
];
size_t
_thread_buffer_mask
[
2
]
;
char
_padding2
[
40
];
ThreadSHMContext
(
const
int
thread_id
,
const
int
thread_num
,
const
int
rank
,
const
int
group_size
,
void
*
thread_shm_ptr
)
:
_curr_thread_stamp
(
1
),
_
re
ady_thread_stamp
(
0
),
:
local_stamp_buffer_idx
(
0
),
re
mote_stamp_buffer_idx
(
0
),
thread_id
(
thread_id
),
thread_num
(
thread_num
),
rank
(
rank
),
group_size
(
group_size
),
_spinning_count
(
0
),
_thread_buffer_mask
(
0
)
{
_spinning_count
(
0
)
{
static_assert
(
sizeof
(
ThreadSHMContext
)
%
64
==
0
);
TORCH_CHECK
(
group_size
<=
MAX_SHM_RANK_NUM
);
TORCH_CHECK
((
size_t
)
this
%
64
==
0
);
TORCH_CHECK
((
size_t
)
thread_shm_ptr
%
64
==
0
);
_curr_thread_stamp
[
0
]
=
1
;
_curr_thread_stamp
[
1
]
=
1
;
_ready_thread_stamp
[
0
]
=
0
;
_ready_thread_stamp
[
1
]
=
0
;
_thread_buffer_mask
[
0
]
=
0
;
_thread_buffer_mask
[
1
]
=
0
;
for
(
int
i
=
0
;
i
<
MAX_SHM_RANK_NUM
;
++
i
)
{
shm_contexts
[
i
]
=
nullptr
;
thread_shm_ptrs
[
i
]
=
nullptr
;
...
...
@@ -70,6 +76,11 @@ struct ThreadSHMContext {
set_context
(
rank
,
this
,
thread_shm_ptr
);
}
void
set_stamp_buffer_idx
(
int
local
,
int
remote
)
{
local_stamp_buffer_idx
=
local
;
remote_stamp_buffer_idx
=
remote
;
}
void
set_context
(
int
rank
,
ThreadSHMContext
*
ptr
,
void
*
thread_shm_ptr
)
{
TORCH_CHECK
(
rank
<
MAX_SHM_RANK_NUM
);
TORCH_CHECK
(
ptr
);
...
...
@@ -84,23 +95,27 @@ struct ThreadSHMContext {
T
*
get_thread_shm_ptr
(
int
rank
)
{
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
int8_t
*>
(
thread_shm_ptrs
[
rank
])
+
(
PER_THREAD_SHM_BUFFER_OFFSET
&
_thread_buffer_mask
));
(
PER_THREAD_SHM_BUFFER_OFFSET
&
_thread_buffer_mask
[
local_stamp_buffer_idx
]));
}
void
next_buffer
()
{
_thread_buffer_mask
^=
0xFFFFFFFFFFFFFFFF
;
}
void
next_buffer
()
{
_thread_buffer_mask
[
local_stamp_buffer_idx
]
^=
0xFFFFFFFFFFFFFFFF
;
}
char
get_curr_stamp
()
const
{
return
_curr_thread_stamp
;
}
char
get_curr_stamp
(
int
idx
)
const
{
return
_curr_thread_stamp
[
idx
]
;
}
char
get_ready_stamp
()
const
{
return
_ready_thread_stamp
;
}
char
get_ready_stamp
(
int
idx
)
const
{
return
_ready_thread_stamp
[
idx
]
;
}
void
next_stamp
()
{
_mm_mfence
();
_curr_thread_stamp
+=
1
;
_curr_thread_stamp
[
local_stamp_buffer_idx
]
+=
1
;
}
void
commit_ready_stamp
()
{
_mm_mfence
();
_ready_thread_stamp
=
_curr_thread_stamp
;
_ready_thread_stamp
[
local_stamp_buffer_idx
]
=
_curr_thread_stamp
[
local_stamp_buffer_idx
];
}
int
get_swizzled_rank
(
int
idx
)
{
return
swizzled_ranks
[
idx
];
}
...
...
@@ -117,10 +132,11 @@ struct ThreadSHMContext {
void
wait_for_one
(
int
rank
,
Cond
&&
cond
)
{
ThreadSHMContext
*
rank_ctx
=
shm_contexts
[
rank
];
for
(;;)
{
char
local_curr_stamp
=
get_curr_stamp
();
char
local_ready_stamp
=
get_ready_stamp
();
char
rank_curr_stamp
=
rank_ctx
->
get_curr_stamp
();
char
rank_ready_stamp
=
rank_ctx
->
get_ready_stamp
();
char
local_curr_stamp
=
get_curr_stamp
(
local_stamp_buffer_idx
);
char
local_ready_stamp
=
get_ready_stamp
(
local_stamp_buffer_idx
);
char
rank_curr_stamp
=
rank_ctx
->
get_curr_stamp
(
remote_stamp_buffer_idx
);
char
rank_ready_stamp
=
rank_ctx
->
get_ready_stamp
(
remote_stamp_buffer_idx
);
if
(
cond
(
local_curr_stamp
,
local_ready_stamp
,
rank_curr_stamp
,
rank_ready_stamp
))
{
break
;
...
...
@@ -361,6 +377,15 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) {
}
}
}
void
reset_threads_stamp_buffer_idx
(
ThreadSHMContext
*
ctx
,
int
local
,
int
remote
)
{
int
thread_num
=
ctx
->
thread_num
;
for
(
int
i
=
0
;
i
<
thread_num
;
++
i
)
{
ThreadSHMContext
*
thread_ctx
=
ctx
+
i
;
thread_ctx
->
set_stamp_buffer_idx
(
local
,
remote
);
}
}
};
// namespace shm_cc_ops
namespace
shm_cc_ops
{
...
...
@@ -632,6 +657,7 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx, int64_t dst,
TensorListMeta
*
metadata
=
new
(
metadata_tensor
.
data_ptr
())
TensorListMeta
();
metadata
->
bind_tensor_list
(
tensor_list_with_metadata
);
shm_cc_ops
::
reset_threads_stamp_buffer_idx
(
ctx
,
0
,
1
);
shm_cc_ops
::
shm_cc_loop
<
int8_t
>
(
ctx
,
metadata
->
total_bytes
,
[
&
](
ThreadSHMContext
*
thread_ctx
,
int64_t
data_offset
,
...
...
@@ -659,6 +685,7 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
torch
::
Tensor
metadata_tensor
=
torch
::
empty
({
sizeof
(
TensorListMeta
)},
options
);
shm_cc_ops
::
reset_threads_stamp_buffer_idx
(
ctx
,
1
,
0
);
ctx
->
wait_for_one
(
src
,
ThreadSHMContext
::
check_stamp_ready
);
shm_cc_ops
::
memcpy
(
metadata_tensor
.
data_ptr
(),
ctx
->
get_thread_shm_ptr
<
void
>
(
src
),
...
...
@@ -677,7 +704,7 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
ctx
,
metadata
.
total_bytes
,
[
&
](
ThreadSHMContext
*
thread_ctx
,
int64_t
data_offset
,
int64_t
data_elem_num
,
bool
fast_mode
)
{
ctx
->
wait_for_one
(
src
,
ThreadSHMContext
::
check_stamp_ready
);
thread_
ctx
->
wait_for_one
(
src
,
ThreadSHMContext
::
check_stamp_ready
);
int64_t
curr_shm_offset
=
0
;
while
(
curr_shm_offset
<
data_elem_num
)
{
MemPiece
frag
=
metadata
.
get_data
(
data_offset
+
curr_shm_offset
);
...
...
csrc/cpu/torch_bindings.cpp
View file @
711aa9d5
...
...
@@ -151,8 +151,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
impl
(
"rotary_embedding"
,
torch
::
kCPU
,
&
rotary_embedding
);
// Quantization
#ifdef
__AVX512F__
#if
def
ined(
__AVX512F__
) || defined(__aarch64__)
at
::
Tag
stride_tag
=
at
::
Tag
::
needs_fixed_stride_order
;
// Compute int8 quantized tensor for given scaling factor.
ops
.
def
(
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
...
...
csrc/cuda_compat.h
View file @
711aa9d5
...
...
@@ -4,10 +4,10 @@
#include <hip/hip_runtime.h>
#endif
#if
n
def
USE_ROCM
#define WARP_SIZE
32
#if
def
ined(
USE_ROCM
) && defined(__GFX9__)
#define WARP_SIZE
64
#else
#define WARP_SIZE
warpSize
#define WARP_SIZE
32
#endif
#ifndef USE_ROCM
...
...
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
View file @
711aa9d5
...
...
@@ -153,7 +153,7 @@ struct ScaledEpilogueBias
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
homogeneous_
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
...
...
@@ -210,7 +210,7 @@ struct ScaledEpilogueBiasAzp
EVTComputeAzp
>
;
using
ComputeScaleBiasA
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
homogeneous_
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
...
...
@@ -288,7 +288,7 @@ struct ScaledEpilogueBiasAzpToken
EVTComputeAcc
>
;
using
ComputeScaleBiasA
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
homogeneous_
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
...
...
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
View file @
711aa9d5
...
...
@@ -195,7 +195,7 @@ struct ScaledEpilogueBias
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
homogeneous_
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
...
...
@@ -238,7 +238,7 @@ struct ScaledEpilogueColumnBias
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
homogeneous_
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
...
...
@@ -295,7 +295,7 @@ struct ScaledEpilogueBiasAzp
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeScaleB
,
ScaleB
,
EVTComputeAzp
>
;
using
ComputeScaleBiasA
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
homogeneous_
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
...
...
@@ -371,7 +371,7 @@ struct ScaledEpilogueBiasAzpToken
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeScaleB
,
ScaleB
,
EVTComputeAcc
>
;
using
ComputeScaleBiasA
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
homogeneous_
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
...
...
csrc/layernorm_kernels.cu
View file @
711aa9d5
...
...
@@ -15,15 +15,16 @@ namespace vllm {
// TODO(woosuk): Further optimize this kernel.
template
<
typename
scalar_t
>
__global__
void
rms_norm_kernel
(
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
int64_t
input_stride
,
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_siz
e
+
idx
];
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
input_strid
e
+
idx
];
variance
+=
x
*
x
;
}
...
...
@@ -37,7 +38,7 @@ __global__ void rms_norm_kernel(
__syncthreads
();
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_siz
e
+
idx
];
float
x
=
(
float
)
input
[
blockIdx
.
x
*
input_strid
e
+
idx
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
}
...
...
@@ -50,7 +51,8 @@ __global__ void rms_norm_kernel(
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_kernel
(
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
int64_t
input_stride
,
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
...
...
@@ -59,6 +61,7 @@ fused_add_rms_norm_kernel(
static_assert
(
sizeof
(
_f16Vec
<
scalar_t
,
width
>
)
==
sizeof
(
scalar_t
)
*
width
);
const
int
vec_hidden_size
=
hidden_size
/
width
;
const
int64_t
vec_input_stride
=
input_stride
/
width
;
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
/* These and the argument pointers are all declared `restrict` as they are
...
...
@@ -73,7 +76,8 @@ fused_add_rms_norm_kernel(
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
_f16Vec
<
scalar_t
,
width
>
temp
=
input_v
[
id
];
int64_t
strided_id
=
blockIdx
.
x
*
vec_input_stride
+
idx
;
_f16Vec
<
scalar_t
,
width
>
temp
=
input_v
[
strided_id
];
temp
+=
residual_v
[
id
];
variance
+=
temp
.
sum_squares
();
residual_v
[
id
]
=
temp
;
...
...
@@ -90,10 +94,11 @@ fused_add_rms_norm_kernel(
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
int64_t
strided_id
=
blockIdx
.
x
*
vec_input_stride
+
idx
;
_f16Vec
<
scalar_t
,
width
>
temp
=
residual_v
[
id
];
temp
*=
s_variance
;
temp
*=
weight_v
[
idx
];
input_v
[
id
]
=
temp
;
input_v
[
strided_
id
]
=
temp
;
}
}
...
...
@@ -103,7 +108,8 @@ fused_add_rms_norm_kernel(
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_kernel
(
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
int64_t
input_stride
,
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
...
...
@@ -111,7 +117,7 @@ fused_add_rms_norm_kernel(
float
variance
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
scalar_t
z
=
input
[
blockIdx
.
x
*
hidden_siz
e
+
idx
];
scalar_t
z
=
input
[
blockIdx
.
x
*
input_strid
e
+
idx
];
z
+=
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
x
=
(
float
)
z
;
variance
+=
x
*
x
;
...
...
@@ -129,7 +135,7 @@ fused_add_rms_norm_kernel(
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
input
[
blockIdx
.
x
*
hidden_siz
e
+
idx
]
=
input
[
blockIdx
.
x
*
input_strid
e
+
idx
]
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
}
}
...
...
@@ -141,11 +147,12 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
double
epsilon
)
{
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
input
.
is_contiguous
()
);
TORCH_CHECK
(
input
.
stride
(
-
1
)
==
1
);
TORCH_CHECK
(
weight
.
is_contiguous
());
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
int64_t
input_stride
=
input
.
stride
(
-
2
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
...
...
@@ -153,26 +160,29 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel"
,
[
&
]
{
vllm
::
rms_norm_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
input_stride
,
weight
.
data_ptr
<
scalar_t
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width)
\
VLLM_DISPATCH_FLOATING_TYPES(
\
input.scalar_type(), "fused_add_rms_norm_kernel", [&] {
\
vllm::fused_add_rms_norm_kernel<scalar_t, width>
\
<<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(),
\
residual.data_ptr<scalar_t>(),
\
weight.data_ptr<scalar_t>(), epsilon
, \
num_tokens, hidden_size); \
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
<<<grid, block, 0, stream>>>(
\
input.data_ptr<scalar_t>(), input_stride,
\
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>()
, \
epsilon,
num_tokens, hidden_size);
\
});
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
residual
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
double
epsilon
)
{
TORCH_CHECK
(
residual
.
is_contiguous
());
TORCH_CHECK
(
weight
.
is_contiguous
());
int
hidden_size
=
input
.
size
(
-
1
);
int64_t
input_stride
=
input
.
stride
(
-
2
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
grid
(
num_tokens
);
...
...
@@ -194,9 +204,16 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
res_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
residual
.
data_ptr
());
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
res_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
)
{
constexpr
int
vector_width
=
8
;
constexpr
int
req_alignment_bytes
=
vector_width
*
2
;
// vector_width * sizeof(bfloat16 or float16) (float32
// falls back to non-vectorized version anyway)
bool
ptrs_are_aligned
=
inp_ptr
%
req_alignment_bytes
==
0
&&
res_ptr
%
req_alignment_bytes
==
0
&&
wt_ptr
%
req_alignment_bytes
==
0
;
bool
offsets_are_multiple_of_vector_width
=
hidden_size
%
vector_width
==
0
&&
input_stride
%
vector_width
==
0
;
if
(
ptrs_are_aligned
&&
offsets_are_multiple_of_vector_width
)
{
LAUNCH_FUSED_ADD_RMS_NORM
(
8
);
}
else
{
LAUNCH_FUSED_ADD_RMS_NORM
(
0
);
...
...
csrc/layernorm_quant_kernels.cu
View file @
711aa9d5
...
...
@@ -25,8 +25,9 @@ namespace vllm {
// TODO(woosuk): Further optimize this kernel.
template
<
typename
scalar_t
,
typename
fp8_type
>
__global__
void
rms_norm_static_fp8_quant_kernel
(
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
int
input_stride
,
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
*
__restrict__
scale
,
// [1]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
...
...
@@ -34,7 +35,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
float
variance
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_siz
e
+
idx
];
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
input_strid
e
+
idx
];
variance
+=
x
*
x
;
}
...
...
@@ -51,7 +52,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
float
const
scale_inv
=
1.0
f
/
*
scale
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_siz
e
+
idx
];
float
x
=
(
float
)
input
[
blockIdx
.
x
*
input_strid
e
+
idx
];
float
const
out_norm
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
...
...
@@ -65,8 +66,9 @@ __global__ void rms_norm_static_fp8_quant_kernel(
template
<
typename
scalar_t
,
int
width
,
typename
fp8_type
>
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_static_fp8_quant_kernel
(
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
int
input_stride
,
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
*
__restrict__
scale
,
// [1]
...
...
@@ -76,6 +78,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
static_assert
(
sizeof
(
_f16Vec
<
scalar_t
,
width
>
)
==
sizeof
(
scalar_t
)
*
width
);
const
int
vec_hidden_size
=
hidden_size
/
width
;
const
int
vec_input_stride
=
input_stride
/
width
;
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
/* These and the argument pointers are all declared `restrict` as they are
...
...
@@ -89,8 +92,9 @@ fused_add_rms_norm_static_fp8_quant_kernel(
reinterpret_cast
<
const
_f16Vec
<
scalar_t
,
width
>*>
(
weight
);
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
int
stride_id
=
blockIdx
.
x
*
vec_input_stride
+
idx
;
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
_f16Vec
<
scalar_t
,
width
>
temp
=
input_v
[
id
];
_f16Vec
<
scalar_t
,
width
>
temp
=
input_v
[
stride_
id
];
temp
+=
residual_v
[
id
];
variance
+=
temp
.
sum_squares
();
residual_v
[
id
]
=
temp
;
...
...
@@ -127,8 +131,9 @@ fused_add_rms_norm_static_fp8_quant_kernel(
template
<
typename
scalar_t
,
int
width
,
typename
fp8_type
>
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_static_fp8_quant_kernel
(
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
int
input_stride
,
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
*
__restrict__
scale
,
// [1]
...
...
@@ -137,7 +142,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
float
variance
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
scalar_t
z
=
input
[
blockIdx
.
x
*
hidden_siz
e
+
idx
];
scalar_t
z
=
input
[
blockIdx
.
x
*
input_strid
e
+
idx
];
z
+=
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
x
=
(
float
)
z
;
variance
+=
x
*
x
;
...
...
@@ -171,7 +176,9 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
torch
::
Tensor
&
scale
,
// [1]
double
epsilon
)
{
TORCH_CHECK
(
out
.
is_contiguous
());
int
hidden_size
=
input
.
size
(
-
1
);
int
input_stride
=
input
.
stride
(
-
2
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
grid
(
num_tokens
);
...
...
@@ -185,8 +192,9 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
vllm
::
rms_norm_static_fp8_quant_kernel
<
scalar_t
,
fp8_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
fp8_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
epsilon
,
num_tokens
,
hidden_size
);
input_stride
,
weight
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
});
}
...
...
@@ -200,7 +208,7 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
width, fp8_t> \
<<<grid, block, 0, stream>>>( \
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(),
\
input_stride,
residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
epsilon, num_tokens, hidden_size); \
}); \
...
...
@@ -212,7 +220,10 @@ void fused_add_rms_norm_static_fp8_quant(
torch
::
Tensor
&
weight
,
// [hidden_size]
torch
::
Tensor
&
scale
,
// [1]
double
epsilon
)
{
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
residual
.
is_contiguous
());
int
hidden_size
=
input
.
size
(
-
1
);
int
input_stride
=
input
.
stride
(
-
2
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
grid
(
num_tokens
);
...
...
@@ -236,7 +247,7 @@ void fused_add_rms_norm_static_fp8_quant(
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
res_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
)
{
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
&&
input_stride
%
8
==
0
)
{
LAUNCH_FUSED_ADD_RMS_NORM
(
8
);
}
else
{
LAUNCH_FUSED_ADD_RMS_NORM
(
0
);
...
...
csrc/mamba/causal_conv1d/causal_conv1d.cu
deleted
100644 → 0
View file @
751c492c
// clang-format off
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu
// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "causal_conv1d.h"
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
#include <cub/block/block_load.cuh>
#include <cub/block/block_store.cuh>
#ifdef USE_ROCM
namespace
cub
=
hipcub
;
#endif
#include "static_switch.h"
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
if (ITYPE == at::ScalarType::Half) { \
using input_t = at::Half; \
using weight_t = at::Half; \
__VA_ARGS__(); \
} else if (ITYPE == at::ScalarType::BFloat16) { \
using input_t = at::BFloat16; \
using weight_t = at::BFloat16; \
__VA_ARGS__(); \
} else if (ITYPE == at::ScalarType::Float) { \
using input_t = float; \
using weight_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
}
template
<
typename
input_t
,
typename
weight_t
>
void
causal_conv1d_fwd_cuda
(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
template
<
typename
input_t
,
typename
weight_t
>
void
causal_conv1d_update_cuda
(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
void
set_conv_params_fwd
(
ConvParamsBase
&
params
,
// sizes
const
size_t
batch
,
const
size_t
dim
,
const
size_t
seqlen
,
const
size_t
width
,
// device pointers
const
at
::
Tensor
x
,
const
at
::
Tensor
weight
,
const
at
::
Tensor
out
,
const
std
::
optional
<
at
::
Tensor
>&
bias
,
bool
silu_activation
,
int64_t
pad_slot_id
,
const
std
::
optional
<
at
::
Tensor
>&
query_start_loc
=
std
::
nullopt
,
const
std
::
optional
<
at
::
Tensor
>&
cache_indices
=
std
::
nullopt
,
const
std
::
optional
<
at
::
Tensor
>&
has_initial_state
=
std
::
nullopt
)
{
// Reset the parameters
memset
(
&
params
,
0
,
sizeof
(
params
));
params
.
batch
=
batch
;
params
.
dim
=
dim
;
params
.
seqlen
=
seqlen
;
params
.
width
=
width
;
params
.
pad_slot_id
=
pad_slot_id
;
params
.
silu_activation
=
silu_activation
;
// Set the pointers and strides.
params
.
x_ptr
=
x
.
data_ptr
();
params
.
weight_ptr
=
weight
.
data_ptr
();
params
.
bias_ptr
=
bias
.
has_value
()
?
bias
.
value
().
data_ptr
()
:
nullptr
;
params
.
out_ptr
=
out
.
data_ptr
();
// All stride are in elements, not bytes.
params
.
query_start_loc_ptr
=
query_start_loc
.
has_value
()
?
query_start_loc
.
value
().
data_ptr
()
:
nullptr
;
params
.
cache_indices_ptr
=
cache_indices
.
has_value
()
?
cache_indices
.
value
().
data_ptr
()
:
nullptr
;
params
.
has_initial_state_ptr
=
has_initial_state
.
has_value
()
?
has_initial_state
.
value
().
data_ptr
()
:
nullptr
;
const
bool
varlen
=
params
.
query_start_loc_ptr
!=
nullptr
;
params
.
x_batch_stride
=
x
.
stride
(
varlen
?
1
:
0
);
params
.
x_c_stride
=
x
.
stride
(
varlen
?
0
:
1
);
params
.
x_l_stride
=
x
.
stride
(
varlen
?
1
:
-
1
);
params
.
weight_c_stride
=
weight
.
stride
(
0
);
params
.
weight_width_stride
=
weight
.
stride
(
1
);
params
.
out_batch_stride
=
out
.
stride
(
varlen
?
1
:
0
);
params
.
out_c_stride
=
out
.
stride
(
varlen
?
0
:
1
);
params
.
out_l_stride
=
out
.
stride
(
varlen
?
1
:
-
1
);
}
void
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
std
::
optional
<
at
::
Tensor
>
&
bias_
,
const
std
::
optional
<
at
::
Tensor
>
&
conv_states
,
const
std
::
optional
<
at
::
Tensor
>
&
query_start_loc
,
const
std
::
optional
<
at
::
Tensor
>
&
cache_indices
,
const
std
::
optional
<
at
::
Tensor
>
&
has_initial_state
,
bool
silu_activation
,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t
pad_slot_id
)
{
auto
input_type
=
x
.
scalar_type
();
auto
weight_type
=
weight
.
scalar_type
();
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
TORCH_CHECK
(
weight_type
==
at
::
ScalarType
::
Float
||
weight_type
==
at
::
ScalarType
::
Half
||
weight_type
==
at
::
ScalarType
::
BFloat16
);
TORCH_CHECK
(
x
.
is_cuda
());
TORCH_CHECK
(
weight
.
is_cuda
());
const
bool
varlen
=
query_start_loc
.
has_value
()
?
true
:
false
;
const
auto
sizes
=
x
.
sizes
();
const
int
batch_size
=
varlen
?
query_start_loc
.
value
().
sizes
()[
0
]
-
1
:
sizes
[
0
];
const
int
dim
=
varlen
?
sizes
[
0
]
:
sizes
[
1
];
const
int
seqlen
=
varlen
?
sizes
[
1
]
:
sizes
[
2
];
const
int
width
=
weight
.
size
(
-
1
);
if
(
varlen
){
CHECK_SHAPE
(
x
,
dim
,
seqlen
);
}
else
{
CHECK_SHAPE
(
x
,
batch_size
,
dim
,
seqlen
);
}
CHECK_SHAPE
(
weight
,
dim
,
width
);
if
(
bias_
.
has_value
())
{
auto
bias
=
bias_
.
value
();
TORCH_CHECK
(
bias
.
scalar_type
()
==
weight_type
);
TORCH_CHECK
(
bias
.
is_cuda
());
TORCH_CHECK
(
bias
.
stride
(
-
1
)
==
1
);
CHECK_SHAPE
(
bias
,
dim
);
}
if
(
has_initial_state
.
has_value
())
{
auto
has_initial_state_
=
has_initial_state
.
value
();
TORCH_CHECK
(
has_initial_state_
.
scalar_type
()
==
at
::
ScalarType
::
Bool
);
TORCH_CHECK
(
has_initial_state_
.
is_cuda
());
CHECK_SHAPE
(
has_initial_state_
,
batch_size
);
}
if
(
query_start_loc
.
has_value
())
{
auto
query_start_loc_
=
query_start_loc
.
value
();
TORCH_CHECK
(
query_start_loc_
.
scalar_type
()
==
at
::
ScalarType
::
Int
);
TORCH_CHECK
(
query_start_loc_
.
is_cuda
());
}
if
(
cache_indices
.
has_value
())
{
auto
cache_indices_
=
cache_indices
.
value
();
TORCH_CHECK
(
cache_indices_
.
scalar_type
()
==
at
::
ScalarType
::
Int
);
TORCH_CHECK
(
cache_indices_
.
is_cuda
());
CHECK_SHAPE
(
cache_indices_
,
batch_size
);
}
at
::
Tensor
out
=
x
;
ConvParamsBase
params
;
set_conv_params_fwd
(
params
,
batch_size
,
dim
,
seqlen
,
width
,
x
,
weight
,
out
,
bias_
,
silu_activation
,
pad_slot_id
,
query_start_loc
,
cache_indices
,
has_initial_state
);
if
(
conv_states
.
has_value
())
{
auto
conv_states_
=
conv_states
.
value
();
TORCH_CHECK
(
conv_states_
.
scalar_type
()
==
input_type
);
TORCH_CHECK
(
conv_states_
.
is_cuda
());
params
.
conv_states_ptr
=
conv_states_
.
data_ptr
();
params
.
conv_states_batch_stride
=
conv_states_
.
stride
(
0
);
params
.
conv_states_c_stride
=
conv_states_
.
stride
(
1
);
params
.
conv_states_l_stride
=
conv_states_
.
stride
(
2
);
}
else
{
params
.
conv_states_ptr
=
nullptr
;
}
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
x
));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16
(
x
.
scalar_type
(),
"causal_conv1d_fwd"
,
[
&
]
{
causal_conv1d_fwd_cuda
<
input_t
,
weight_t
>
(
params
,
stream
);
});
}
void
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
std
::
optional
<
at
::
Tensor
>
&
bias_
,
bool
silu_activation
,
const
std
::
optional
<
at
::
Tensor
>
&
cache_seqlens_
,
const
std
::
optional
<
at
::
Tensor
>
&
conv_state_indices_
,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t
pad_slot_id
)
{
auto
input_type
=
x
.
scalar_type
();
auto
weight_type
=
weight
.
scalar_type
();
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
TORCH_CHECK
(
weight_type
==
at
::
ScalarType
::
Float
||
weight_type
==
at
::
ScalarType
::
Half
||
weight_type
==
at
::
ScalarType
::
BFloat16
);
TORCH_CHECK
(
weight_type
==
input_type
,
"weight type must equal to input type, other variations are disabled due to binary size limitations"
);
TORCH_CHECK
(
conv_state
.
scalar_type
()
==
input_type
);
TORCH_CHECK
(
x
.
is_cuda
());
TORCH_CHECK
(
conv_state
.
is_cuda
());
TORCH_CHECK
(
weight
.
is_cuda
());
const
auto
sizes
=
x
.
sizes
();
const
int
batch_size
=
sizes
[
0
];
const
int
dim
=
sizes
[
1
];
const
int
seqlen
=
sizes
[
2
];
const
int
width
=
weight
.
size
(
-
1
);
const
int
conv_state_len
=
conv_state
.
size
(
2
);
TORCH_CHECK
(
conv_state_len
>=
width
-
1
);
CHECK_SHAPE
(
x
,
batch_size
,
dim
,
seqlen
);
CHECK_SHAPE
(
weight
,
dim
,
width
);
TORCH_CHECK
(
width
>=
2
&&
width
<=
4
,
"causal_conv1d only supports width between 2 and 4"
);
if
(
bias_
.
has_value
())
{
auto
bias
=
bias_
.
value
();
TORCH_CHECK
(
bias
.
scalar_type
()
==
weight_type
);
TORCH_CHECK
(
bias
.
is_cuda
());
TORCH_CHECK
(
bias
.
stride
(
-
1
)
==
1
);
CHECK_SHAPE
(
bias
,
dim
);
}
at
::
Tensor
out
=
x
;
ConvParamsBase
params
;
set_conv_params_fwd
(
params
,
batch_size
,
dim
,
seqlen
,
width
,
x
,
weight
,
out
,
bias_
,
silu_activation
,
pad_slot_id
);
params
.
conv_state_ptr
=
conv_state
.
data_ptr
();
params
.
conv_state_len
=
conv_state_len
;
// All stride are in elements, not bytes.
params
.
conv_state_batch_stride
=
conv_state
.
stride
(
0
);
params
.
conv_state_c_stride
=
conv_state
.
stride
(
1
);
params
.
conv_state_l_stride
=
conv_state
.
stride
(
2
);
if
(
cache_seqlens_
.
has_value
())
{
auto
cache_seqlens
=
cache_seqlens_
.
value
();
TORCH_CHECK
(
cache_seqlens
.
scalar_type
()
==
torch
::
kInt32
);
TORCH_CHECK
(
cache_seqlens
.
is_cuda
());
TORCH_CHECK
(
cache_seqlens
.
stride
(
-
1
)
==
1
);
CHECK_SHAPE
(
cache_seqlens
,
batch_size
);
params
.
cache_seqlens
=
cache_seqlens
.
data_ptr
<
int32_t
>
();
}
else
{
params
.
cache_seqlens
=
nullptr
;
}
if
(
conv_state_indices_
.
has_value
())
{
auto
conv_state_indices
=
conv_state_indices_
.
value
();
TORCH_CHECK
(
conv_state_indices
.
scalar_type
()
==
torch
::
kInt32
)
TORCH_CHECK
(
conv_state_indices
.
is_cuda
());
TORCH_CHECK
(
conv_state_indices
.
stride
(
0
)
==
1
)
CHECK_SHAPE
(
conv_state_indices
,
batch_size
);
int
conv_state_entries
=
conv_state
.
size
(
0
);
CHECK_SHAPE
(
conv_state
,
conv_state_entries
,
dim
,
conv_state_len
);
params
.
conv_state_indices_ptr
=
conv_state_indices
.
data_ptr
<
int32_t
>
();
}
else
{
CHECK_SHAPE
(
conv_state
,
batch_size
,
dim
,
conv_state_len
);
params
.
conv_state_indices_ptr
=
nullptr
;
}
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
x
));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16
(
x
.
scalar_type
(),
"causal_conv1d_update"
,
[
&
]
{
causal_conv1d_update_cuda
<
input_t
,
weight_t
>
(
params
,
stream
);
});
}
template
<
int
kNThreads_
,
int
kWidth_
,
bool
kIsVecLoad_
,
typename
input_t_
,
typename
weight_t_
>
struct
Causal_conv1d_fwd_kernel_traits
{
using
input_t
=
input_t_
;
using
weight_t
=
weight_t_
;
static
constexpr
int
kNThreads
=
kNThreads_
;
static
constexpr
int
kWidth
=
kWidth_
;
static
constexpr
int
kNBytes
=
sizeof
(
input_t
);
static_assert
(
kNBytes
==
2
||
kNBytes
==
4
);
static
constexpr
int
kNElts
=
kNBytes
==
4
?
4
:
8
;
static_assert
(
kWidth
<=
kNElts
);
static
constexpr
bool
kIsVecLoad
=
kIsVecLoad_
;
using
vec_t
=
typename
BytesToType
<
kNBytes
*
kNElts
>::
Type
;
using
BlockLoadT
=
cub
::
BlockLoad
<
input_t
,
kNThreads
,
kNElts
,
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
>
;
using
BlockLoadVecT
=
cub
::
BlockLoad
<
vec_t
,
kNThreads
,
1
,
cub
::
BLOCK_LOAD_DIRECT
>
;
using
BlockStoreT
=
cub
::
BlockStore
<
input_t
,
kNThreads
,
kNElts
,
cub
::
BLOCK_STORE_WARP_TRANSPOSE
>
;
using
BlockStoreVecT
=
cub
::
BlockStore
<
vec_t
,
kNThreads
,
1
,
cub
::
BLOCK_STORE_DIRECT
>
;
static
constexpr
int
kSmemIOSize
=
kIsVecLoad
?
0
:
custom_max
({
sizeof
(
typename
BlockLoadT
::
TempStorage
),
sizeof
(
typename
BlockStoreT
::
TempStorage
)});
static
constexpr
int
kSmemExchangeSize
=
kNThreads
*
kNBytes
*
kNElts
;
static
constexpr
int
kSmemSize
=
kSmemIOSize
+
kSmemExchangeSize
;
};
template
<
typename
Ktraits
>
__global__
__launch_bounds__
(
Ktraits
::
kNThreads
)
void
causal_conv1d_fwd_kernel
(
ConvParamsBase
params
)
{
constexpr
int
kWidth
=
Ktraits
::
kWidth
;
constexpr
int
kNThreads
=
Ktraits
::
kNThreads
;
constexpr
int
kNElts
=
Ktraits
::
kNElts
;
constexpr
bool
kIsVecLoad
=
Ktraits
::
kIsVecLoad
;
using
input_t
=
typename
Ktraits
::
input_t
;
using
vec_t
=
typename
Ktraits
::
vec_t
;
using
weight_t
=
typename
Ktraits
::
weight_t
;
// Shared memory.
extern
__shared__
char
smem_
[];
auto
&
smem_load
=
reinterpret_cast
<
typename
Ktraits
::
BlockLoadT
::
TempStorage
&>
(
smem_
);
auto
&
smem_load_vec
=
reinterpret_cast
<
typename
Ktraits
::
BlockLoadVecT
::
TempStorage
&>
(
smem_
);
auto
&
smem_store
=
reinterpret_cast
<
typename
Ktraits
::
BlockStoreT
::
TempStorage
&>
(
smem_
);
auto
&
smem_store_vec
=
reinterpret_cast
<
typename
Ktraits
::
BlockStoreVecT
::
TempStorage
&>
(
smem_
);
vec_t
*
smem_exchange
=
reinterpret_cast
<
vec_t
*>
(
smem_
+
Ktraits
::
kSmemIOSize
);
const
bool
kVarlen
=
params
.
query_start_loc_ptr
!=
nullptr
;
const
int
tidx
=
threadIdx
.
x
;
const
int
batch_id
=
blockIdx
.
x
;
const
int
channel_id
=
blockIdx
.
y
;
const
int
*
query_start_loc
=
kVarlen
?
reinterpret_cast
<
int
*>
(
params
.
query_start_loc_ptr
)
:
nullptr
;
const
int
sequence_start_index
=
kVarlen
?
query_start_loc
[
batch_id
]
:
batch_id
;
const
int
seqlen
=
kVarlen
?
query_start_loc
[
batch_id
+
1
]
-
sequence_start_index
:
params
.
seqlen
;
input_t
*
x
=
reinterpret_cast
<
input_t
*>
(
params
.
x_ptr
)
+
sequence_start_index
*
params
.
x_batch_stride
+
channel_id
*
params
.
x_c_stride
;
weight_t
*
weight
=
reinterpret_cast
<
weight_t
*>
(
params
.
weight_ptr
)
+
channel_id
*
params
.
weight_c_stride
;
input_t
*
out
=
reinterpret_cast
<
input_t
*>
(
params
.
out_ptr
)
+
sequence_start_index
*
params
.
out_batch_stride
+
channel_id
*
params
.
out_c_stride
;
float
bias_val
=
params
.
bias_ptr
==
nullptr
?
0.
f
:
float
(
reinterpret_cast
<
weight_t
*>
(
params
.
bias_ptr
)[
channel_id
]);
bool
has_initial_state
=
params
.
has_initial_state_ptr
==
nullptr
?
false
:
reinterpret_cast
<
bool
*>
(
params
.
has_initial_state_ptr
)[
batch_id
];
int
*
cache_indices
=
params
.
cache_indices_ptr
==
nullptr
?
nullptr
:
reinterpret_cast
<
int
*>
(
params
.
cache_indices_ptr
);
int
cache_index
=
cache_indices
==
nullptr
?
batch_id
:
cache_indices
[
batch_id
];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if
(
cache_index
==
params
.
pad_slot_id
){
return
;
}
input_t
*
conv_states
=
params
.
conv_states_ptr
==
nullptr
?
nullptr
:
reinterpret_cast
<
input_t
*>
(
params
.
conv_states_ptr
)
+
cache_index
*
params
.
conv_states_batch_stride
+
channel_id
*
params
.
conv_states_c_stride
;
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
if
(
tidx
==
0
)
{
input_t
initial_state
[
kNElts
]
=
{
0
};
if
(
has_initial_state
)
{
#pragma unroll
for
(
int
w
=
0
;
w
<
kWidth
-
1
;
++
w
){
initial_state
[
kNElts
-
1
-
(
kWidth
-
2
)
+
w
]
=
conv_states
[
w
];
}
}
smem_exchange
[
kNThreads
-
1
]
=
reinterpret_cast
<
vec_t
*>
(
initial_state
)[
0
];
}
float
weight_vals
[
kWidth
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kWidth
;
++
i
)
{
weight_vals
[
i
]
=
float
(
weight
[
i
*
params
.
weight_width_stride
]);
}
constexpr
int
kChunkSize
=
kNThreads
*
kNElts
;
const
int
n_chunks
=
(
seqlen
+
kChunkSize
-
1
)
/
kChunkSize
;
for
(
int
chunk
=
0
;
chunk
<
n_chunks
;
++
chunk
)
{
input_t
x_vals_load
[
2
*
kNElts
]
=
{
0
};
if
constexpr
(
kIsVecLoad
)
{
typename
Ktraits
::
BlockLoadVecT
(
smem_load_vec
).
Load
(
reinterpret_cast
<
vec_t
*>
(
x
),
*
reinterpret_cast
<
vec_t
(
*
)[
1
]
>
(
&
x_vals_load
[
kNElts
]),
(
seqlen
-
chunk
*
kChunkSize
)
/
kNElts
);
}
else
{
__syncthreads
();
typename
Ktraits
::
BlockLoadT
(
smem_load
).
Load
(
x
,
*
reinterpret_cast
<
input_t
(
*
)[
kNElts
]
>
(
&
x_vals_load
[
kNElts
]),
seqlen
-
chunk
*
kChunkSize
);
}
x
+=
kChunkSize
;
__syncthreads
();
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
// the last elements of the previous chunk.
if
(
tidx
<
kNThreads
-
1
)
{
smem_exchange
[
tidx
]
=
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
1
];
}
__syncthreads
();
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
0
]
=
smem_exchange
[
tidx
>
0
?
tidx
-
1
:
kNThreads
-
1
];
__syncthreads
();
// Now thread kNThreads - 1 can write the last elements of the current chunk.
if
(
tidx
==
kNThreads
-
1
)
{
smem_exchange
[
tidx
]
=
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
1
];
}
float
x_vals
[
2
*
kNElts
];
#pragma unroll
for
(
int
i
=
0
;
i
<
2
*
kNElts
;
++
i
)
{
x_vals
[
i
]
=
float
(
x_vals_load
[
i
]);
}
float
out_vals
[
kNElts
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kNElts
;
++
i
)
{
out_vals
[
i
]
=
bias_val
;
#pragma unroll
for
(
int
w
=
0
;
w
<
kWidth
;
++
w
)
{
out_vals
[
i
]
+=
weight_vals
[
w
]
*
x_vals
[
kNElts
+
i
-
(
kWidth
-
w
-
1
)];
}
}
if
(
params
.
silu_activation
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
kNElts
;
++
i
)
{
out_vals
[
i
]
=
out_vals
[
i
]
/
(
1
+
expf
(
-
out_vals
[
i
]));
}
}
input_t
out_vals_store
[
kNElts
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kNElts
;
++
i
)
{
out_vals_store
[
i
]
=
out_vals
[
i
];
}
if
constexpr
(
kIsVecLoad
)
{
typename
Ktraits
::
BlockStoreVecT
(
smem_store_vec
).
Store
(
reinterpret_cast
<
vec_t
*>
(
out
),
reinterpret_cast
<
vec_t
(
&
)[
1
]
>
(
out_vals_store
),
(
seqlen
-
chunk
*
kChunkSize
)
/
kNElts
);
}
else
{
typename
Ktraits
::
BlockStoreT
(
smem_store
).
Store
(
out
,
out_vals_store
,
seqlen
-
chunk
*
kChunkSize
);
}
out
+=
kChunkSize
;
int
final_state_position
=
((
seqlen
-
(
kWidth
-
1
))
-
(
n_chunks
-
1
)
*
kChunkSize
);
// in case the final state is separated between the last "smem_exchange" and
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
// (which occurs when `final_state_position` is a non-positive index)
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
if
(
conv_states
!=
nullptr
&&
final_state_position
<
0
&&
seqlen
>
kWidth
){
input_t
vals_load
[
kNElts
]
=
{
0
};
if
((
chunk
==
n_chunks
-
2
)
&&
(
tidx
==
kNThreads
-
1
)){
// chunk = n_chunks - 2, a segment of the final state sits in the last index
reinterpret_cast
<
vec_t
*>
(
vals_load
)[
0
]
=
smem_exchange
[
kNThreads
-
1
];
#pragma unroll
for
(
int
w
=
0
;
w
<
-
final_state_position
;
++
w
){
conv_states
[
w
]
=
vals_load
[
kNElts
+
final_state_position
+
w
];
}
}
if
((
chunk
==
n_chunks
-
1
)
&&
tidx
==
0
){
// chunk = n_chunks - 1, the second segment of the final state first positions
reinterpret_cast
<
vec_t
*>
(
vals_load
)[
0
]
=
smem_exchange
[
0
];
for
(
int
w
=
-
final_state_position
;
w
<
kWidth
-
1
;
++
w
){
conv_states
[
w
]
=
vals_load
[
w
+
final_state_position
];
}
return
;
}
}
}
// Final state is stored in the smem_exchange last token slot,
// in case seqlen < kWidth, we would need to take the final state from the
// initial state which is stored in conv_states
// in case seqlen > kWidth, we would need to load the last kWidth - 1 data
// and load it into conv_state accordingly
int
last_thread
=
((
seqlen
-
(
kWidth
-
1
))
-
(
n_chunks
-
1
)
*
kChunkSize
)
/
kNElts
;
if
(
conv_states
!=
nullptr
&&
tidx
==
last_thread
)
{
input_t
x_vals_load
[
kNElts
*
2
]
=
{
0
};
// in case we are on the first kWidth tokens
if
(
last_thread
==
0
&&
seqlen
<
kWidth
){
// Need to take the initial state
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
0
]
=
smem_exchange
[
0
];
const
int
offset
=
seqlen
-
(
kWidth
-
1
);
#pragma unroll
for
(
int
w
=
0
;
w
<
kWidth
-
1
;
++
w
){
// pad the existing state
if
((
w
-
seqlen
)
>=
0
&&
has_initial_state
)
{
conv_states
[
w
-
seqlen
]
=
conv_states
[
w
];
}
else
if
((
w
-
seqlen
)
>=
0
&&
!
has_initial_state
)
{
conv_states
[
w
-
seqlen
]
=
input_t
(
0.0
f
);
}
}
#pragma unroll
for
(
int
w
=
0
;
w
<
kWidth
-
1
;
++
w
){
if
(
offset
+
w
>=
0
)
conv_states
[
w
]
=
x_vals_load
[
offset
+
w
];
}
}
else
{
// in case the final state is in between the threads data
const
int
offset
=
((
seqlen
-
(
kWidth
-
1
))
%
(
kNElts
));
if
((
offset
+
kWidth
-
2
)
>=
kNElts
&&
(
last_thread
+
1
<
kNThreads
)){
// In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a
// illegal access error on H100.
// Therefore, we access last_thread + 1, only if the final state data sits there
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
1
]
=
smem_exchange
[
last_thread
+
1
];
}
reinterpret_cast
<
vec_t
*>
(
x_vals_load
)[
0
]
=
smem_exchange
[
last_thread
];
#pragma unroll
for
(
int
w
=
0
;
w
<
kWidth
-
1
;
++
w
){
conv_states
[
w
]
=
x_vals_load
[
offset
+
w
];
}
}
}
}
template
<
int
kNThreads
,
int
kWidth
,
typename
input_t
,
typename
weight_t
>
void
causal_conv1d_fwd_launch
(
ConvParamsBase
&
params
,
cudaStream_t
stream
)
{
static
constexpr
int
kNElts
=
sizeof
(
input_t
)
==
4
?
4
:
8
;
const
bool
kVarlen
=
params
.
query_start_loc_ptr
!=
nullptr
;
BOOL_SWITCH
(
params
.
seqlen
%
kNElts
==
0
&&
!
kVarlen
,
kIsVecLoad
,
[
&
]
{
using
Ktraits
=
Causal_conv1d_fwd_kernel_traits
<
kNThreads
,
kWidth
,
kIsVecLoad
,
input_t
,
weight_t
>
;
constexpr
int
kSmemSize
=
Ktraits
::
kSmemSize
;
dim3
grid
(
params
.
batch
,
params
.
dim
);
auto
kernel
=
&
causal_conv1d_fwd_kernel
<
Ktraits
>
;
if
(
kSmemSize
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
(
void
*
)
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kSmemSize
));
std
::
cerr
<<
"Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior.
\n
"
<<
std
::
endl
;
}
kernel
<<<
grid
,
Ktraits
::
kNThreads
,
kSmemSize
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
}
template
<
typename
input_t
,
typename
weight_t
>
void
causal_conv1d_fwd_cuda
(
ConvParamsBase
&
params
,
cudaStream_t
stream
)
{
if
(
params
.
width
==
2
)
{
causal_conv1d_fwd_launch
<
128
,
2
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
width
==
3
)
{
causal_conv1d_fwd_launch
<
128
,
3
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
width
==
4
)
{
causal_conv1d_fwd_launch
<
128
,
4
,
input_t
,
weight_t
>
(
params
,
stream
);
}
}
template
void
causal_conv1d_fwd_cuda
<
float
,
float
>(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
template
void
causal_conv1d_fwd_cuda
<
at
::
Half
,
at
::
Half
>(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
template
void
causal_conv1d_fwd_cuda
<
at
::
BFloat16
,
at
::
BFloat16
>(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
template
<
int
kNThreads_
,
int
kWidth_
,
typename
input_t_
,
typename
weight_t_
>
struct
Causal_conv1d_update_kernel_traits
{
using
input_t
=
input_t_
;
using
weight_t
=
weight_t_
;
static
constexpr
int
kNThreads
=
kNThreads_
;
static
constexpr
int
kWidth
=
kWidth_
;
static
constexpr
int
kNBytes
=
sizeof
(
input_t
);
static_assert
(
kNBytes
==
2
||
kNBytes
==
4
);
};
template
<
typename
Ktraits
,
bool
kIsCircularBuffer
>
__global__
__launch_bounds__
(
Ktraits
::
kNThreads
)
void
causal_conv1d_update_kernel
(
ConvParamsBase
params
)
{
constexpr
int
kWidth
=
Ktraits
::
kWidth
;
constexpr
int
kNThreads
=
Ktraits
::
kNThreads
;
using
input_t
=
typename
Ktraits
::
input_t
;
using
weight_t
=
typename
Ktraits
::
weight_t
;
const
int
tidx
=
threadIdx
.
x
;
const
int
batch_id
=
blockIdx
.
x
;
const
int
channel_id
=
blockIdx
.
y
*
kNThreads
+
tidx
;
if
(
channel_id
>=
params
.
dim
)
return
;
input_t
*
x
=
reinterpret_cast
<
input_t
*>
(
params
.
x_ptr
)
+
batch_id
*
params
.
x_batch_stride
+
channel_id
*
params
.
x_c_stride
;
// If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
// along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
const
int
conv_state_batch_coord
=
params
.
conv_state_indices_ptr
==
nullptr
?
batch_id
:
params
.
conv_state_indices_ptr
[
batch_id
];
// conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early
if
(
conv_state_batch_coord
==
params
.
pad_slot_id
){
return
;
}
input_t
*
conv_state
=
reinterpret_cast
<
input_t
*>
(
params
.
conv_state_ptr
)
+
conv_state_batch_coord
*
params
.
conv_state_batch_stride
+
channel_id
*
params
.
conv_state_c_stride
;
weight_t
*
weight
=
reinterpret_cast
<
weight_t
*>
(
params
.
weight_ptr
)
+
channel_id
*
params
.
weight_c_stride
;
input_t
*
out
=
reinterpret_cast
<
input_t
*>
(
params
.
out_ptr
)
+
batch_id
*
params
.
out_batch_stride
+
channel_id
*
params
.
out_c_stride
;
float
bias_val
=
params
.
bias_ptr
==
nullptr
?
0.
f
:
float
(
reinterpret_cast
<
weight_t
*>
(
params
.
bias_ptr
)[
channel_id
]);
int
state_len
=
params
.
conv_state_len
;
int
advance_len
=
params
.
seqlen
;
int
cache_seqlen
=
kIsCircularBuffer
?
params
.
cache_seqlens
[
batch_id
]
%
state_len
:
0
;
int
update_idx
=
cache_seqlen
-
(
kWidth
-
1
);
update_idx
=
update_idx
<
0
?
update_idx
+
state_len
:
update_idx
;
float
weight_vals
[
kWidth
]
=
{
0
};
#pragma unroll
for
(
int
i
=
0
;
i
<
kWidth
;
++
i
)
{
weight_vals
[
i
]
=
float
(
weight
[
i
*
params
.
weight_width_stride
]);
}
float
x_vals
[
kWidth
]
=
{
0
};
if
constexpr
(
!
kIsCircularBuffer
)
{
#pragma unroll 2
for
(
int
i
=
0
;
i
<
state_len
-
advance_len
-
(
kWidth
-
1
);
++
i
)
{
conv_state
[
i
*
params
.
conv_state_l_stride
]
=
conv_state
[(
i
+
advance_len
)
*
params
.
conv_state_l_stride
];
}
#pragma unroll
for
(
int
i
=
0
;
i
<
kWidth
-
1
;
++
i
)
{
input_t
state_val
=
conv_state
[(
state_len
-
(
kWidth
-
1
)
+
i
)
*
params
.
conv_state_l_stride
];
if
(
i
<
advance_len
+
(
kWidth
-
1
)
&&
state_len
-
advance_len
-
(
kWidth
-
1
)
+
i
>=
0
)
{
conv_state
[(
state_len
-
advance_len
-
(
kWidth
-
1
)
+
i
)
*
params
.
conv_state_l_stride
]
=
state_val
;
}
x_vals
[
i
]
=
float
(
state_val
);
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
kWidth
-
1
;
++
i
,
update_idx
=
update_idx
+
1
>=
state_len
?
update_idx
+
1
-
state_len
:
update_idx
+
1
)
{
input_t
state_val
=
conv_state
[
update_idx
*
params
.
conv_state_l_stride
];
x_vals
[
i
]
=
float
(
state_val
);
}
}
#pragma unroll 2
for
(
int
i
=
0
;
i
<
params
.
seqlen
;
++
i
)
{
input_t
x_val
=
x
[
i
*
params
.
x_l_stride
];
if
constexpr
(
!
kIsCircularBuffer
)
{
if
(
i
<
advance_len
&&
state_len
-
advance_len
+
i
>=
0
)
{
conv_state
[(
state_len
-
advance_len
+
i
)
*
params
.
conv_state_l_stride
]
=
x_val
;
}
}
else
{
conv_state
[
update_idx
*
params
.
conv_state_l_stride
]
=
x_val
;
++
update_idx
;
update_idx
=
update_idx
>=
state_len
?
update_idx
-
state_len
:
update_idx
;
}
x_vals
[
kWidth
-
1
]
=
float
(
x_val
);
float
out_val
=
bias_val
;
#pragma unroll
for
(
int
j
=
0
;
j
<
kWidth
;
++
j
)
{
out_val
+=
weight_vals
[
j
]
*
x_vals
[
j
];
}
if
(
params
.
silu_activation
)
{
out_val
=
out_val
/
(
1
+
expf
(
-
out_val
));
}
out
[
i
*
params
.
out_l_stride
]
=
input_t
(
out_val
);
// Shift the input buffer by 1
#pragma unroll
for
(
int
i
=
0
;
i
<
kWidth
-
1
;
++
i
)
{
x_vals
[
i
]
=
x_vals
[
i
+
1
];
}
}
}
template
<
int
kNThreads
,
int
kWidth
,
typename
input_t
,
typename
weight_t
>
void
causal_conv1d_update_launch
(
ConvParamsBase
&
params
,
cudaStream_t
stream
)
{
using
Ktraits
=
Causal_conv1d_update_kernel_traits
<
kNThreads
,
kWidth
,
input_t
,
weight_t
>
;
dim3
grid
(
params
.
batch
,
(
params
.
dim
+
kNThreads
-
1
)
/
kNThreads
);
auto
kernel
=
params
.
cache_seqlens
==
nullptr
?
&
causal_conv1d_update_kernel
<
Ktraits
,
false
>
:
&
causal_conv1d_update_kernel
<
Ktraits
,
true
>
;
kernel
<<<
grid
,
Ktraits
::
kNThreads
,
0
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
template
<
typename
input_t
,
typename
weight_t
>
void
causal_conv1d_update_cuda
(
ConvParamsBase
&
params
,
cudaStream_t
stream
)
{
if
(
params
.
width
==
2
)
{
causal_conv1d_update_launch
<
64
,
2
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
width
==
3
)
{
causal_conv1d_update_launch
<
64
,
3
,
input_t
,
weight_t
>
(
params
,
stream
);
}
else
if
(
params
.
width
==
4
)
{
causal_conv1d_update_launch
<
64
,
4
,
input_t
,
weight_t
>
(
params
,
stream
);
}
}
template
void
causal_conv1d_update_cuda
<
float
,
float
>(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
template
void
causal_conv1d_update_cuda
<
at
::
Half
,
at
::
Half
>(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
template
void
causal_conv1d_update_cuda
<
at
::
BFloat16
,
at
::
BFloat16
>(
ConvParamsBase
&
params
,
cudaStream_t
stream
);
csrc/mamba/causal_conv1d/causal_conv1d.h
deleted
100644 → 0
View file @
751c492c
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
// clang-format off
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
ConvParamsBase
{
using
index_t
=
uint32_t
;
int
batch
,
dim
,
seqlen
,
width
;
int64_t
pad_slot_id
;
bool
silu_activation
;
index_t
x_batch_stride
;
index_t
x_c_stride
;
index_t
x_l_stride
;
index_t
weight_c_stride
;
index_t
weight_width_stride
;
index_t
out_batch_stride
;
index_t
out_c_stride
;
index_t
out_l_stride
;
int
conv_state_len
;
index_t
conv_state_batch_stride
;
index_t
conv_state_c_stride
;
index_t
conv_state_l_stride
;
// Common data pointers.
void
*
__restrict__
x_ptr
;
void
*
__restrict__
weight_ptr
;
void
*
__restrict__
bias_ptr
;
void
*
__restrict__
out_ptr
;
void
*
__restrict__
conv_state_ptr
;
void
*
__restrict__
query_start_loc_ptr
;
void
*
__restrict__
has_initial_state_ptr
;
void
*
__restrict__
cache_indices_ptr
;
int32_t
*
__restrict__
cache_seqlens
;
// For the continuous batching case. Makes it so that the mamba state for
// the current batch doesn't need to be a contiguous tensor.
int32_t
*
__restrict__
conv_state_indices_ptr
;
void
*
__restrict__
seq_idx_ptr
;
// No __restrict__ since initial_states could be the same as final_states.
void
*
initial_states_ptr
;
index_t
initial_states_batch_stride
;
index_t
initial_states_l_stride
;
index_t
initial_states_c_stride
;
void
*
final_states_ptr
;
index_t
final_states_batch_stride
;
index_t
final_states_l_stride
;
index_t
final_states_c_stride
;
void
*
conv_states_ptr
;
index_t
conv_states_batch_stride
;
index_t
conv_states_l_stride
;
index_t
conv_states_c_stride
;
};
#ifndef USE_ROCM
#include <cuda_bf16.h>
template
<
typename
T
>
__device__
inline
T
shuffle_xor
(
T
val
,
int
offset
)
{
return
__shfl_xor_sync
(
uint32_t
(
-
1
),
val
,
offset
);
}
constexpr
size_t
custom_max
(
std
::
initializer_list
<
size_t
>
ilist
)
{
return
std
::
max
(
ilist
);
}
template
<
typename
T
>
constexpr
T
constexpr_min
(
T
a
,
T
b
)
{
return
std
::
min
(
a
,
b
);
}
#else
#include <hip/hip_bf16.h>
template
<
typename
T
>
__device__
inline
T
shuffle_xor
(
T
val
,
int
offset
)
{
return
__shfl_xor
(
val
,
offset
);
}
constexpr
size_t
custom_max
(
std
::
initializer_list
<
size_t
>
ilist
)
{
return
*
std
::
max_element
(
ilist
.
begin
(),
ilist
.
end
());
}
template
<
typename
T
>
constexpr
T
constexpr_min
(
T
a
,
T
b
)
{
return
a
<
b
?
a
:
b
;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
BYTES
>
struct
BytesToType
{};
template
<
>
struct
BytesToType
<
16
>
{
using
Type
=
uint4
;
static_assert
(
sizeof
(
Type
)
==
16
);
};
template
<
>
struct
BytesToType
<
8
>
{
using
Type
=
uint64_t
;
static_assert
(
sizeof
(
Type
)
==
8
);
};
template
<
>
struct
BytesToType
<
4
>
{
using
Type
=
uint32_t
;
static_assert
(
sizeof
(
Type
)
==
4
);
};
template
<
>
struct
BytesToType
<
2
>
{
using
Type
=
uint16_t
;
static_assert
(
sizeof
(
Type
)
==
2
);
};
template
<
>
struct
BytesToType
<
1
>
{
using
Type
=
uint8_t
;
static_assert
(
sizeof
(
Type
)
==
1
);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
SumOp
{
__device__
inline
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
+
y
;
}
};
template
<
int
THREADS
>
struct
Allreduce
{
static_assert
(
THREADS
==
32
||
THREADS
==
16
||
THREADS
==
8
||
THREADS
==
4
);
template
<
typename
T
,
typename
Operator
>
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
constexpr
int
OFFSET
=
THREADS
/
2
;
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
OFFSET
));
return
Allreduce
<
OFFSET
>::
run
(
x
,
op
);
}
};
template
<
>
struct
Allreduce
<
2
>
{
template
<
typename
T
,
typename
Operator
>
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
1
));
return
x
;
}
};
csrc/mamba/causal_conv1d/static_switch.h
deleted
100644 → 0
View file @
751c492c
// Inspired by
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
// clang-format off
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
static constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
static constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
View file @
711aa9d5
...
...
@@ -7,7 +7,11 @@
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
#ifdef USE_ROCM
#include <c10/hip/HIPException.h> // For C10_HIP_CHECK and C10_HIP_KERNEL_LAUNCH_CHECK
#else
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
#endif
#ifndef USE_ROCM
#include <cub/block/block_load.cuh>
...
...
@@ -312,19 +316,25 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
// kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size
constexpr
bool
kIsVariableB
=
true
;
constexpr
bool
kIsVariableC
=
true
;
constexpr
bool
kHasZ
=
true
;
BOOL_SWITCH
(
params
.
seqlen
%
(
kNThreads
*
kNItems
)
==
0
,
kIsEvenLen
,
[
&
]
{
BOOL_SWITCH
(
params
.
query_start_loc_ptr
!=
nullptr
,
kVarlen
,
[
&
]
{
using
Ktraits
=
Selective_Scan_fwd_kernel_traits
<
kNThreads
,
kNItems
,
kNRows
,
kIsEvenLen
,
kIsVariableB
,
kIsVariableC
,
kHasZ
,
kVarlen
,
input_t
,
weight_t
>
;
constexpr
int
kSmemSize
=
Ktraits
::
kSmemSize
+
kNRows
*
MAX_DSTATE
*
sizeof
(
typename
Ktraits
::
scan_t
);
dim3
grid
(
params
.
batch
,
params
.
dim
/
kNRows
);
auto
kernel
=
&
selective_scan_fwd_kernel
<
Ktraits
>
;
if
(
kSmemSize
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
(
void
*
)
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kSmemSize
));
}
kernel
<<<
grid
,
Ktraits
::
kNThreads
,
kSmemSize
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
BOOL_SWITCH
(
params
.
z_ptr
!=
nullptr
,
kHasZ
,
[
&
]
{
BOOL_SWITCH
(
params
.
query_start_loc_ptr
!=
nullptr
,
kVarlen
,
[
&
]
{
using
Ktraits
=
Selective_Scan_fwd_kernel_traits
<
kNThreads
,
kNItems
,
kNRows
,
kIsEvenLen
,
kIsVariableB
,
kIsVariableC
,
kHasZ
,
kVarlen
,
input_t
,
weight_t
>
;
constexpr
int
kSmemSize
=
Ktraits
::
kSmemSize
+
kNRows
*
MAX_DSTATE
*
sizeof
(
typename
Ktraits
::
scan_t
);
dim3
grid
(
params
.
batch
,
params
.
dim
/
kNRows
);
auto
kernel
=
&
selective_scan_fwd_kernel
<
Ktraits
>
;
if
(
kSmemSize
>=
48
*
1024
)
{
#ifdef USE_ROCM
C10_HIP_CHECK
(
hipFuncSetAttribute
(
reinterpret_cast
<
const
void
*>
(
kernel
),
hipFuncAttributeMaxDynamicSharedMemorySize
,
kSmemSize
));
#else
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kSmemSize
));
#endif
}
kernel
<<<
grid
,
Ktraits
::
kNThreads
,
kSmemSize
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
});
});
}
...
...
@@ -612,19 +622,20 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
at
::
Tensor
z
,
out_z
;
const
bool
has_z
=
z_
.
has_value
();
TORCH_CHECK
(
has_z
,
"has_z = False is disabled in favor of reduced binary size"
)
z
=
z_
.
value
();
TORCH_CHECK
(
z
.
scalar_type
()
==
input_type
);
TORCH_CHECK
(
z
.
is_cuda
());
TORCH_CHECK
(
z
.
stride
(
-
1
)
==
1
||
z
.
size
(
-
1
)
==
1
);
if
(
varlen
){
CHECK_SHAPE
(
z
,
dim
,
seqlen
);
}
else
{
CHECK_SHAPE
(
z
,
batch_size
,
dim
,
seqlen
);
if
(
has_z
)
{
z
=
z_
.
value
();
TORCH_CHECK
(
z
.
scalar_type
()
==
input_type
);
TORCH_CHECK
(
z
.
is_cuda
());
TORCH_CHECK
(
z
.
stride
(
-
1
)
==
1
||
z
.
size
(
-
1
)
==
1
);
if
(
varlen
){
CHECK_SHAPE
(
z
,
dim
,
seqlen
);
}
else
{
CHECK_SHAPE
(
z
,
batch_size
,
dim
,
seqlen
);
}
out_z
=
z
;
}
out_z
=
z
;
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
at
::
Tensor
out
=
delta
;
TORCH_CHECK
(
ssm_states
.
scalar_type
()
==
input_type
);
...
...
@@ -653,4 +664,3 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
selective_scan_fwd_cuda
<
input_t
,
weight_t
>
(
params
,
stream
);
});
}
csrc/moe/moe_align_sum_kernels.cu
View file @
711aa9d5
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cub/cub.cuh>
#include <ATen/ATen.h>
#include <ATen/cuda/Atomic.cuh>
...
...
@@ -19,9 +20,14 @@ __global__ void moe_align_block_size_kernel(
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
padded_num_experts
,
int32_t
experts_per_warp
,
int32_t
block_size
,
size_t
numel
,
int32_t
*
__restrict__
cumsum
)
{
size_t
numel
,
int32_t
*
__restrict__
cumsum
,
int32_t
max_num_tokens_padded
)
{
extern
__shared__
int32_t
shared_counts
[];
// Initialize sorted_token_ids with numel
for
(
size_t
it
=
threadIdx
.
x
;
it
<
max_num_tokens_padded
;
it
+=
blockDim
.
x
)
{
sorted_token_ids
[
it
]
=
numel
;
}
const
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
my_expert_start
=
warp_id
*
experts_per_warp
;
...
...
@@ -45,18 +51,27 @@ __global__ void moe_align_block_size_kernel(
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
int
expert_count
=
0
;
int
warp_idx
=
(
i
-
1
)
/
experts_per_warp
;
int
expert_offset
=
(
i
-
1
)
%
experts_per_warp
;
expert_count
=
shared_counts
[
warp_idx
*
experts_per_warp
+
expert_offset
];
// Compute prefix sum over token counts per expert
using
BlockScan
=
cub
::
BlockScan
<
int32_t
,
1024
>
;
__shared__
typename
BlockScan
::
TempStorage
temp_storage
;
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
CEILDIV
(
expert_count
,
block_size
)
*
block_size
;
}
*
total_tokens_post_pad
=
cumsum
[
num_experts
];
int
expert_count
=
0
;
int
expert_id
=
threadIdx
.
x
;
if
(
expert_id
<
num_experts
)
{
int
warp_idx
=
expert_id
/
experts_per_warp
;
int
expert_offset
=
expert_id
%
experts_per_warp
;
expert_count
=
shared_counts
[
warp_idx
*
experts_per_warp
+
expert_offset
];
expert_count
=
CEILDIV
(
expert_count
,
block_size
)
*
block_size
;
}
int
cumsum_val
;
BlockScan
(
temp_storage
).
ExclusiveSum
(
expert_count
,
cumsum_val
);
if
(
expert_id
<=
num_experts
)
{
cumsum
[
expert_id
]
=
cumsum_val
;
}
if
(
expert_id
==
num_experts
)
{
*
total_tokens_post_pad
=
cumsum_val
;
}
__syncthreads
();
...
...
@@ -67,6 +82,13 @@ __global__ void moe_align_block_size_kernel(
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
}
}
// Fill remaining expert_ids with 0
const
size_t
fill_start_idx
=
cumsum
[
num_experts
]
/
block_size
+
threadIdx
.
x
;
const
size_t
expert_ids_size
=
CEILDIV
(
max_num_tokens_padded
,
block_size
);
for
(
size_t
i
=
fill_start_idx
;
i
<
expert_ids_size
;
i
+=
blockDim
.
x
)
{
expert_ids
[
i
]
=
0
;
}
}
template
<
typename
scalar_t
>
...
...
@@ -178,7 +200,12 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
)
{
int32_t
block_size
,
size_t
numel
,
int32_t
max_num_tokens_padded
)
{
// Initialize sorted_token_ids with numel
for
(
size_t
it
=
threadIdx
.
x
;
it
<
max_num_tokens_padded
;
it
+=
blockDim
.
x
)
{
sorted_token_ids
[
it
]
=
numel
;
}
const
size_t
tid
=
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
;
...
...
@@ -226,6 +253,13 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
}
}
// Fill remaining expert_ids with 0
const
size_t
fill_start_idx
=
cumsum
[
num_experts
]
/
block_size
+
threadIdx
.
x
;
const
size_t
expert_ids_size
=
CEILDIV
(
max_num_tokens_padded
,
block_size
);
for
(
size_t
i
=
fill_start_idx
;
i
<
expert_ids_size
;
i
+=
blockDim
.
x
)
{
expert_ids
[
i
]
=
0
;
}
for
(
size_t
i
=
tid
;
i
<
numel
;
i
+=
stride
)
{
int32_t
expert_id
=
topk_ids
[
i
];
int32_t
rank_post_pad
=
...
...
@@ -252,13 +286,17 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int
threads
=
1024
;
threads
=
((
threads
+
WARP_SIZE
-
1
)
/
WARP_SIZE
)
*
WARP_SIZE
;
// BlockScan uses 1024 threads and assigns one thread per expert.
TORCH_CHECK
(
padded_num_experts
<
1024
,
"padded_num_experts must be less than 1024"
);
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
// calc needed amount of shared mem for `cumsum` tensors
auto
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt
).
device
(
topk_ids
.
device
());
torch
::
Tensor
cumsum_buffer
=
torch
::
zeros
({
num_experts
+
1
},
options_int
);
torch
::
empty
({
num_experts
+
1
},
options_int
);
bool
small_batch_expert_mode
=
(
topk_ids
.
numel
()
<
1024
)
&&
(
num_experts
<=
64
);
...
...
@@ -276,7 +314,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
());
topk_ids
.
numel
()
,
sorted_token_ids
.
size
(
0
)
);
}
else
{
auto
align_kernel
=
vllm
::
moe
::
moe_align_block_size_kernel
<
scalar_t
>
;
...
...
@@ -290,7 +328,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
padded_num_experts
,
experts_per_warp
,
block_size
,
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
());
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
(),
sorted_token_ids
.
size
(
0
));
const
int
block_threads
=
std
::
min
(
256
,
(
int
)
threads
);
const
int
num_blocks
=
...
...
csrc/ops.h
View file @
711aa9d5
...
...
@@ -428,6 +428,11 @@ void scaled_fp4_experts_quant(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
);
void
per_token_group_quant_fp8
(
const
torch
::
Tensor
&
input
,
torch
::
Tensor
&
output_q
,
torch
::
Tensor
&
output_s
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
double
fp8_max
,
bool
scale_ue8m0
);
#endif
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
...
...
@@ -467,22 +472,6 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const
std
::
optional
<
torch
::
Tensor
>&
has_initial_state
,
const
torch
::
Tensor
&
ssm_states
,
int64_t
pad_slot_id
);
void
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
std
::
optional
<
at
::
Tensor
>&
bias_
,
bool
silu_activation
,
const
std
::
optional
<
at
::
Tensor
>&
cache_seqlens_
,
const
std
::
optional
<
at
::
Tensor
>&
conv_state_indices_
,
int64_t
pad_slot_id
);
void
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
std
::
optional
<
at
::
Tensor
>&
bias_
,
const
std
::
optional
<
at
::
Tensor
>&
conv_states
,
const
std
::
optional
<
at
::
Tensor
>&
query_start_loc
,
const
std
::
optional
<
at
::
Tensor
>&
cache_indices
,
const
std
::
optional
<
at
::
Tensor
>&
has_initial_state
,
bool
silu_activation
,
int64_t
pad_slot_id
);
using
fptr_t
=
int64_t
;
fptr_t
init_custom_ar
(
const
std
::
vector
<
int64_t
>&
fake_ipc_ptrs
,
torch
::
Tensor
&
rank_data
,
int64_t
rank
,
...
...
csrc/opt/layernorm_kernels_opt.cu
View file @
711aa9d5
...
...
@@ -22,13 +22,14 @@ template <typename scalar_t>
__global__
void
rms_norm_kernel
(
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
int64_t
input_stride
,
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_siz
e
+
idx
];
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
input_strid
e
+
idx
];
variance
+=
x
*
x
;
}
...
...
@@ -42,7 +43,7 @@ __global__ void rms_norm_kernel(
__syncthreads
();
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_siz
e
+
idx
];
float
x
=
(
float
)
input
[
blockIdx
.
x
*
input_strid
e
+
idx
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
}
...
...
@@ -57,6 +58,7 @@ template <typename scalar_t, int width>
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_kernel
(
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
int64_t
input_stride
,
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
...
...
@@ -65,6 +67,7 @@ fused_add_rms_norm_kernel(
static_assert
(
sizeof
(
_f16Vec
<
scalar_t
,
width
>
)
==
sizeof
(
scalar_t
)
*
width
);
const
int
vec_hidden_size
=
hidden_size
/
width
;
const
int64_t
vec_input_stride
=
input_stride
/
width
;
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
/* These and the argument pointers are all declared `restrict` as they are
...
...
@@ -79,7 +82,8 @@ fused_add_rms_norm_kernel(
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
_f16Vec
<
scalar_t
,
width
>
temp
=
input_v
[
id
];
int64_t
strided_id
=
blockIdx
.
x
*
vec_input_stride
+
idx
;
_f16Vec
<
scalar_t
,
width
>
temp
=
input_v
[
strided_id
];
temp
+=
residual_v
[
id
];
variance
+=
temp
.
sum_squares
();
residual_v
[
id
]
=
temp
;
...
...
@@ -96,10 +100,11 @@ fused_add_rms_norm_kernel(
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
int64_t
strided_id
=
blockIdx
.
x
*
vec_input_stride
+
idx
;
_f16Vec
<
scalar_t
,
width
>
temp
=
residual_v
[
id
];
temp
*=
s_variance
;
temp
*=
weight_v
[
idx
];
input_v
[
id
]
=
temp
;
input_v
[
strided_
id
]
=
temp
;
}
}
...
...
@@ -110,6 +115,7 @@ template <typename scalar_t, int width>
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_kernel
(
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
int64_t
input_stride
,
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
...
...
@@ -117,7 +123,7 @@ fused_add_rms_norm_kernel(
float
variance
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
scalar_t
z
=
input
[
blockIdx
.
x
*
hidden_siz
e
+
idx
];
scalar_t
z
=
input
[
blockIdx
.
x
*
input_strid
e
+
idx
];
z
+=
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
x
=
(
float
)
z
;
variance
+=
x
*
x
;
...
...
@@ -135,7 +141,7 @@ fused_add_rms_norm_kernel(
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
input
[
blockIdx
.
x
*
hidden_siz
e
+
idx
]
=
input
[
blockIdx
.
x
*
input_strid
e
+
idx
]
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
}
}
...
...
@@ -253,8 +259,14 @@ void rms_norm_opt(torch::Tensor& out, // [..., hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
double
epsilon
)
{
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
input
.
stride
(
-
1
)
==
1
);
TORCH_CHECK
(
weight
.
is_contiguous
());
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
int64_t
input_stride
=
input
.
stride
(
-
2
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
...
...
@@ -300,7 +312,7 @@ void rms_norm_opt(torch::Tensor& out, // [..., hidden_size]
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel"
,
[
&
]
{
vllm
::
rms_norm_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
input_stride
,
weight
.
data_ptr
<
scalar_t
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
}
...
...
@@ -310,10 +322,10 @@ void rms_norm_opt(torch::Tensor& out, // [..., hidden_size]
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
<<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(),
\
residual.data_ptr<scalar_t>(),
\
weight.data_ptr<scalar_t>(), epsilon,
\
num_tokens, hidden_size);
\
<<<grid, block, 0, stream>>>(
\
input.data_ptr<scalar_t>(), input_stride,
\
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
\
epsilon,
num_tokens, hidden_size); \
});
...
...
@@ -322,7 +334,10 @@ void fused_add_rms_norm_opt(torch::Tensor& input, // [..., hidden_size]
torch
::
Tensor
&
residual
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
double
epsilon
)
{
TORCH_CHECK
(
residual
.
is_contiguous
());
TORCH_CHECK
(
weight
.
is_contiguous
());
int
hidden_size
=
input
.
size
(
-
1
);
int64_t
input_stride
=
input
.
stride
(
-
2
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
...
@@ -380,7 +395,16 @@ void fused_add_rms_norm_opt(torch::Tensor& input, // [..., hidden_size]
bytes.
*/
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
)
{
constexpr
int
vector_width
=
8
;
constexpr
int
req_alignment_bytes
=
vector_width
*
2
;
// vector_width * sizeof(bfloat16 or float16) (float32
// falls back to non-vectorized version anyway)
bool
ptrs_are_aligned
=
inp_ptr
%
req_alignment_bytes
==
0
&&
res_ptr
%
req_alignment_bytes
==
0
&&
wt_ptr
%
req_alignment_bytes
==
0
;
bool
offsets_are_multiple_of_vector_width
=
hidden_size
%
vector_width
==
0
&&
input_stride
%
vector_width
==
0
;
if
(
ptrs_are_aligned
&&
offsets_are_multiple_of_vector_width
)
{
LAUNCH_FUSED_ADD_RMS_NORM
(
8
);
}
else
{
LAUNCH_FUSED_ADD_RMS_NORM
(
0
);
...
...
csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu
View file @
711aa9d5
...
...
@@ -201,11 +201,10 @@ void run_blockwise_scaled_group_mm(
reinterpret_cast
<
typename
ScheduleConfig
::
LayoutSFB
*>
(
layout_sfb
.
data_ptr
())};
cutlass
::
KernelHardwareInfo
hw_info
;
hw_info
.
device_id
=
a_ptrs
.
get_device
();
hw_info
.
sm_count
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
int
device_id
=
a_ptrs
.
device
().
index
();
static
const
cutlass
::
KernelHardwareInfo
hw_info
{
device_id
,
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
device_id
)};
// Epilogue Arguments
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
...
...
Prev
1
2
3
4
5
6
7
…
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment