Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
c9bdaa8b
Unverified
Commit
c9bdaa8b
authored
Mar 28, 2023
by
OlivierDehaene
Committed by
GitHub
Mar 28, 2023
Browse files
feat(server): reduce mlp and attn in one op for flash neox (#145)
parent
f0000689
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
102 additions
and
104 deletions
+102
-104
server/text_generation_server/models/flash_neox_modeling.py
server/text_generation_server/models/flash_neox_modeling.py
+102
-104
No files found.
server/text_generation_server/models/flash_neox_modeling.py
View file @
c9bdaa8b
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch
import
torch.distributed
import
torch.distributed
...
@@ -16,6 +36,42 @@ import dropout_layer_norm
...
@@ -16,6 +36,42 @@ import dropout_layer_norm
from
flash_attn.layers.rotary
import
RotaryEmbedding
from
flash_attn.layers.rotary
import
RotaryEmbedding
class
FastLayerNorm
(
nn
.
LayerNorm
):
def
forward
(
self
,
hidden_states
,
residual
=
None
):
if
hidden_states
.
shape
[
-
1
]
>
6144
:
if
residual
is
not
None
:
hidden_states
+=
residual
residual
=
hidden_states
return
super
(
FastLayerNorm
,
self
).
forward
(
hidden_states
),
residual
else
:
(
normed_hidden_states
,
residual
,
*
rest
,
)
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
hidden_states
,
residual
,
self
.
weight
,
self
.
bias
,
None
,
None
,
None
,
None
,
0.0
,
self
.
eps
,
1.0
,
0
,
None
,
False
,
False
,
)
if
residual
is
None
:
residual
=
hidden_states
return
normed_hidden_states
,
residual
class
FastLinear
(
nn
.
Linear
):
class
FastLinear
(
nn
.
Linear
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -59,9 +115,6 @@ class TensorParallelColumnLinear(FastLinear):
...
@@ -59,9 +115,6 @@ class TensorParallelColumnLinear(FastLinear):
dtype
=
dtype
,
dtype
=
dtype
,
)
)
def
forward
(
self
,
input
):
return
super
(
TensorParallelColumnLinear
,
self
).
forward
(
input
)
class
TensorParallelRowLinear
(
FastLinear
):
class
TensorParallelRowLinear
(
FastLinear
):
def
__init__
(
def
__init__
(
...
@@ -69,12 +122,14 @@ class TensorParallelRowLinear(FastLinear):
...
@@ -69,12 +122,14 @@ class TensorParallelRowLinear(FastLinear):
in_features
,
in_features
,
out_features
,
out_features
,
process_group
:
torch
.
distributed
.
ProcessGroup
,
process_group
:
torch
.
distributed
.
ProcessGroup
,
reduce
=
True
,
bias
=
True
,
bias
=
True
,
device
=
None
,
device
=
None
,
dtype
=
None
,
dtype
=
None
,
):
):
self
.
process_group
=
process_group
self
.
process_group
=
process_group
self
.
tp_world_size
=
process_group
.
size
()
self
.
tp_world_size
=
process_group
.
size
()
self
.
reduce
=
reduce
assert
in_features
%
self
.
tp_world_size
==
0
assert
in_features
%
self
.
tp_world_size
==
0
in_features
=
in_features
//
self
.
tp_world_size
in_features
=
in_features
//
self
.
tp_world_size
...
@@ -88,7 +143,8 @@ class TensorParallelRowLinear(FastLinear):
...
@@ -88,7 +143,8 @@ class TensorParallelRowLinear(FastLinear):
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
out
=
super
(
TensorParallelRowLinear
,
self
).
forward
(
input
)
out
=
super
(
TensorParallelRowLinear
,
self
).
forward
(
input
)
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
process_group
)
if
self
.
reduce
:
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
process_group
)
return
out
return
out
...
@@ -196,7 +252,13 @@ class PositionRotaryEmbedding(RotaryEmbedding):
...
@@ -196,7 +252,13 @@ class PositionRotaryEmbedding(RotaryEmbedding):
class
FlashNeoxAttention
(
torch
.
nn
.
Module
):
class
FlashNeoxAttention
(
torch
.
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
num_heads
,
hidden_size
,
rotary_pct
,
rotary_emb_base
,
process_group
=
None
self
,
num_heads
,
hidden_size
,
rotary_pct
,
rotary_emb_base
,
process_group
=
None
,
reduce
=
True
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
...
@@ -218,9 +280,7 @@ class FlashNeoxAttention(torch.nn.Module):
...
@@ -218,9 +280,7 @@ class FlashNeoxAttention(torch.nn.Module):
process_group
=
process_group
,
process_group
=
process_group
,
)
)
self
.
dense
=
TensorParallelRowLinear
(
self
.
dense
=
TensorParallelRowLinear
(
hidden_size
,
hidden_size
,
hidden_size
,
process_group
=
process_group
,
reduce
=
reduce
hidden_size
,
process_group
=
process_group
,
)
)
def
shuffle_qkv_dims
(
self
):
def
shuffle_qkv_dims
(
self
):
...
@@ -309,7 +369,9 @@ class FlashNeoxAttention(torch.nn.Module):
...
@@ -309,7 +369,9 @@ class FlashNeoxAttention(torch.nn.Module):
class
FlashMLP
(
nn
.
Module
):
class
FlashMLP
(
nn
.
Module
):
def
__init__
(
self
,
act
,
hidden_size
,
intermediate_size
,
process_group
=
None
):
def
__init__
(
self
,
act
,
hidden_size
,
intermediate_size
,
process_group
=
None
,
reduce
=
True
):
super
().
__init__
()
super
().
__init__
()
self
.
act
=
(
self
.
act
=
(
ACT2FN
[
act
]
ACT2FN
[
act
]
...
@@ -330,6 +392,7 @@ class FlashMLP(nn.Module):
...
@@ -330,6 +392,7 @@ class FlashMLP(nn.Module):
intermediate_size
,
intermediate_size
,
hidden_size
,
hidden_size
,
process_group
=
process_group
,
process_group
=
process_group
,
reduce
=
reduce
,
)
)
self
.
process_group
=
process_group
self
.
process_group
=
process_group
...
@@ -355,12 +418,24 @@ class FlashNeoXLayer(nn.Module):
...
@@ -355,12 +418,24 @@ class FlashNeoXLayer(nn.Module):
):
):
super
().
__init__
()
super
().
__init__
()
self
.
use_parallel_residual
=
use_parallel_residual
self
.
use_parallel_residual
=
use_parallel_residual
self
.
input_layernorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
layer_norm_eps
)
self
.
input_layernorm
=
Fast
LayerNorm
(
hidden_size
,
eps
=
layer_norm_eps
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
layer_norm_eps
)
self
.
post_attention_layernorm
=
Fast
LayerNorm
(
hidden_size
,
eps
=
layer_norm_eps
)
self
.
attention
=
FlashNeoxAttention
(
self
.
attention
=
FlashNeoxAttention
(
num_heads
,
hidden_size
,
rotary_pct
,
rotary_emb_base
,
process_group
num_heads
,
hidden_size
,
rotary_pct
,
rotary_emb_base
,
process_group
,
reduce
=
not
use_parallel_residual
,
)
self
.
mlp
=
FlashMLP
(
act
,
hidden_size
,
intermediate_size
,
process_group
,
reduce
=
not
use_parallel_residual
,
)
)
self
.
mlp
=
FlashMLP
(
act
,
hidden_size
,
intermediate_size
,
process_group
)
self
.
process_group
=
process_group
def
forward
(
def
forward
(
self
,
self
,
...
@@ -375,24 +450,7 @@ class FlashNeoXLayer(nn.Module):
...
@@ -375,24 +450,7 @@ class FlashNeoXLayer(nn.Module):
cu_seqlens_q
,
cu_seqlens_q
,
):
):
if
self
.
use_parallel_residual
:
if
self
.
use_parallel_residual
:
# faster input layer norm
ln1_hidden_states
,
_
=
self
.
input_layernorm
(
hidden_states
)
ln1_hidden_states
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
hidden_states
,
None
,
self
.
input_layernorm
.
weight
,
self
.
input_layernorm
.
bias
,
None
,
None
,
None
,
None
,
0.0
,
self
.
input_layernorm
.
eps
,
1.0
,
0
,
None
,
False
,
False
,
)
attn_output
=
self
.
attention
(
attn_output
=
self
.
attention
(
ln1_hidden_states
,
ln1_hidden_states
,
...
@@ -405,46 +463,18 @@ class FlashNeoXLayer(nn.Module):
...
@@ -405,46 +463,18 @@ class FlashNeoXLayer(nn.Module):
cu_seqlens_q
,
cu_seqlens_q
,
)
)
# faster post attention layer norm
ln2_hidden_states
,
_
=
self
.
post_attention_layernorm
(
hidden_states
)
ln2_hidden_states
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
hidden_states
,
None
,
self
.
post_attention_layernorm
.
weight
,
self
.
post_attention_layernorm
.
bias
,
None
,
None
,
None
,
None
,
0.0
,
self
.
post_attention_layernorm
.
eps
,
1.0
,
0
,
None
,
False
,
False
,
)
mlp_output
=
self
.
mlp
(
ln2_hidden_states
)
mlp_output
=
self
.
mlp
(
ln2_hidden_states
)
return
mlp_output
+
attn_output
+
hidden_states
,
None
intermediate
=
mlp_output
+
attn_output
# Only reduce once and after the addition instead of once per layer
if
self
.
process_group
is
not
None
:
torch
.
distributed
.
all_reduce
(
intermediate
,
group
=
self
.
process_group
)
return
intermediate
+
hidden_states
,
None
else
:
else
:
# faster input layer norm
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
hidden_states
,
residual
,
self
.
input_layernorm
.
weight
,
self
.
input_layernorm
.
bias
,
None
,
None
,
None
,
None
,
0.0
,
self
.
input_layernorm
.
eps
,
1.0
,
0
,
None
,
False
,
False
,
)
hidden_states
=
self
.
attention
(
hidden_states
=
self
.
attention
(
hidden_states
,
hidden_states
,
...
@@ -457,23 +487,8 @@ class FlashNeoXLayer(nn.Module):
...
@@ -457,23 +487,8 @@ class FlashNeoXLayer(nn.Module):
cu_seqlens_q
,
cu_seqlens_q
,
)
)
# faster post attention layer norm
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
hidden_states
,
residual
hidden_states
,
residual
,
self
.
post_attention_layernorm
.
weight
,
self
.
post_attention_layernorm
.
bias
,
None
,
None
,
None
,
None
,
0.0
,
self
.
post_attention_layernorm
.
eps
,
1.0
,
0
,
None
,
False
,
False
,
)
)
mlp_output
=
self
.
mlp
(
hidden_states
)
mlp_output
=
self
.
mlp
(
hidden_states
)
...
@@ -523,7 +538,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
...
@@ -523,7 +538,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
]
]
)
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
final_layer_norm
=
Fast
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
)
...
@@ -603,24 +618,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
...
@@ -603,24 +618,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
cu_seqlens_q
,
cu_seqlens_q
,
)
)
# Faster final layer norm
hidden_states
,
_
=
self
.
final_layer_norm
(
hidden_states
,
residual
)
hidden_states
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
hidden_states
,
residual
,
self
.
final_layer_norm
.
weight
,
self
.
final_layer_norm
.
bias
,
None
,
None
,
None
,
None
,
0.0
,
self
.
final_layer_norm
.
eps
,
1.0
,
0
,
None
,
False
,
False
,
)
return
hidden_states
,
past_key_values
return
hidden_states
,
past_key_values
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment