Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
9b294976
Unverified
Commit
9b294976
authored
Dec 02, 2023
by
Woosuk Kwon
Committed by
GitHub
Dec 02, 2023
Browse files
Add PyTorch-native implementation of custom layers (#1898)
parent
5313c2cb
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
150 additions
and
185 deletions
+150
-185
tests/kernels/test_activation.py
tests/kernels/test_activation.py
+10
-17
tests/kernels/test_layernorm.py
tests/kernels/test_layernorm.py
+24
-35
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+24
-133
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+18
-0
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+20
-0
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+54
-0
No files found.
tests/kernels/test_activation.py
View file @
9b294976
import
pytest
import
pytest
import
torch
import
torch
import
torch.nn.functional
as
F
from
transformers.activations
import
get_activation
from
vllm.
_C
import
ops
from
vllm.
model_executor.layers.activation
import
FastGELU
,
NewGELU
,
SiluAndMul
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_TOKENS
=
[
7
,
83
,
2048
]
# Arbitrary values for testing
NUM_TOKENS
=
[
7
,
83
,
2048
]
# Arbitrary values for testing
...
@@ -11,11 +9,6 @@ D = [512, 4096, 5120, 13824] # Arbitrary values for testing
...
@@ -11,11 +9,6 @@ D = [512, 4096, 5120, 13824] # Arbitrary values for testing
SEEDS
=
[
0
]
SEEDS
=
[
0
]
def
ref_silu_and_mul
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
,
x2
=
x
.
chunk
(
chunks
=
2
,
dim
=
1
)
return
F
.
silu
(
x1
)
*
x2
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
...
@@ -30,9 +23,9 @@ def test_silu_and_mul(
...
@@ -30,9 +23,9 @@ def test_silu_and_mul(
torch
.
random
.
manual_seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
,
device
=
"cuda"
)
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
,
device
=
"cuda"
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
"cuda"
)
layer
=
SiluAndMul
(
)
o
ps
.
silu_and_mul
(
out
,
x
)
o
ut
=
layer
(
x
)
ref_out
=
ref_silu_and_mul
(
x
)
ref_out
=
layer
.
_forward
(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
...
@@ -50,9 +43,9 @@ def test_gelu_new(
...
@@ -50,9 +43,9 @@ def test_gelu_new(
torch
.
random
.
manual_seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
"cuda"
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
"cuda"
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
"cuda"
)
layer
=
NewGELU
(
)
o
ps
.
gelu_new
(
out
,
x
)
o
ut
=
layer
(
x
)
ref_out
=
get_activation
(
"gelu_new"
)
(
x
)
ref_out
=
layer
.
_forward
(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
...
@@ -69,7 +62,7 @@ def test_gelu_fast(
...
@@ -69,7 +62,7 @@ def test_gelu_fast(
torch
.
random
.
manual_seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
"cuda"
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
"cuda"
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
"cuda"
)
layer
=
FastGELU
(
)
o
ps
.
gelu_fast
(
out
,
x
)
o
ut
=
layer
(
x
)
ref_out
=
get_activation
(
"gelu_fast"
)
(
x
)
ref_out
=
layer
.
_forward
(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
tests/kernels/test_layernorm.py
View file @
9b294976
import
pytest
import
pytest
import
torch
import
torch
import
torch.nn
as
nn
from
vllm.
_C
import
ops
from
vllm.
model_executor.layers.layernorm
import
RMSNorm
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
HIDDEN_SIZES
=
[
67
,
768
,
2048
,
5120
,
8192
]
# Arbitrary values for testing
NUM_TOKENS
=
[
7
,
83
,
4096
]
# Arbitrary values for testing
NUM_TOKENS
=
[
7
,
83
,
4096
]
# Arbitrary values for testing
HIDDEN_SIZES
=
[
768
,
5120
,
8192
]
# Arbitrary values for testing
ADD_RESIDUAL
=
[
False
,
True
]
SEEDS
=
[
0
]
SEEDS
=
[
0
]
class
RefRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
super
().
__init__
()
weight
=
torch
.
empty
(
hidden_size
)
weight
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
self
.
weight
=
nn
.
Parameter
(
weight
)
self
.
variance_epsilon
=
eps
def
forward
(
self
,
hidden_states
):
input_dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
return
self
.
weight
*
hidden_states
.
to
(
input_dtype
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"add_residual"
,
ADD_RESIDUAL
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_rms_norm
(
def
test_rms_norm
(
num_tokens
:
int
,
num_tokens
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
add_residual
:
bool
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
,
seed
:
int
,
)
->
None
:
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
scale
=
float
(
hidden_size
**-
0.5
)
layer
=
RMSNorm
(
hidden_size
).
to
(
dtype
).
cuda
()
x
=
torch
.
empty
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
layer
.
weight
.
data
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
x
.
uniform_
(
-
scale
,
scale
)
scale
=
1
/
(
2
*
hidden_size
)
ref
=
RefRMSNorm
(
hidden_size
).
to
(
dtype
).
cuda
()
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
x
*=
scale
out
=
torch
.
empty_like
(
x
)
residual
=
torch
.
randn_like
(
x
)
*
scale
if
add_residual
else
None
ops
.
rms_norm
(
out
,
# NOTE(woosuk): The reference implementation should be executed first
x
,
# because the custom kernel is in-place.
ref
.
weight
.
data
,
ref_out
=
layer
.
_forward
(
x
,
residual
)
ref
.
variance_epsilon
,
out
=
layer
(
x
,
residual
)
)
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
ref_out
=
ref
(
x
)
# numerical errors than other operators because they involve reductions.
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-5
)
# Therefore, we use a larger tolerance.
if
add_residual
:
assert
torch
.
allclose
(
out
[
0
],
ref_out
[
0
],
atol
=
1e-2
,
rtol
=
1e-2
)
assert
torch
.
allclose
(
out
[
1
],
ref_out
[
1
],
atol
=
1e-2
,
rtol
=
1e-2
)
else
:
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
tests/kernels/test_pos_encoding.py
View file @
9b294976
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
import
pytest
import
pytest
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
vllm.
_C
import
op
s
from
vllm.
model_executor.layers.rotary_embedding
import
get_r
op
e
IS_NEOX_STYLE
=
[
True
,
False
]
IS_NEOX_STYLE
=
[
True
,
False
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
ROTARY_DIMS
=
[
None
,
32
]
# None means rotary dim == head size
ROTARY_DIMS
=
[
None
,
32
]
# None means rotary dim == head size
NUM_HEADS
=
[
7
,
12
,
40
,
52
]
# Arbitrary values for testing
NUM_HEADS
=
[
7
,
17
]
# Arbitrary values for testing
NUM_TOKENS
=
[
11
,
83
,
2048
]
# Arbitrary values for testing
BATCH_SIZES
=
[
1
,
5
]
# Arbitrary values for testing
SEQ_LENS
=
[
11
,
8192
]
# Arbitrary values for testing
SEEDS
=
[
0
]
SEEDS
=
[
0
]
def
rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
rotate_gptj
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
::
2
]
x2
=
x
[...,
1
::
2
]
x
=
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
)
return
x
.
flatten
(
-
2
)
def
apply_rope
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
is_neox_style
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
rotate_fn
=
rotate_neox
if
is_neox_style
else
rotate_gptj
q_embed
=
(
q
*
cos
)
+
(
rotate_fn
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_fn
(
k
)
*
sin
)
return
q_embed
,
k_embed
class
RefRotaryEmbedding
(
nn
.
Module
):
"""Reference implementation of rotary embedding."""
def
__init__
(
self
,
dim
:
int
,
is_neox_style
:
bool
,
max_position_embeddings
:
int
=
8192
,
base
:
int
=
10000
,
)
->
None
:
super
().
__init__
()
self
.
rotary_dim
=
dim
self
.
is_neox_style
=
is_neox_style
self
.
max_position_embeddings
=
max_position_embeddings
# Create cos and sin embeddings.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
)
/
dim
))
t
=
torch
.
arange
(
max_position_embeddings
).
float
()
freqs
=
torch
.
einsum
(
"i,j->ij"
,
t
,
inv_freq
.
float
())
if
is_neox_style
:
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
else
:
emb
=
torch
.
repeat_interleave
(
freqs
,
2
,
-
1
)
cos
=
emb
.
cos
().
to
(
dtype
=
inv_freq
.
dtype
)
sin
=
emb
.
sin
().
to
(
dtype
=
inv_freq
.
dtype
)
self
.
register_buffer
(
"cos_cached"
,
cos
,
persistent
=
False
)
self
.
register_buffer
(
"sin_cached"
,
sin
,
persistent
=
False
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
# [num_tokens]
query
:
torch
.
Tensor
,
# [num_tokens, num_heads, head_size]
key
:
torch
.
Tensor
,
# [num_tokens, num_heads, head_size]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
query_rot
=
query_rot
.
transpose
(
0
,
1
)
key_rot
=
key_rot
.
transpose
(
0
,
1
)
cos
=
F
.
embedding
(
positions
,
self
.
cos_cached
)
sin
=
F
.
embedding
(
positions
,
self
.
sin_cached
)
query_rot
,
key_rot
=
apply_rope
(
query_rot
,
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query_rot
=
query_rot
.
transpose
(
0
,
1
).
contiguous
()
key_rot
=
key_rot
.
transpose
(
0
,
1
).
contiguous
()
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
)
# Output query/key shape: [num_tokens, num_tokens, head_size]
return
query
,
key
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"rotary_dim"
,
ROTARY_DIMS
)
@
pytest
.
mark
.
parametrize
(
"rotary_dim"
,
ROTARY_DIMS
)
...
@@ -108,7 +26,8 @@ class RefRotaryEmbedding(nn.Module):
...
@@ -108,7 +26,8 @@ class RefRotaryEmbedding(nn.Module):
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_rotary_embedding
(
def
test_rotary_embedding
(
is_neox_style
:
bool
,
is_neox_style
:
bool
,
num_tokens
:
int
,
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
rotary_dim
:
Optional
[
int
],
rotary_dim
:
Optional
[
int
],
...
@@ -122,53 +41,25 @@ def test_rotary_embedding(
...
@@ -122,53 +41,25 @@ def test_rotary_embedding(
torch
.
random
.
manual_seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
positions
=
torch
.
randint
(
0
,
max_position
,
(
num_tokens
,
),
device
=
"cuda"
)
if
rotary_dim
is
None
:
query
=
torch
.
randn
(
num_tokens
,
rotary_dim
=
head_size
rope
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
)
rope
=
rope
.
to
(
dtype
).
cuda
()
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
),
device
=
"cuda"
)
query
=
torch
.
randn
(
batch_size
,
seq_len
,
num_heads
*
head_size
,
num_heads
*
head_size
,
dtype
=
dtype
,
dtype
=
dtype
,
device
=
"cuda"
)
device
=
"cuda"
)
key
=
torch
.
randn
(
num_tokens
,
key
=
torch
.
randn_like
(
query
)
num_heads
*
head_size
,
dtype
=
dtype
,
device
=
"cuda"
)
# Create the rotary embedding.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
rotary_dim
))
t
=
torch
.
arange
(
max_position
).
float
()
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cos_sin_cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cos_sin_cache
=
cos_sin_cache
.
to
(
dtype
=
dtype
,
device
=
"cuda"
)
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
out_query
=
query
.
clone
()
out_key
=
key
.
clone
()
ops
.
rotary_embedding
(
positions
,
out_query
,
out_key
,
head_size
,
cos_sin_cache
,
is_neox_style
,
)
# Run the reference implementation.
ref_rotary_embedding
=
RefRotaryEmbedding
(
dim
=
rotary_dim
,
is_neox_style
=
is_neox_style
,
max_position_embeddings
=
max_position
,
base
=
base
,
).
to
(
dtype
=
dtype
,
device
=
"cuda"
)
ref_query
,
ref_key
=
ref_rotary_embedding
(
positions
,
query
.
view
(
num_tokens
,
num_heads
,
head_size
),
key
.
view
(
num_tokens
,
num_heads
,
head_size
),
)
ref_query
=
ref_query
.
view
(
num_tokens
,
num_heads
*
head_size
)
ref_key
=
ref_key
.
view
(
num_tokens
,
num_heads
*
head_size
)
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query
,
ref_key
=
rope
.
_forward
(
positions
,
query
,
key
)
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
)
# Compare the results.
# Compare the results.
assert
torch
.
allclose
(
out_query
,
ref_query
,
atol
=
1e-5
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out_query
,
ref_query
,
atol
=
1e-5
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
1e-5
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
1e-5
,
rtol
=
1e-5
)
vllm/model_executor/layers/activation.py
View file @
9b294976
"""Custom activation functions."""
"""Custom activation functions."""
import
math
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
vllm._C
import
ops
from
vllm._C
import
ops
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
...
@@ -22,6 +24,11 @@ class SiluAndMul(nn.Module):
...
@@ -22,6 +24,11 @@ class SiluAndMul(nn.Module):
return: (batch_size, seq_len, d) or (num_tokens, d)
return: (batch_size, seq_len, d) or (num_tokens, d)
"""
"""
def
_forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward()."""
d
=
x
.
shape
[
-
1
]
//
2
return
F
.
silu
(
x
[...,
:
d
])
*
x
[...,
d
:]
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
...
@@ -32,6 +39,12 @@ class SiluAndMul(nn.Module):
...
@@ -32,6 +39,12 @@ class SiluAndMul(nn.Module):
class
NewGELU
(
nn
.
Module
):
class
NewGELU
(
nn
.
Module
):
def
_forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward()."""
c
=
math
.
sqrt
(
2.0
/
math
.
pi
)
return
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
c
*
(
x
+
0.044715
*
torch
.
pow
(
x
,
3.0
))))
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
out
=
torch
.
empty_like
(
x
)
out
=
torch
.
empty_like
(
x
)
ops
.
gelu_new
(
out
,
x
)
ops
.
gelu_new
(
out
,
x
)
...
@@ -40,6 +53,11 @@ class NewGELU(nn.Module):
...
@@ -40,6 +53,11 @@ class NewGELU(nn.Module):
class
FastGELU
(
nn
.
Module
):
class
FastGELU
(
nn
.
Module
):
def
_forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward()."""
return
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
x
*
0.7978845608
*
(
1.0
+
0.044715
*
x
*
x
)))
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
out
=
torch
.
empty_like
(
x
)
out
=
torch
.
empty_like
(
x
)
ops
.
gelu_fast
(
out
,
x
)
ops
.
gelu_fast
(
out
,
x
)
...
...
vllm/model_executor/layers/layernorm.py
View file @
9b294976
...
@@ -23,6 +23,26 @@ class RMSNorm(nn.Module):
...
@@ -23,6 +23,26 @@ class RMSNorm(nn.Module):
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
self
.
variance_epsilon
=
eps
def
_forward
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype
=
x
.
dtype
x
=
x
.
to
(
torch
.
float32
)
if
residual
is
not
None
:
x
=
x
+
residual
.
to
(
torch
.
float32
)
residual
=
x
.
to
(
orig_dtype
)
variance
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
x
=
x
.
to
(
orig_dtype
)
*
self
.
weight
if
residual
is
None
:
return
x
else
:
return
x
,
residual
def
forward
(
def
forward
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
9b294976
...
@@ -30,6 +30,19 @@ import torch.nn as nn
...
@@ -30,6 +30,19 @@ import torch.nn as nn
from
vllm._C
import
ops
from
vllm._C
import
ops
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
_rotate_gptj
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
::
2
]
x2
=
x
[...,
1
::
2
]
x
=
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
)
return
x
.
flatten
(
-
2
)
class
RotaryEmbedding
(
nn
.
Module
):
class
RotaryEmbedding
(
nn
.
Module
):
"""Original rotary positional embedding."""
"""Original rotary positional embedding."""
...
@@ -81,6 +94,47 @@ class RotaryEmbedding(nn.Module):
...
@@ -81,6 +94,47 @@ class RotaryEmbedding(nn.Module):
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
return
cache
def
_forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""PyTorch-native implementation equivalent to forward()."""
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
if
self
.
rotary_dim
<
self
.
head_size
:
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
cos_sin
=
self
.
cos_sin_cache
[
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
self
.
is_neox_style
:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos
=
cos
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
else
:
cos
=
cos
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
rotate_fn
=
_rotate_neox
if
self
.
is_neox_style
else
_rotate_gptj
query_rot
=
query_rot
*
cos
+
rotate_fn
(
query_rot
)
*
sin
key_rot
=
key_rot
*
cos
+
rotate_fn
(
key_rot
)
*
sin
if
self
.
rotary_dim
<
self
.
head_size
:
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
)
else
:
query
=
query_rot
key
=
key_rot
query
=
query
.
flatten
(
-
2
)
key
=
key
.
flatten
(
-
2
)
return
query
,
key
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment