Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
42f5e7c5
Unverified
Commit
42f5e7c5
authored
Jan 15, 2025
by
Jee Jee Li
Committed by
GitHub
Jan 15, 2025
Browse files
[Kernel] Support MulAndSilu (#11624)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
a3a3ee4e
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
83 additions
and
36 deletions
+83
-36
csrc/activation_kernels.cu
csrc/activation_kernels.cu
+25
-7
csrc/ops.h
csrc/ops.h
+2
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+3
-0
tests/kernels/test_activation.py
tests/kernels/test_activation.py
+13
-7
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+35
-0
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+3
-11
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+2
-11
No files found.
csrc/activation_kernels.cu
View file @
42f5e7c5
...
...
@@ -9,8 +9,16 @@
namespace
vllm
{
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
),
bool
act_first
>
__device__
__forceinline__
scalar_t
compute
(
const
scalar_t
&
x
,
const
scalar_t
&
y
)
{
return
act_first
?
ACT_FN
(
x
)
*
y
:
x
*
ACT_FN
(
y
);
}
// Activation and gating kernel template.
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
)>
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
),
bool
act_first
>
__global__
void
act_and_mul_kernel
(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
...
...
@@ -19,7 +27,7 @@ __global__ void act_and_mul_kernel(
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
idx
]);
const
scalar_t
y
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
d
+
idx
]);
out
[
token_idx
*
d
+
idx
]
=
ACT_FN
(
x
)
*
y
;
out
[
token_idx
*
d
+
idx
]
=
compute
<
scalar_t
,
ACT_FN
,
act_first
>
(
x
,
y
)
;
}
}
...
...
@@ -55,7 +63,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
}
// namespace vllm
// Launch activation and gating kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
// Use ACT_FIRST (bool) indicating whether to apply the activation function
// first.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
...
...
@@ -64,7 +74,7 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>
>
\
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>
, ACT_FIRST>
\
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
});
...
...
@@ -72,19 +82,27 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
void
silu_and_mul
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
)
// [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
silu_kernel
);
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
silu_kernel
,
true
);
}
void
mul_and_silu
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
)
// [..., 2 * d]
{
// The difference between mul_and_silu and silu_and_mul is that mul_and_silu
// applies the silu to the latter half of the input.
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
silu_kernel
,
false
);
}
void
gelu_and_mul
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
)
// [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
gelu_kernel
);
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
gelu_kernel
,
true
);
}
void
gelu_tanh_and_mul
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
)
// [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
gelu_tanh_kernel
);
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
gelu_tanh_kernel
,
true
);
}
namespace
vllm
{
...
...
csrc/ops.h
View file @
42f5e7c5
...
...
@@ -86,6 +86,8 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
void
silu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
mul_and_silu
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_tanh_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
...
...
csrc/torch_bindings.cpp
View file @
42f5e7c5
...
...
@@ -55,6 +55,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"silu_and_mul(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"silu_and_mul"
,
torch
::
kCUDA
,
&
silu_and_mul
);
ops
.
def
(
"mul_and_silu(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"mul_and_silu"
,
torch
::
kCUDA
,
&
mul_and_silu
);
// Activation function used in GeGLU with `none` approximation.
ops
.
def
(
"gelu_and_mul(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"gelu_and_mul"
,
torch
::
kCUDA
,
&
gelu_and_mul
);
...
...
tests/kernels/test_activation.py
View file @
42f5e7c5
...
...
@@ -6,8 +6,9 @@ import torch
from
tests.kernels.utils
import
opcheck
from
vllm.model_executor.layers.activation
import
(
FastGELU
,
FatreluAndMul
,
GeluAndMul
,
NewGELU
,
QuickGELU
,
SiluAndMul
)
GeluAndMul
,
MulAndSilu
,
NewGELU
,
QuickGELU
,
SiluAndMul
)
from
vllm.platforms
import
current_platform
from
.allclose_default
import
get_default_atol
,
get_default_rtol
...
...
@@ -21,8 +22,9 @@ CUDA_DEVICES = [
]
@
pytest
.
mark
.
parametrize
(
"activation"
,
[
"silu"
,
"gelu"
,
"gelu_tanh"
,
"fatrelu"
])
@
pytest
.
mark
.
parametrize
(
"activation"
,
[
"silu_and_mul"
,
"mul_and_silu"
,
"gelu"
,
"gelu_tanh"
,
"fatrelu"
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
...
...
@@ -40,9 +42,12 @@ def test_act_and_mul(
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
)
if
activation
==
"silu"
:
if
activation
==
"silu
_and_mul
"
:
layer
=
SiluAndMul
()
fn
=
torch
.
ops
.
_C
.
silu_and_mul
if
activation
==
"mul_and_silu"
:
layer
=
MulAndSilu
()
fn
=
torch
.
ops
.
_C
.
mul_and_silu
elif
activation
==
"gelu"
:
layer
=
GeluAndMul
(
approximate
=
"none"
)
fn
=
torch
.
ops
.
_C
.
gelu_and_mul
...
...
@@ -55,8 +60,9 @@ def test_act_and_mul(
fn
=
torch
.
ops
.
_C
.
fatrelu_and_mul
out
=
layer
(
x
)
ref_out
=
layer
.
forward_native
(
x
)
# The SiLU, GELU and FatReLU implementations are equivalent to the native
# PyTorch implementations, so we can do exact comparison.
# The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
# equivalent to the native PyTorch implementations, so we can do exact
# comparison.
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
0.0
,
rtol
=
0.0
)
d
=
x
.
shape
[
-
1
]
//
2
...
...
vllm/model_executor/layers/activation.py
View file @
42f5e7c5
...
...
@@ -87,6 +87,41 @@ class SiluAndMul(CustomOp):
return
out
@
CustomOp
.
register
(
"mul_and_silu"
)
class
MulAndSilu
(
CustomOp
):
"""An activation function for SwiGLU.
The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def
__init__
(
self
):
super
().
__init__
()
if
current_platform
.
is_cuda_alike
()
or
current_platform
.
is_cpu
():
self
.
op
=
torch
.
ops
.
_C
.
mul_and_silu
elif
current_platform
.
is_xpu
():
from
vllm._ipex_ops
import
ipex_ops
self
.
op
=
ipex_ops
.
silu_and_mul
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward()."""
d
=
x
.
shape
[
-
1
]
//
2
return
x
[...,
:
d
]
*
F
.
silu
(
x
[...,
d
:])
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
self
.
op
(
out
,
x
)
return
out
# TODO implement forward_xpu for MulAndSilu
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
@
CustomOp
.
register
(
"gelu_and_mul"
)
class
GeluAndMul
(
CustomOp
):
"""An activation function for GeGLU.
...
...
vllm/model_executor/models/molmo.py
View file @
42f5e7c5
...
...
@@ -23,7 +23,8 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
InputContext
,
token_inputs
)
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.activation
import
QuickGELU
,
SiluAndMul
from
vllm.model_executor.layers.activation
import
(
MulAndSilu
,
QuickGELU
,
SiluAndMul
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
...
...
@@ -462,15 +463,6 @@ class MolmoAttention(nn.Module):
return
output
class
SwiGLU
(
nn
.
Module
):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
,
gate
=
x
.
chunk
(
2
,
dim
=-
1
)
# Note that the order is reversed compared to
# SiluAndMul.
return
x
*
F
.
silu
(
gate
)
class
LanuageModelMLP
(
nn
.
Module
):
"""Molmo's LLM mlp."""
...
...
@@ -489,7 +481,7 @@ class LanuageModelMLP(nn.Module):
quant_config
=
quant_config
,
)
# Activation function.
self
.
act_fn
=
SwiGLU
()
self
.
act_fn
=
MulAndSilu
()
# Feed-forward output projection.
self
.
down_proj
=
RowParallelLinear
(
self
.
intermediate_size
,
...
...
vllm/model_executor/models/ultravox.py
View file @
42f5e7c5
...
...
@@ -16,7 +16,7 @@ from transformers.models.whisper.modeling_whisper import WhisperEncoder
from
vllm
import
envs
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.activation
import
SiluAndMul
,
get_act_fn
from
vllm.model_executor.layers.activation
import
MulAndSilu
,
get_act_fn
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.model_loader.loader
import
DefaultModelLoader
...
...
@@ -248,15 +248,6 @@ class StackAudioFrames(nn.Module):
return
audio_embeds
class
FlippedSiluAndMul
(
SiluAndMul
):
"""Ultravox is trained with SwiGLU with flipped halves."""
def
forward
(
self
,
x
:
torch
.
Tensor
):
a
,
b
=
x
.
chunk
(
2
,
dim
=-
1
)
flipped
=
torch
.
cat
((
b
,
a
),
dim
=-
1
)
return
super
().
forward
(
flipped
)
class
UltravoxProjector
(
nn
.
Module
):
def
__init__
(
self
,
config
:
UltravoxConfig
):
...
...
@@ -269,7 +260,7 @@ class UltravoxProjector(nn.Module):
dim
=
self
.
hidden_dim
if
config
.
projector_act
==
"swiglu"
:
self
.
act
=
FlippedSiluAndMul
()
self
.
act
=
MulAndSilu
()
dim
=
dim
//
2
else
:
self
.
act
=
get_act_fn
(
config
.
projector_act
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment