Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d231153f
Commit
d231153f
authored
Aug 13, 2024
by
zhuwenwen
Browse files
feat:optimize act_and_mul_kernel
parent
d1787c31
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
95 additions
and
13 deletions
+95
-13
csrc/activation_kernels.cu
csrc/activation_kernels.cu
+95
-13
No files found.
csrc/activation_kernels.cu
View file @
d231153f
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <cmath>
...
...
@@ -23,6 +24,64 @@ __global__ void act_and_mul_kernel(
}
}
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
),
int
VEC
>
__global__
void
act_and_mul_kernel_vectorize1
(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
int
d
)
{
using
VecType
=
at
::
native
::
memory
::
aligned_vector
<
scalar_t
,
VEC
>
;
const
int
token_idx
=
blockIdx
.
x
;
int
idx
=
threadIdx
.
x
*
VEC
;
if
(
idx
<
d
)
{
const
int
x_index
=
token_idx
*
2
*
d
+
idx
;
const
int
y_index
=
token_idx
*
d
+
idx
;
VecType
*
x1
=
(
VecType
*
)(
input
+
x_index
);
VecType
*
x2
=
(
VecType
*
)(
input
+
x_index
+
d
);
VecType
*
y
=
(
VecType
*
)(
out
+
y_index
);
scalar_t
r_x1
[
VEC
];
scalar_t
r_x2
[
VEC
];
scalar_t
r_y
[
VEC
];
*
(
VecType
*
)
r_x1
=
*
x1
;
*
(
VecType
*
)
r_x2
=
*
x2
;
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC
;
i
++
)
{
const
scalar_t
t_x1
=
VLLM_LDG
(
&
r_x1
[
i
]);
const
scalar_t
t_x2
=
VLLM_LDG
(
&
r_x2
[
i
]);
r_y
[
i
]
=
ACT_FN
(
t_x1
)
*
t_x2
;
}
*
y
=
*
(
VecType
*
)
r_y
;
}
}
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
),
int
VEC
>
__global__
void
act_and_mul_kernel_vectorize2
(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
int
d
)
{
using
VecType
=
at
::
native
::
memory
::
aligned_vector
<
scalar_t
,
VEC
>
;
const
int
token_idx
=
blockIdx
.
x
;
int
idx
=
threadIdx
.
x
*
VEC
;
for
(;
idx
<
d
;
idx
+=
blockDim
.
x
*
VEC
)
{
const
int
x_index
=
token_idx
*
2
*
d
+
idx
;
const
int
y_index
=
token_idx
*
d
+
idx
;
VecType
*
x1
=
(
VecType
*
)(
input
+
x_index
);
VecType
*
x2
=
(
VecType
*
)(
input
+
x_index
+
d
);
VecType
*
y
=
(
VecType
*
)(
out
+
y_index
);
scalar_t
r_x1
[
VEC
];
scalar_t
r_x2
[
VEC
];
scalar_t
r_y
[
VEC
];
*
(
VecType
*
)
r_x1
=
*
x1
;
*
(
VecType
*
)
r_x2
=
*
x2
;
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC
;
i
++
)
{
const
scalar_t
t_x1
=
VLLM_LDG
(
&
r_x1
[
i
]);
const
scalar_t
t_x2
=
VLLM_LDG
(
&
r_x2
[
i
]);
r_y
[
i
]
=
ACT_FN
(
t_x1
)
*
t_x2
;
}
*
y
=
*
(
VecType
*
)
r_y
;
}
}
template
<
typename
T
>
__device__
__forceinline__
T
silu_kernel
(
const
T
&
x
)
{
// x * sigmoid(x)
...
...
@@ -54,19 +113,42 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
}
// namespace vllm
// Launch activation and gating kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
if (0 == d % 8 && d <= 16384) { \
if (d <= 512) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 2> \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 1024) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 128, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 2048) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 4096) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 512, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else { \
vllm::act_and_mul_kernel_vectorize2<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 1024, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} \
} else { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} \
});
void
silu_and_mul
(
torch
::
Tensor
&
out
,
// [..., d]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment