Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
e1054247
Unverified
Commit
e1054247
authored
Nov 19, 2023
by
ljss
Committed by
GitHub
Nov 18, 2023
Browse files
[Optimization] Implement fused add rmsnorm (#1667)
parent
8d17774f
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
166 additions
and
61 deletions
+166
-61
csrc/layernorm.cpp
csrc/layernorm.cpp
+10
-0
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+55
-0
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+15
-1
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+15
-10
vllm/model_executor/models/internlm.py
vllm/model_executor/models/internlm.py
+15
-10
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+15
-10
vllm/model_executor/models/mistral.py
vllm/model_executor/models/mistral.py
+15
-10
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+13
-10
vllm/model_executor/models/yi.py
vllm/model_executor/models/yi.py
+13
-10
No files found.
csrc/layernorm.cpp
View file @
e1054247
...
...
@@ -6,9 +6,19 @@ void rms_norm(
torch
::
Tensor
&
weight
,
float
epsilon
);
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
float
epsilon
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"rms_norm"
,
&
rms_norm
,
"Apply Root Mean Square (RMS) Normalization to the input tensor."
);
m
.
def
(
"fused_add_rms_norm"
,
&
fused_add_rms_norm
,
"In-place fused Add and RMS Normalization"
);
}
csrc/layernorm_kernels.cu
View file @
e1054247
...
...
@@ -34,6 +34,36 @@ __global__ void rms_norm_kernel(
}
}
// TODO: Further optimize this kernel.
template
<
typename
scalar_t
>
__global__
void
fused_add_rms_norm_kernel
(
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
x
+=
(
float
)
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
variance
+=
x
*
x
;
residual
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
(
scalar_t
)
x
;
}
variance
=
blockReduceSum
<
float
>
(
variance
);
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
__syncthreads
();
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
input
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
((
scalar_t
)
(
x
*
s_variance
))
*
weight
[
idx
];
}
}
}
// namespace vllm
void
rms_norm
(
...
...
@@ -60,3 +90,28 @@ void rms_norm(
hidden_size
);
});
}
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
residual
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
float
epsilon
)
{
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"fused_add_rms_norm_kernel"
,
[
&
]
{
vllm
::
fused_add_rms_norm_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t
>
(),
residual
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
}
vllm/model_executor/layers/layernorm.py
View file @
e1054247
"""Custom normalization layers."""
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
...
...
@@ -21,7 +23,19 @@ class RMSNorm(nn.Module):
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
residual
is
not
None
:
layernorm_ops
.
fused_add_rms_norm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
,
)
return
x
,
residual
out
=
torch
.
empty_like
(
x
)
layernorm_ops
.
rms_norm
(
out
,
...
...
vllm/model_executor/models/baichuan.py
View file @
e1054247
...
...
@@ -225,10 +225,15 @@ class BaiChuanDecoderLayer(nn.Module):
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
...
...
@@ -236,14 +241,12 @@ class BaiChuanDecoderLayer(nn.Module):
input_metadata
=
input_metadata
,
cache_event
=
cache_event
,
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
return
hidden_states
,
residual
class
BaiChuanModel
(
nn
.
Module
):
...
...
@@ -276,20 +279,22 @@ class BaiChuanModel(nn.Module):
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
if
cache_events
is
None
:
cache_event
=
None
else
:
cache_event
=
cache_events
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
,
residual
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
...
...
vllm/model_executor/models/internlm.py
View file @
e1054247
...
...
@@ -155,10 +155,15 @@ class InternLMDecoderLayer(nn.Module):
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
...
...
@@ -166,14 +171,12 @@ class InternLMDecoderLayer(nn.Module):
input_metadata
=
input_metadata
,
cache_event
=
cache_event
,
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
return
hidden_states
,
residual
class
InternLMModel
(
nn
.
Module
):
...
...
@@ -208,20 +211,22 @@ class InternLMModel(nn.Module):
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
if
cache_events
is
None
:
cache_event
=
None
else
:
cache_event
=
cache_events
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
,
residual
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
...
...
vllm/model_executor/models/llama.py
View file @
e1054247
...
...
@@ -197,10 +197,15 @@ class LlamaDecoderLayer(nn.Module):
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
...
...
@@ -208,14 +213,12 @@ class LlamaDecoderLayer(nn.Module):
input_metadata
=
input_metadata
,
cache_event
=
cache_event
,
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
return
hidden_states
,
residual
class
LlamaModel
(
nn
.
Module
):
...
...
@@ -248,20 +251,22 @@ class LlamaModel(nn.Module):
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
if
cache_events
is
None
:
cache_event
=
None
else
:
cache_event
=
cache_events
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
,
residual
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
...
...
vllm/model_executor/models/mistral.py
View file @
e1054247
...
...
@@ -191,10 +191,15 @@ class MistralDecoderLayer(nn.Module):
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
...
...
@@ -202,14 +207,12 @@ class MistralDecoderLayer(nn.Module):
input_metadata
=
input_metadata
,
cache_event
=
cache_event
,
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
return
hidden_states
,
residual
class
MistralModel
(
nn
.
Module
):
...
...
@@ -243,20 +246,22 @@ class MistralModel(nn.Module):
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
if
cache_events
is
None
:
cache_event
=
None
else
:
cache_event
=
cache_events
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
,
residual
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
...
...
vllm/model_executor/models/qwen.py
View file @
e1054247
...
...
@@ -159,10 +159,14 @@ class QWenBlock(nn.Module):
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
residual
=
hidden_states
hidden_states
=
self
.
ln_1
(
hidden_states
)
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
ln_1
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
ln_1
(
hidden_states
,
residual
)
hidden_states
=
self
.
attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
...
...
@@ -170,14 +174,11 @@ class QWenBlock(nn.Module):
input_metadata
=
input_metadata
,
cache_event
=
cache_event
,
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
ln_2
(
hidden_states
)
hidden_states
,
residual
=
self
.
ln_2
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
return
hidden_states
,
residual
class
QWenModel
(
nn
.
Module
):
...
...
@@ -210,20 +211,22 @@ class QWenModel(nn.Module):
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
torch
.
Tensor
:
hidden_states
=
self
.
wte
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
h
)):
if
cache_events
is
None
:
cache_event
=
None
else
:
cache_event
=
cache_events
[
i
]
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
,
residual
,
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
hidden_states
,
_
=
self
.
ln_f
(
hidden_states
,
residual
)
return
hidden_states
...
...
vllm/model_executor/models/yi.py
View file @
e1054247
...
...
@@ -195,10 +195,14 @@ class YiDecoderLayer(nn.Module):
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
residual
=
hidden_states
hidden_states
=
self
.
ln1
(
hidden_states
)
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
ln1
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
ln1
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
...
...
@@ -206,14 +210,11 @@ class YiDecoderLayer(nn.Module):
input_metadata
=
input_metadata
,
cache_event
=
cache_event
,
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
ln2
(
hidden_states
)
hidden_states
,
residual
=
self
.
ln2
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
return
hidden_states
,
residual
class
YiModel
(
nn
.
Module
):
...
...
@@ -246,20 +247,22 @@ class YiModel(nn.Module):
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
if
cache_events
is
None
:
cache_event
=
None
else
:
cache_event
=
cache_events
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
,
residual
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment