Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7c0541b3
Unverified
Commit
7c0541b3
authored
Mar 09, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 09, 2025
Browse files
Move activation.cu to sgl-kernel/elementwise (#4250)
parent
e8a69e4d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
83 additions
and
1 deletion
+83
-1
sgl-kernel/csrc/elementwise/activation.cu
sgl-kernel/csrc/elementwise/activation.cu
+82
-0
sgl-kernel/setup.py
sgl-kernel/setup.py
+1
-1
No files found.
sgl-kernel/csrc/elementwise/activation.cu
0 → 100644
View file @
7c0541b3
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/activation.cuh>
#include "pytorch_extension_utils.h"
using
namespace
flashinfer
;
__device__
__forceinline__
float
silu
(
const
float
&
val
)
{
return
val
/
(
1.0
f
+
__expf
(
-
val
));
}
__device__
__forceinline__
float
gelu
(
const
float
&
val
)
{
constexpr
float
kAlpha
=
M_SQRT1_2
;
return
val
*
0.5
f
*
(
1.0
f
+
::
erf
(
val
*
kAlpha
));
}
__device__
__forceinline__
float
gelu_tanh
(
const
float
&
val
)
{
const
float
cdf
=
0.5
f
*
(
1.0
f
+
math
::
tanh
((
0.7978845608028654
f
*
(
val
+
0.044715
f
*
val
*
val
*
val
))));
return
val
*
cdf
;
}
void
silu_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
)
{
int
d
=
input
.
size
(
-
1
)
/
2
;
int64_t
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
dim3
grid
(
num_tokens
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16
(
input
.
scalar_type
(),
c_type
,
[
&
]
{
uint32_t
vec_size
=
16
/
sizeof
(
c_type
);
dim3
block
(
std
::
min
(
d
/
vec_size
,
1024U
));
flashinfer
::
activation
::
act_and_mul_kernel
<
c_type
,
silu
>
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
c_type
*>
(
out
.
data_ptr
()),
static_cast
<
c_type
*>
(
input
.
data_ptr
()),
d
);
return
true
;
});
}
void
gelu_tanh_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
)
{
int
d
=
input
.
size
(
-
1
)
/
2
;
int64_t
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
dim3
grid
(
num_tokens
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16
(
input
.
scalar_type
(),
c_type
,
[
&
]
{
uint32_t
vec_size
=
16
/
sizeof
(
c_type
);
dim3
block
(
std
::
min
(
d
/
vec_size
,
1024U
));
flashinfer
::
activation
::
act_and_mul_kernel
<
c_type
,
gelu_tanh
>
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
c_type
*>
(
out
.
data_ptr
()),
static_cast
<
c_type
*>
(
input
.
data_ptr
()),
d
);
return
true
;
});
}
void
gelu_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
)
{
int
d
=
input
.
size
(
-
1
)
/
2
;
int64_t
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
dim3
grid
(
num_tokens
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16
(
input
.
scalar_type
(),
c_type
,
[
&
]
{
uint32_t
vec_size
=
16
/
sizeof
(
c_type
);
dim3
block
(
std
::
min
(
d
/
vec_size
,
1024U
));
flashinfer
::
activation
::
act_and_mul_kernel
<
c_type
,
gelu
>
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
c_type
*>
(
out
.
data_ptr
()),
static_cast
<
c_type
*>
(
input
.
data_ptr
()),
d
);
return
true
;
});
}
sgl-kernel/setup.py
View file @
7c0541b3
...
...
@@ -97,6 +97,7 @@ sources = [
"csrc/allreduce/trt_reduce_internal.cu"
,
"csrc/allreduce/trt_reduce_kernel.cu"
,
"csrc/attention/lightning_attention_decode_kernel.cu"
,
"csrc/elementwise/activation.cu"
,
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
,
"csrc/elementwise/rope.cu"
,
"csrc/gemm/bmm_fp8.cu"
,
...
...
@@ -111,7 +112,6 @@ sources = [
"csrc/speculative/eagle_utils.cu"
,
"csrc/speculative/speculative_sampling.cu"
,
"csrc/torch_extension.cc"
,
"3rdparty/flashinfer/csrc/activation.cu"
,
"3rdparty/flashinfer/csrc/norm.cu"
,
"3rdparty/flashinfer/csrc/renorm.cu"
,
"3rdparty/flashinfer/csrc/sampling.cu"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment