Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
41ca62cf
Unverified
Commit
41ca62cf
authored
Jun 05, 2024
by
Woosuk Kwon
Committed by
GitHub
Jun 05, 2024
Browse files
[Misc] Add CustomOp interface for device portability (#5255)
parent
974fc9b8
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
100 additions
and
27 deletions
+100
-27
tests/kernels/test_activation.py
tests/kernels/test_activation.py
+2
-2
tests/kernels/test_layernorm.py
tests/kernels/test_layernorm.py
+1
-1
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+4
-3
vllm/model_executor/custom_op.py
vllm/model_executor/custom_op.py
+60
-0
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+21
-13
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+6
-4
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+6
-4
No files found.
tests/kernels/test_activation.py
View file @
41ca62cf
...
@@ -44,7 +44,7 @@ def test_act_and_mul(
...
@@ -44,7 +44,7 @@ def test_act_and_mul(
elif
activation
==
"gelu_tanh"
:
elif
activation
==
"gelu_tanh"
:
layer
=
GeluAndMul
(
approximate
=
"tanh"
)
layer
=
GeluAndMul
(
approximate
=
"tanh"
)
out
=
layer
(
x
)
out
=
layer
(
x
)
ref_out
=
layer
.
_
forward
(
x
)
ref_out
=
layer
.
forward
_native
(
x
)
# The SiLU and GELU implementations are equivalent to the native PyTorch
# The SiLU and GELU implementations are equivalent to the native PyTorch
# implementations, so we can do exact comparison.
# implementations, so we can do exact comparison.
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
0.0
,
rtol
=
0.0
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
0.0
,
rtol
=
0.0
)
...
@@ -72,7 +72,7 @@ def test_activation(
...
@@ -72,7 +72,7 @@ def test_activation(
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
)
layer
=
activation
()
layer
=
activation
()
out
=
layer
(
x
)
out
=
layer
(
x
)
ref_out
=
layer
.
_
forward
(
x
)
ref_out
=
layer
.
forward
_native
(
x
)
assert
torch
.
allclose
(
out
,
assert
torch
.
allclose
(
out
,
ref_out
,
ref_out
,
atol
=
get_default_atol
(
out
),
atol
=
get_default_atol
(
out
),
...
...
tests/kernels/test_layernorm.py
View file @
41ca62cf
...
@@ -42,7 +42,7 @@ def test_rms_norm(
...
@@ -42,7 +42,7 @@ def test_rms_norm(
# NOTE(woosuk): The reference implementation should be executed first
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
# because the custom kernel is in-place.
ref_out
=
layer
.
_
forward
(
x
,
residual
)
ref_out
=
layer
.
forward
_native
(
x
,
residual
)
out
=
layer
(
x
,
residual
)
out
=
layer
(
x
,
residual
)
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
# numerical errors than other operators because they involve reductions.
# numerical errors than other operators because they involve reductions.
...
...
tests/kernels/test_pos_encoding.py
View file @
41ca62cf
...
@@ -64,7 +64,7 @@ def test_rotary_embedding(
...
@@ -64,7 +64,7 @@ def test_rotary_embedding(
# NOTE(woosuk): The reference implementation should be executed first
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
# because the custom kernel is in-place.
ref_query
,
ref_key
=
rope
.
_
forward
(
positions
,
query
,
key
)
ref_query
,
ref_key
=
rope
.
forward
_native
(
positions
,
query
,
key
)
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
)
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
)
# Compare the results.
# Compare the results.
assert
torch
.
allclose
(
out_query
,
assert
torch
.
allclose
(
out_query
,
...
@@ -121,7 +121,7 @@ def test_batched_rotary_embedding(
...
@@ -121,7 +121,7 @@ def test_batched_rotary_embedding(
# NOTE(woosuk): The reference implementation should be executed first
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
# because the custom kernel is in-place.
ref_query
,
ref_key
=
rope
.
_
forward
(
positions
,
query
,
key
)
ref_query
,
ref_key
=
rope
.
forward
_native
(
positions
,
query
,
key
)
out_query
,
out_key
=
rope
.
forward
(
positions
,
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
query
,
key
,
key
,
...
@@ -195,7 +195,8 @@ def test_batched_rotary_embedding_multi_lora(
...
@@ -195,7 +195,8 @@ def test_batched_rotary_embedding_multi_lora(
# NOTE(woosuk): The reference implementation should be executed first
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
# because the custom kernel is in-place.
ref_query
,
ref_key
=
rope
.
_forward
(
positions
,
query
,
key
,
query_offsets
)
ref_query
,
ref_key
=
rope
.
forward_native
(
positions
,
query
,
key
,
query_offsets
)
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
,
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
,
query_offsets
.
flatten
())
query_offsets
.
flatten
())
# Compare the results.
# Compare the results.
...
...
vllm/model_executor/custom_op.py
0 → 100644
View file @
41ca62cf
import
torch.nn
as
nn
from
vllm.utils
import
is_cpu
,
is_hip
class
CustomOp
(
nn
.
Module
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
()
self
.
_forward_method
=
self
.
dispatch_forward
()
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
_forward_method
(
*
args
,
**
kwargs
)
def
forward_native
(
self
,
*
args
,
**
kwargs
):
"""PyTorch-native implementation of the forward method.
This method is optional. If implemented, it can be used with compilers
such as torch.compile or PyTorch XLA. Also, it can be used for testing
purposes.
"""
raise
NotImplementedError
def
forward_cuda
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
def
forward_hip
(
self
,
*
args
,
**
kwargs
):
# By default, we assume that HIP ops are compatible with CUDA ops.
return
self
.
forward_cuda
(
*
args
,
**
kwargs
)
def
forward_xpu
(
self
,
*
args
,
**
kwargs
):
# By default, we assume that XPU ops are compatible with CUDA ops.
# NOTE(woosuk): This is a placeholder for future extensions.
return
self
.
forward_cuda
(
*
args
,
**
kwargs
)
def
forward_cpu
(
self
,
*
args
,
**
kwargs
):
# By default, we assume that CPU ops are compatible with CUDA ops.
return
self
.
forward_cuda
(
*
args
,
**
kwargs
)
def
forward_tpu
(
self
,
*
args
,
**
kwargs
):
# By default, we assume that TPU ops are compatible with the
# PyTorch-native implementation.
# NOTE(woosuk): This is a placeholder for future extensions.
return
self
.
forward_native
(
*
args
,
**
kwargs
)
def
forward_gaudi
(
self
,
*
args
,
**
kwargs
):
# By default, we assume that Gaudi ops are compatible with the
# PyTorch-native implementation.
# NOTE(woosuk): This is a placeholder for future extensions.
return
self
.
forward_native
(
*
args
,
**
kwargs
)
def
dispatch_forward
(
self
):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
if
is_hip
():
return
self
.
forward_hip
elif
is_cpu
():
return
self
.
forward_cpu
else
:
return
self
.
forward_cuda
vllm/model_executor/layers/activation.py
View file @
41ca62cf
...
@@ -6,14 +6,14 @@ import torch
...
@@ -6,14 +6,14 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
vllm
import
_custom_ops
as
ops
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
class
SiluAndMul
(
nn
.
Module
):
class
SiluAndMul
(
CustomOp
):
"""An activation function for SwiGLU.
"""An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
...
@@ -23,12 +23,14 @@ class SiluAndMul(nn.Module):
...
@@ -23,12 +23,14 @@ class SiluAndMul(nn.Module):
return: (num_tokens, d) or (batch_size, seq_len, d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
"""
def
_
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward()."""
"""PyTorch-native implementation equivalent to forward()."""
d
=
x
.
shape
[
-
1
]
//
2
d
=
x
.
shape
[
-
1
]
//
2
return
F
.
silu
(
x
[...,
:
d
])
*
x
[...,
d
:]
return
F
.
silu
(
x
[...,
:
d
])
*
x
[...,
d
:]
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
from
vllm
import
_custom_ops
as
ops
d
=
x
.
shape
[
-
1
]
//
2
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
...
@@ -36,7 +38,7 @@ class SiluAndMul(nn.Module):
...
@@ -36,7 +38,7 @@ class SiluAndMul(nn.Module):
return
out
return
out
class
GeluAndMul
(
nn
.
Module
):
class
GeluAndMul
(
CustomOp
):
"""An activation function for GeGLU.
"""An activation function for GeGLU.
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
...
@@ -52,12 +54,14 @@ class GeluAndMul(nn.Module):
...
@@ -52,12 +54,14 @@ class GeluAndMul(nn.Module):
if
approximate
not
in
(
"none"
,
"tanh"
):
if
approximate
not
in
(
"none"
,
"tanh"
):
raise
ValueError
(
f
"Unknown approximate mode:
{
approximate
}
"
)
raise
ValueError
(
f
"Unknown approximate mode:
{
approximate
}
"
)
def
_
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward()."""
"""PyTorch-native implementation equivalent to forward()."""
d
=
x
.
shape
[
-
1
]
//
2
d
=
x
.
shape
[
-
1
]
//
2
return
F
.
gelu
(
x
[...,
:
d
],
approximate
=
self
.
approximate
)
*
x
[...,
d
:]
return
F
.
gelu
(
x
[...,
:
d
],
approximate
=
self
.
approximate
)
*
x
[...,
d
:]
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
from
vllm
import
_custom_ops
as
ops
d
=
x
.
shape
[
-
1
]
//
2
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
...
@@ -71,28 +75,32 @@ class GeluAndMul(nn.Module):
...
@@ -71,28 +75,32 @@ class GeluAndMul(nn.Module):
return
f
'approximate=
{
repr
(
self
.
approximate
)
}
'
return
f
'approximate=
{
repr
(
self
.
approximate
)
}
'
class
NewGELU
(
nn
.
Module
):
class
NewGELU
(
CustomOp
):
def
_
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward()."""
"""PyTorch-native implementation equivalent to forward()."""
c
=
math
.
sqrt
(
2.0
/
math
.
pi
)
c
=
math
.
sqrt
(
2.0
/
math
.
pi
)
return
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
c
*
return
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
c
*
(
x
+
0.044715
*
torch
.
pow
(
x
,
3.0
))))
(
x
+
0.044715
*
torch
.
pow
(
x
,
3.0
))))
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
from
vllm
import
_custom_ops
as
ops
out
=
torch
.
empty_like
(
x
)
out
=
torch
.
empty_like
(
x
)
ops
.
gelu_new
(
out
,
x
)
ops
.
gelu_new
(
out
,
x
)
return
out
return
out
class
FastGELU
(
nn
.
Module
):
class
FastGELU
(
CustomOp
):
def
_
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward()."""
"""PyTorch-native implementation equivalent to forward()."""
return
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
x
*
0.7978845608
*
return
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
x
*
0.7978845608
*
(
1.0
+
0.044715
*
x
*
x
)))
(
1.0
+
0.044715
*
x
*
x
)))
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
from
vllm
import
_custom_ops
as
ops
out
=
torch
.
empty_like
(
x
)
out
=
torch
.
empty_like
(
x
)
ops
.
gelu_fast
(
out
,
x
)
ops
.
gelu_fast
(
out
,
x
)
return
out
return
out
...
...
vllm/model_executor/layers/layernorm.py
View file @
41ca62cf
...
@@ -4,10 +4,10 @@ from typing import Optional, Tuple, Union
...
@@ -4,10 +4,10 @@ from typing import Optional, Tuple, Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm
import
_custom_ops
as
ops
from
vllm
.model_executor.custom_op
import
CustomOp
class
RMSNorm
(
nn
.
Module
):
class
RMSNorm
(
CustomOp
):
"""Root mean square normalization.
"""Root mean square normalization.
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
...
@@ -23,7 +23,7 @@ class RMSNorm(nn.Module):
...
@@ -23,7 +23,7 @@ class RMSNorm(nn.Module):
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
self
.
variance_epsilon
=
eps
def
_
forward
(
def
forward
_native
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -43,11 +43,13 @@ class RMSNorm(nn.Module):
...
@@ -43,11 +43,13 @@ class RMSNorm(nn.Module):
else
:
else
:
return
x
,
residual
return
x
,
residual
def
forward
(
def
forward
_cuda
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
from
vllm
import
_custom_ops
as
ops
if
residual
is
not
None
:
if
residual
is
not
None
:
ops
.
fused_add_rms_norm
(
ops
.
fused_add_rms_norm
(
x
,
x
,
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
41ca62cf
...
@@ -27,7 +27,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
...
@@ -27,7 +27,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm
import
_custom_ops
as
ops
from
vllm
.model_executor.custom_op
import
CustomOp
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -43,7 +43,7 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
...
@@ -43,7 +43,7 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
return
x
.
flatten
(
-
2
)
return
x
.
flatten
(
-
2
)
class
RotaryEmbedding
(
nn
.
Module
):
class
RotaryEmbedding
(
CustomOp
):
"""Original rotary positional embedding."""
"""Original rotary positional embedding."""
def
__init__
(
def
__init__
(
...
@@ -93,7 +93,7 @@ class RotaryEmbedding(nn.Module):
...
@@ -93,7 +93,7 @@ class RotaryEmbedding(nn.Module):
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
return
cache
def
_
forward
(
def
forward
_native
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
@@ -138,13 +138,15 @@ class RotaryEmbedding(nn.Module):
...
@@ -138,13 +138,15 @@ class RotaryEmbedding(nn.Module):
key
=
key
.
flatten
(
-
2
)
key
=
key
.
flatten
(
-
2
)
return
query
,
key
return
query
,
key
def
forward
(
def
forward
_cuda
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
vllm
import
_custom_ops
as
ops
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
positions
.
device
,
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
positions
.
device
,
dtype
=
query
.
dtype
)
dtype
=
query
.
dtype
)
# ops.rotary_embedding()/batched_rotary_embedding()
# ops.rotary_embedding()/batched_rotary_embedding()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment