Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0a1ab1e5
Unverified
Commit
0a1ab1e5
authored
Dec 16, 2025
by
Michael Goin
Committed by
GitHub
Dec 16, 2025
Browse files
[Perf][Kernels] Vectorize `csrc/activations_kernels.cu` (#29512)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
b6ec077e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
176 additions
and
38 deletions
+176
-38
benchmarks/kernels/benchmark_activation.py
benchmarks/kernels/benchmark_activation.py
+2
-2
csrc/activation_kernels.cu
csrc/activation_kernels.cu
+174
-36
No files found.
benchmarks/kernels/benchmark_activation.py
View file @
0a1ab1e5
...
@@ -13,8 +13,8 @@ from vllm.triton_utils import triton
...
@@ -13,8 +13,8 @@ from vllm.triton_utils import triton
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
batch_size_range
=
[
1
,
16
,
32
,
64
,
128
]
batch_size_range
=
[
1
,
16
,
128
]
seq_len_range
=
[
1
,
16
,
64
,
1
28
,
256
,
512
,
1024
,
2048
,
4096
]
seq_len_range
=
[
1
,
16
,
64
,
1
024
,
4096
]
intermediate_size
=
[
3072
,
9728
,
12288
]
intermediate_size
=
[
3072
,
9728
,
12288
]
configs
=
list
(
itertools
.
product
(
batch_size_range
,
seq_len_range
,
intermediate_size
))
configs
=
list
(
itertools
.
product
(
batch_size_range
,
seq_len_range
,
intermediate_size
))
...
...
csrc/activation_kernels.cu
View file @
0a1ab1e5
...
@@ -15,19 +15,61 @@ __device__ __forceinline__ scalar_t compute(const scalar_t& x,
...
@@ -15,19 +15,61 @@ __device__ __forceinline__ scalar_t compute(const scalar_t& x,
const
scalar_t
&
y
)
{
const
scalar_t
&
y
)
{
return
act_first
?
ACT_FN
(
x
)
*
y
:
x
*
ACT_FN
(
y
);
return
act_first
?
ACT_FN
(
x
)
*
y
:
x
*
ACT_FN
(
y
);
}
}
// Activation and gating kernel template.
// Check if all pointers are 16-byte aligned for int4 vectorized access
__device__
__forceinline__
bool
is_16byte_aligned
(
const
void
*
ptr
)
{
return
(
reinterpret_cast
<
uintptr_t
>
(
ptr
)
&
15
)
==
0
;
}
// Activation and gating kernel template.
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
),
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
),
bool
act_first
>
bool
act_first
>
__global__
void
act_and_mul_kernel
(
__global__
void
act_and_mul_kernel
(
scalar_t
*
__restrict__
out
,
// [..., d]
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
int
d
)
{
const
int
d
)
{
constexpr
int
VEC_SIZE
=
16
/
sizeof
(
scalar_t
);
const
int64_t
token_idx
=
blockIdx
.
x
;
const
int64_t
token_idx
=
blockIdx
.
x
;
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
*
x_ptr
=
input
+
token_idx
*
2
*
d
;
const
scalar_t
x
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
idx
]);
const
scalar_t
*
y_ptr
=
x_ptr
+
d
;
const
scalar_t
y
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
d
+
idx
]);
scalar_t
*
out_ptr
=
out
+
token_idx
*
d
;
out
[
token_idx
*
d
+
idx
]
=
compute
<
scalar_t
,
ACT_FN
,
act_first
>
(
x
,
y
);
// Check alignment for 128-bit vectorized access.
// All three pointers must be 16-byte aligned for safe int4 operations.
const
bool
aligned
=
is_16byte_aligned
(
x_ptr
)
&&
is_16byte_aligned
(
y_ptr
)
&&
is_16byte_aligned
(
out_ptr
);
if
(
aligned
&&
d
>=
VEC_SIZE
)
{
// Fast path: 128-bit vectorized loop
const
int4
*
x_vec
=
reinterpret_cast
<
const
int4
*>
(
x_ptr
);
const
int4
*
y_vec
=
reinterpret_cast
<
const
int4
*>
(
y_ptr
);
int4
*
out_vec
=
reinterpret_cast
<
int4
*>
(
out_ptr
);
const
int
num_vecs
=
d
/
VEC_SIZE
;
const
int
vec_end
=
num_vecs
*
VEC_SIZE
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_vecs
;
i
+=
blockDim
.
x
)
{
int4
x
=
VLLM_LDG
(
&
x_vec
[
i
]),
y
=
VLLM_LDG
(
&
y_vec
[
i
]),
r
;
auto
*
xp
=
reinterpret_cast
<
scalar_t
*>
(
&
x
);
auto
*
yp
=
reinterpret_cast
<
scalar_t
*>
(
&
y
);
auto
*
rp
=
reinterpret_cast
<
scalar_t
*>
(
&
r
);
#pragma unroll
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
j
++
)
{
rp
[
j
]
=
compute
<
scalar_t
,
ACT_FN
,
act_first
>
(
xp
[
j
],
yp
[
j
]);
}
out_vec
[
i
]
=
r
;
}
// Scalar cleanup for remaining elements
for
(
int
i
=
vec_end
+
threadIdx
.
x
;
i
<
d
;
i
+=
blockDim
.
x
)
{
out_ptr
[
i
]
=
compute
<
scalar_t
,
ACT_FN
,
act_first
>
(
VLLM_LDG
(
&
x_ptr
[
i
]),
VLLM_LDG
(
&
y_ptr
[
i
]));
}
}
else
{
// Scalar fallback for unaligned data or small d
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
VLLM_LDG
(
&
x_ptr
[
idx
]);
const
scalar_t
y
=
VLLM_LDG
(
&
y_ptr
[
idx
]);
out_ptr
[
idx
]
=
compute
<
scalar_t
,
ACT_FN
,
act_first
>
(
x
,
y
);
}
}
}
}
}
...
@@ -120,50 +162,115 @@ template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
...
@@ -120,50 +162,115 @@ template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
__global__
void
act_and_mul_kernel_with_param
(
__global__
void
act_and_mul_kernel_with_param
(
scalar_t
*
__restrict__
out
,
const
scalar_t
*
__restrict__
input
,
const
int
d
,
scalar_t
*
__restrict__
out
,
const
scalar_t
*
__restrict__
input
,
const
int
d
,
const
float
param
)
{
const
float
param
)
{
constexpr
int
VEC_SIZE
=
16
/
sizeof
(
scalar_t
);
const
int64_t
token_idx
=
blockIdx
.
x
;
const
int64_t
token_idx
=
blockIdx
.
x
;
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
*
x_ptr
=
input
+
token_idx
*
2
*
d
;
const
scalar_t
x
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
idx
]);
const
scalar_t
*
y_ptr
=
x_ptr
+
d
;
const
scalar_t
y
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
d
+
idx
]);
scalar_t
*
out_ptr
=
out
+
token_idx
*
d
;
out
[
token_idx
*
d
+
idx
]
=
ACT_FN
(
x
,
param
)
*
y
;
// Check alignment for 128-bit vectorized access
const
bool
aligned
=
is_16byte_aligned
(
x_ptr
)
&&
is_16byte_aligned
(
y_ptr
)
&&
is_16byte_aligned
(
out_ptr
);
if
(
aligned
&&
d
>=
VEC_SIZE
)
{
// Fast path: 128-bit vectorized loop
const
int4
*
x_vec
=
reinterpret_cast
<
const
int4
*>
(
x_ptr
);
const
int4
*
y_vec
=
reinterpret_cast
<
const
int4
*>
(
y_ptr
);
int4
*
out_vec
=
reinterpret_cast
<
int4
*>
(
out_ptr
);
const
int
num_vecs
=
d
/
VEC_SIZE
;
const
int
vec_end
=
num_vecs
*
VEC_SIZE
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_vecs
;
i
+=
blockDim
.
x
)
{
int4
x
=
VLLM_LDG
(
&
x_vec
[
i
]),
y
=
VLLM_LDG
(
&
y_vec
[
i
]),
r
;
auto
*
xp
=
reinterpret_cast
<
scalar_t
*>
(
&
x
);
auto
*
yp
=
reinterpret_cast
<
scalar_t
*>
(
&
y
);
auto
*
rp
=
reinterpret_cast
<
scalar_t
*>
(
&
r
);
#pragma unroll
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
j
++
)
{
rp
[
j
]
=
ACT_FN
(
xp
[
j
],
param
)
*
yp
[
j
];
}
out_vec
[
i
]
=
r
;
}
// Scalar cleanup for remaining elements
for
(
int
i
=
vec_end
+
threadIdx
.
x
;
i
<
d
;
i
+=
blockDim
.
x
)
{
out_ptr
[
i
]
=
ACT_FN
(
VLLM_LDG
(
&
x_ptr
[
i
]),
param
)
*
VLLM_LDG
(
&
y_ptr
[
i
]);
}
}
else
{
// Scalar fallback for unaligned data or small d
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
VLLM_LDG
(
&
x_ptr
[
idx
]);
const
scalar_t
y
=
VLLM_LDG
(
&
y_ptr
[
idx
]);
out_ptr
[
idx
]
=
ACT_FN
(
x
,
param
)
*
y
;
}
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
T
swigluoai_and_mul
(
const
T
&
gate
,
const
T
&
up
,
__device__
__forceinline__
T
swigluoai_and_mul
(
const
T
&
gate
,
const
T
&
up
,
float
alpha
,
float
limit
)
{
float
alpha
,
float
limit
)
{
// clamp gate: min=None, max=limit
// Clamp gate to (-inf, limit] and up to [-limit, limit]
const
float
gate_f
=
(
float
)
gate
;
const
float
g
=
fminf
((
float
)
gate
,
limit
);
const
float
clamped_gate
=
gate_f
>
limit
?
limit
:
gate_f
;
const
float
u
=
fmaxf
(
fminf
((
float
)
up
,
limit
),
-
limit
);
// glu = gate * sigmoid(gate * alpha), then return (up + 1) * glu
// clamp up: min=-limit, max=limit
return
(
T
)((
u
+
1.0
f
)
*
g
/
(
1.0
f
+
expf
(
-
g
*
alpha
)));
const
float
up_f
=
(
float
)
up
;
const
float
clamped_up
=
up_f
>
limit
?
limit
:
(
up_f
<
-
limit
?
-
limit
:
up_f
);
// glu = gate * sigmoid(gate * alpha)
const
float
sigmoid_val
=
1.0
f
/
(
1.0
f
+
expf
(
-
clamped_gate
*
alpha
));
const
float
glu
=
clamped_gate
*
sigmoid_val
;
// (up + 1) * glu
return
(
T
)((
clamped_up
+
1.0
f
)
*
glu
);
}
}
// Interleaved gate/up: input has [gate0, up0, gate1, up1, ...].
template
<
typename
scalar_t
,
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
,
const
scalar_t
&
,
const
float
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
,
const
scalar_t
&
,
const
float
,
const
float
)>
const
float
)>
__global__
void
swigluoai_and_mul_kernel
(
__global__
void
swigluoai_and_mul_kernel
(
scalar_t
*
__restrict__
out
,
// [..., d]
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2
,
d]
const
scalar_t
*
__restrict__
input
,
// [..., 2
*
d]
(interleaved)
const
int
d
,
const
float
alpha
,
const
float
limit
)
{
const
int
d
,
const
float
alpha
,
const
float
limit
)
{
// For interleaved data: input has 2*d elements per token (gate/up pairs)
// output has d elements per token
constexpr
int
VEC_SIZE
=
16
/
sizeof
(
scalar_t
);
constexpr
int
PAIRS
=
VEC_SIZE
/
2
;
// Number of gate/up pairs per int4 load
const
int64_t
token_idx
=
blockIdx
.
x
;
const
int64_t
token_idx
=
blockIdx
.
x
;
// TODO: Vectorize loads and stores.
const
scalar_t
*
in_ptr
=
input
+
token_idx
*
2
*
d
;
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
scalar_t
*
out_ptr
=
out
+
token_idx
*
d
;
// gate = x[..., ::2] (even indices)
const
scalar_t
gate
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
2
*
idx
]);
// Check alignment for 128-bit vectorized access on input.
// up = x[..., 1::2] (odd indices)
// For output we use int2 (64-bit) which has 8-byte alignment requirement.
const
scalar_t
up
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
2
*
idx
+
1
]);
const
bool
in_aligned
=
is_16byte_aligned
(
in_ptr
);
const
bool
out_aligned
=
out
[
token_idx
*
d
+
idx
]
=
ACT_FN
(
gate
,
up
,
alpha
,
limit
);
(
reinterpret_cast
<
uintptr_t
>
(
out_ptr
)
&
7
)
==
0
;
// 8-byte for int2
if
(
in_aligned
&&
out_aligned
&&
d
>=
PAIRS
)
{
// Fast path: vectorized loop
// Each int4 load gives VEC_SIZE elements = PAIRS gate/up pairs
// Each int2 store writes PAIRS output elements
const
int4
*
in_vec
=
reinterpret_cast
<
const
int4
*>
(
in_ptr
);
int2
*
out_vec
=
reinterpret_cast
<
int2
*>
(
out_ptr
);
const
int
num_vecs
=
d
/
PAIRS
;
const
int
vec_end
=
num_vecs
*
PAIRS
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_vecs
;
i
+=
blockDim
.
x
)
{
int4
v
=
VLLM_LDG
(
&
in_vec
[
i
]);
int2
r
;
auto
*
vp
=
reinterpret_cast
<
scalar_t
*>
(
&
v
);
auto
*
rp
=
reinterpret_cast
<
scalar_t
*>
(
&
r
);
#pragma unroll
for
(
int
j
=
0
;
j
<
PAIRS
;
j
++
)
{
rp
[
j
]
=
ACT_FN
(
vp
[
2
*
j
],
vp
[
2
*
j
+
1
],
alpha
,
limit
);
}
out_vec
[
i
]
=
r
;
}
// Scalar cleanup for remaining elements
for
(
int
i
=
vec_end
+
threadIdx
.
x
;
i
<
d
;
i
+=
blockDim
.
x
)
{
out_ptr
[
i
]
=
ACT_FN
(
VLLM_LDG
(
&
in_ptr
[
2
*
i
]),
VLLM_LDG
(
&
in_ptr
[
2
*
i
+
1
]),
alpha
,
limit
);
}
}
else
{
// Scalar fallback for unaligned data or small d
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
// gate = x[..., ::2] (even indices)
const
scalar_t
gate
=
VLLM_LDG
(
&
in_ptr
[
2
*
idx
]);
// up = x[..., 1::2] (odd indices)
const
scalar_t
up
=
VLLM_LDG
(
&
in_ptr
[
2
*
idx
+
1
]);
out_ptr
[
idx
]
=
ACT_FN
(
gate
,
up
,
alpha
,
limit
);
}
}
}
}
}
...
@@ -217,10 +324,41 @@ __global__ void activation_kernel(
...
@@ -217,10 +324,41 @@ __global__ void activation_kernel(
scalar_t
*
__restrict__
out
,
// [..., d]
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., d]
const
int
d
)
{
const
int
d
)
{
constexpr
int
VEC_SIZE
=
16
/
sizeof
(
scalar_t
);
const
int64_t
token_idx
=
blockIdx
.
x
;
const
int64_t
token_idx
=
blockIdx
.
x
;
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
*
in_ptr
=
input
+
token_idx
*
d
;
const
scalar_t
x
=
VLLM_LDG
(
&
input
[
token_idx
*
d
+
idx
]);
scalar_t
*
out_ptr
=
out
+
token_idx
*
d
;
out
[
token_idx
*
d
+
idx
]
=
ACT_FN
(
x
);
// Check alignment for 128-bit vectorized access
const
bool
aligned
=
is_16byte_aligned
(
in_ptr
)
&&
is_16byte_aligned
(
out_ptr
);
if
(
aligned
&&
d
>=
VEC_SIZE
)
{
// Fast path: 128-bit vectorized loop
const
int4
*
in_vec
=
reinterpret_cast
<
const
int4
*>
(
in_ptr
);
int4
*
out_vec
=
reinterpret_cast
<
int4
*>
(
out_ptr
);
const
int
num_vecs
=
d
/
VEC_SIZE
;
const
int
vec_end
=
num_vecs
*
VEC_SIZE
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_vecs
;
i
+=
blockDim
.
x
)
{
int4
v
=
VLLM_LDG
(
&
in_vec
[
i
]),
r
;
auto
*
vp
=
reinterpret_cast
<
scalar_t
*>
(
&
v
);
auto
*
rp
=
reinterpret_cast
<
scalar_t
*>
(
&
r
);
#pragma unroll
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
j
++
)
{
rp
[
j
]
=
ACT_FN
(
vp
[
j
]);
}
out_vec
[
i
]
=
r
;
}
// Scalar cleanup for remaining elements
for
(
int
i
=
vec_end
+
threadIdx
.
x
;
i
<
d
;
i
+=
blockDim
.
x
)
{
out_ptr
[
i
]
=
ACT_FN
(
VLLM_LDG
(
&
in_ptr
[
i
]));
}
}
else
{
// Scalar fallback for unaligned data or small d
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
VLLM_LDG
(
&
in_ptr
[
idx
]);
out_ptr
[
idx
]
=
ACT_FN
(
x
);
}
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment