Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6dc7aa42
"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "38b14e8b06c4e24e13be1c3e994fee307e4d3c6e"
Commit
6dc7aa42
authored
Dec 31, 2024
by
zhuwenwen
Browse files
remove fatrelu_and_mul
parent
02c3e313
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
36 deletions
+36
-36
csrc/opt/activation_kernels_opt.cu
csrc/opt/activation_kernels_opt.cu
+36
-36
No files found.
csrc/opt/activation_kernels_opt.cu
View file @
6dc7aa42
...
...
@@ -107,41 +107,41 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
return
(
T
)(
0.5
f
*
f
*
(
1.0
f
+
::
tanhf
(
inner
)));
}
template
<
typename
T
>
__device__
__forceinline__
T
fatrelu_kernel
(
const
T
&
x
,
const
float
threshold
)
{
const
float
f
=
(
float
)
x
;
return
(
T
)(
f
>
threshold
?
f
:
0.0
f
);
}
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
,
const
float
)>
__global__
void
act_and_mul_kernel_with_param
(
scalar_t
*
__restrict__
out
,
const
scalar_t
*
__restrict__
input
,
const
int
d
,
const
float
param
)
{
const
int64_t
token_idx
=
blockIdx
.
x
;
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
idx
]);
const
scalar_t
y
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
d
+
idx
]);
out
[
token_idx
*
d
+
idx
]
=
ACT_FN
(
x
,
param
)
*
y
;
}
}
//
template <typename T>
//
__device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) {
//
const float f = (float)x;
//
return (T)(f > threshold ? f : 0.0f);
//
}
//
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
//
__global__ void act_and_mul_kernel_with_param(
//
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
//
const float param) {
//
const int64_t token_idx = blockIdx.x;
//
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
//
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
//
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
//
out[token_idx * d + idx] = ACT_FN(x, param) * y;
//
}
//
}
}
// namespace vllm
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \
vllm::act_and_mul_kernel_with_param<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d, \
PARAM); \
});
//
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
//
int d = input.size(-1) / 2; \
//
int64_t num_tokens = input.numel() / input.size(-1); \
//
dim3 grid(num_tokens); \
//
dim3 block(std::min(d, 1024)); \
//
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
//
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
//
VLLM_DISPATCH_FLOATING_TYPES( \
//
input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \
//
vllm::act_and_mul_kernel_with_param<scalar_t, KERNEL<scalar_t>> \
//
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
//
input.data_ptr<scalar_t>(), d, \
//
PARAM); \
//
});
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
...
...
@@ -200,8 +200,8 @@ void gelu_tanh_and_mul_opt(torch::Tensor& out, // [..., d]
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
gelu_tanh_kernel
);
}
void
fatrelu_and_mul
(
torch
::
Tensor
&
out
,
// [..., d],
torch
::
Tensor
&
input
,
// [..., 2 * d]
double
threshold
)
{
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM
(
vllm
::
fatrelu_kernel
,
threshold
);
}
\ No newline at end of file
// void fatrelu_and_mul_opt(torch::Tensor& out, // [..., d],
// torch::Tensor& input, // [..., 2 * d]
// double threshold) {
// LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
// }
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment