Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7c0541b3
"tests/vscode:/vscode.git/clone" did not exist on "eafcb7e7f55d385099a9289b275e8371897edb9f"
Unverified
Commit
7c0541b3
authored
Mar 09, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 09, 2025
Browse files
Move activation.cu to sgl-kernel/elementwise (#4250)
parent
e8a69e4d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
83 additions
and
1 deletion
+83
-1
sgl-kernel/csrc/elementwise/activation.cu
sgl-kernel/csrc/elementwise/activation.cu
+82
-0
sgl-kernel/setup.py
sgl-kernel/setup.py
+1
-1
No files found.
sgl-kernel/csrc/elementwise/activation.cu
0 → 100644
View file @
7c0541b3
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/activation.cuh>
#include "pytorch_extension_utils.h"
using
namespace
flashinfer
;
__device__
__forceinline__
float
silu
(
const
float
&
val
)
{
return
val
/
(
1.0
f
+
__expf
(
-
val
));
}
__device__
__forceinline__
float
gelu
(
const
float
&
val
)
{
constexpr
float
kAlpha
=
M_SQRT1_2
;
return
val
*
0.5
f
*
(
1.0
f
+
::
erf
(
val
*
kAlpha
));
}
__device__
__forceinline__
float
gelu_tanh
(
const
float
&
val
)
{
const
float
cdf
=
0.5
f
*
(
1.0
f
+
math
::
tanh
((
0.7978845608028654
f
*
(
val
+
0.044715
f
*
val
*
val
*
val
))));
return
val
*
cdf
;
}
void
silu_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
)
{
int
d
=
input
.
size
(
-
1
)
/
2
;
int64_t
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
dim3
grid
(
num_tokens
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16
(
input
.
scalar_type
(),
c_type
,
[
&
]
{
uint32_t
vec_size
=
16
/
sizeof
(
c_type
);
dim3
block
(
std
::
min
(
d
/
vec_size
,
1024U
));
flashinfer
::
activation
::
act_and_mul_kernel
<
c_type
,
silu
>
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
c_type
*>
(
out
.
data_ptr
()),
static_cast
<
c_type
*>
(
input
.
data_ptr
()),
d
);
return
true
;
});
}
void
gelu_tanh_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
)
{
int
d
=
input
.
size
(
-
1
)
/
2
;
int64_t
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
dim3
grid
(
num_tokens
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16
(
input
.
scalar_type
(),
c_type
,
[
&
]
{
uint32_t
vec_size
=
16
/
sizeof
(
c_type
);
dim3
block
(
std
::
min
(
d
/
vec_size
,
1024U
));
flashinfer
::
activation
::
act_and_mul_kernel
<
c_type
,
gelu_tanh
>
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
c_type
*>
(
out
.
data_ptr
()),
static_cast
<
c_type
*>
(
input
.
data_ptr
()),
d
);
return
true
;
});
}
void
gelu_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
)
{
int
d
=
input
.
size
(
-
1
)
/
2
;
int64_t
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
dim3
grid
(
num_tokens
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16
(
input
.
scalar_type
(),
c_type
,
[
&
]
{
uint32_t
vec_size
=
16
/
sizeof
(
c_type
);
dim3
block
(
std
::
min
(
d
/
vec_size
,
1024U
));
flashinfer
::
activation
::
act_and_mul_kernel
<
c_type
,
gelu
>
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
c_type
*>
(
out
.
data_ptr
()),
static_cast
<
c_type
*>
(
input
.
data_ptr
()),
d
);
return
true
;
});
}
sgl-kernel/setup.py
View file @
7c0541b3
...
@@ -97,6 +97,7 @@ sources = [
...
@@ -97,6 +97,7 @@ sources = [
"csrc/allreduce/trt_reduce_internal.cu"
,
"csrc/allreduce/trt_reduce_internal.cu"
,
"csrc/allreduce/trt_reduce_kernel.cu"
,
"csrc/allreduce/trt_reduce_kernel.cu"
,
"csrc/attention/lightning_attention_decode_kernel.cu"
,
"csrc/attention/lightning_attention_decode_kernel.cu"
,
"csrc/elementwise/activation.cu"
,
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
,
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
,
"csrc/elementwise/rope.cu"
,
"csrc/elementwise/rope.cu"
,
"csrc/gemm/bmm_fp8.cu"
,
"csrc/gemm/bmm_fp8.cu"
,
...
@@ -111,7 +112,6 @@ sources = [
...
@@ -111,7 +112,6 @@ sources = [
"csrc/speculative/eagle_utils.cu"
,
"csrc/speculative/eagle_utils.cu"
,
"csrc/speculative/speculative_sampling.cu"
,
"csrc/speculative/speculative_sampling.cu"
,
"csrc/torch_extension.cc"
,
"csrc/torch_extension.cc"
,
"3rdparty/flashinfer/csrc/activation.cu"
,
"3rdparty/flashinfer/csrc/norm.cu"
,
"3rdparty/flashinfer/csrc/norm.cu"
,
"3rdparty/flashinfer/csrc/renorm.cu"
,
"3rdparty/flashinfer/csrc/renorm.cu"
,
"3rdparty/flashinfer/csrc/sampling.cu"
,
"3rdparty/flashinfer/csrc/sampling.cu"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment