Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bd620b01
Unverified
Commit
bd620b01
authored
Jun 20, 2024
by
Roger Wang
Committed by
GitHub
Jun 21, 2024
Browse files
[Kernel][CPU] Add Quick `gelu` to CPU (#5717)
parent
d9a252bc
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
29 additions
and
0 deletions
+29
-0
csrc/cpu/activation.cpp
csrc/cpu/activation.cpp
+19
-0
csrc/cpu/torch_bindings.cpp
csrc/cpu/torch_bindings.cpp
+4
-0
vllm/_ipex_ops.py
vllm/_ipex_ops.py
+3
-0
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+3
-0
No files found.
csrc/cpu/activation.cpp
View file @
bd620b01
...
@@ -59,6 +59,13 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) {
...
@@ -59,6 +59,13 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) {
return
w3
*
x
*
(
ones
+
t
);
return
w3
*
x
*
(
ones
+
t
);
}
}
FORCE_INLINE
vec_op
::
FP32Vec8
gelu_quick_act
(
const
vec_op
::
FP32Vec8
&
x
)
{
const
vec_op
::
FP32Vec8
zeros
(
0.0
);
const
vec_op
::
FP32Vec8
ones
(
1.0
);
const
vec_op
::
FP32Vec8
w1
(
1.702
f
);
return
x
/
(
ones
+
(
zeros
-
w1
*
x
).
exp
());
}
FORCE_INLINE
vec_op
::
FP32Vec8
gelu_act
(
const
vec_op
::
FP32Vec8
&
x
)
{
FORCE_INLINE
vec_op
::
FP32Vec8
gelu_act
(
const
vec_op
::
FP32Vec8
&
x
)
{
const
vec_op
::
FP32Vec8
ones
(
1.0
);
const
vec_op
::
FP32Vec8
ones
(
1.0
);
const
vec_op
::
FP32Vec8
w1
(
M_SQRT1_2
);
const
vec_op
::
FP32Vec8
w1
(
M_SQRT1_2
);
...
@@ -142,3 +149,15 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input) {
...
@@ -142,3 +149,15 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input) {
CPU_KERNEL_GUARD_OUT
(
gelu_fast_impl
)
CPU_KERNEL_GUARD_OUT
(
gelu_fast_impl
)
});
});
}
}
void
gelu_quick
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
)
{
int
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
int
d
=
input
.
size
(
-
1
);
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"gelu_quick_impl"
,
[
&
]
{
CPU_KERNEL_GUARD_IN
(
gelu_quick_impl
)
activation_kernel
<
scalar_t
,
gelu_quick_act
,
false
>
(
num_tokens
,
d
,
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
());
CPU_KERNEL_GUARD_OUT
(
gelu_quick_impl
)
});
}
csrc/cpu/torch_bindings.cpp
View file @
bd620b01
...
@@ -58,6 +58,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -58,6 +58,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"gelu_fast(Tensor! out, Tensor input) -> ()"
);
ops
.
def
(
"gelu_fast(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"gelu_fast"
,
torch
::
kCPU
,
&
gelu_fast
);
ops
.
impl
(
"gelu_fast"
,
torch
::
kCPU
,
&
gelu_fast
);
// Quick GELU implementation.
ops
.
def
(
"gelu_quick(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"gelu_quick"
,
torch
::
kCPU
,
&
gelu_quick
);
// Layernorm
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops
.
def
(
ops
.
def
(
...
...
vllm/_ipex_ops.py
View file @
bd620b01
...
@@ -43,6 +43,9 @@ class ipex_ops:
...
@@ -43,6 +43,9 @@ class ipex_ops:
def
gelu_new
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
def
gelu_new
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
out
.
copy_
(
torch
.
nn
.
functional
.
gelu
(
x
))
out
.
copy_
(
torch
.
nn
.
functional
.
gelu
(
x
))
# TODO add implementation of gelu_quick here
# def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
def
paged_attention_v1
(
def
paged_attention_v1
(
out
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/activation.py
View file @
bd620b01
...
@@ -155,6 +155,9 @@ class QuickGELU(CustomOp):
...
@@ -155,6 +155,9 @@ class QuickGELU(CustomOp):
ops
.
gelu_quick
(
out
,
x
)
ops
.
gelu_quick
(
out
,
x
)
return
out
return
out
# TODO implement forward_xpu for QuickGELU
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
class
ScaledActivation
(
nn
.
Module
):
class
ScaledActivation
(
nn
.
Module
):
"""An activation function with post-scale parameters.
"""An activation function with post-scale parameters.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment