Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
05e9a796
Unverified
Commit
05e9a796
authored
Mar 24, 2023
by
OlivierDehaene
Committed by
GitHub
Mar 24, 2023
Browse files
feat(server): flash neoX (#133)
parent
23e10288
Changes
10
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1307 additions
and
25 deletions
+1307
-25
.github/workflows/build.yaml
.github/workflows/build.yaml
+4
-0
.github/workflows/tests.yaml
.github/workflows/tests.yaml
+4
-0
Dockerfile
Dockerfile
+6
-3
server/Makefile
server/Makefile
+12
-5
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+18
-2
server/text_generation_server/models/causal_lm.py
server/text_generation_server/models/causal_lm.py
+0
-1
server/text_generation_server/models/flash_neox.py
server/text_generation_server/models/flash_neox.py
+601
-0
server/text_generation_server/models/flash_neox_modeling.py
server/text_generation_server/models/flash_neox_modeling.py
+637
-0
server/text_generation_server/utils/tokens.py
server/text_generation_server/utils/tokens.py
+1
-1
server/text_generation_server/utils/watermark.py
server/text_generation_server/utils/watermark.py
+24
-13
No files found.
.github/workflows/build.yaml
View file @
05e9a796
...
@@ -20,6 +20,10 @@ on:
...
@@ -20,6 +20,10 @@ on:
branches
:
branches
:
-
'
main'
-
'
main'
concurrency
:
group
:
${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress
:
true
jobs
:
jobs
:
build-and-push-image
:
build-and-push-image
:
runs-on
:
ubuntu-latest
runs-on
:
ubuntu-latest
...
...
.github/workflows/tests.yaml
View file @
05e9a796
...
@@ -11,6 +11,10 @@ on:
...
@@ -11,6 +11,10 @@ on:
-
"
Cargo.lock"
-
"
Cargo.lock"
-
"
rust-toolchain.toml"
-
"
rust-toolchain.toml"
concurrency
:
group
:
${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress
:
true
jobs
:
jobs
:
run_tests
:
run_tests
:
runs-on
:
ubuntu-20.04
runs-on
:
ubuntu-20.04
...
...
Dockerfile
View file @
05e9a796
...
@@ -43,7 +43,7 @@ ENV LANG=C.UTF-8 \
...
@@ -43,7 +43,7 @@ ENV LANG=C.UTF-8 \
CONDA_DEFAULT_ENV=text-generation \
CONDA_DEFAULT_ENV=text-generation \
PATH=$PATH:/opt/miniconda/envs/text-generation/bin:/opt/miniconda/bin:/usr/local/cuda/bin
PATH=$PATH:/opt/miniconda/envs/text-generation/bin:/opt/miniconda/bin:/usr/local/cuda/bin
RUN
apt-get update
&&
apt-get
install
-y
unzip
curl libssl-dev
&&
rm
-rf
/var/lib/apt/lists/
*
RUN
apt-get update
&&
apt-get
install
-y
git
curl libssl-dev
&&
rm
-rf
/var/lib/apt/lists/
*
RUN
cd
~
&&
\
RUN
cd
~
&&
\
curl
-L
-O
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
&&
\
curl
-L
-O
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
&&
\
...
@@ -53,10 +53,13 @@ RUN cd ~ && \
...
@@ -53,10 +53,13 @@ RUN cd ~ && \
WORKDIR
/usr/src
WORKDIR
/usr/src
# Install torch
RUN
pip
install
torch
--extra-index-url
https://download.pytorch.org/whl/cu118
--no-cache-dir
COPY
server/Makefile server/Makefile
COPY
server/Makefile server/Makefile
# Install specific version of
torch
# Install specific version of
flash attention
RUN
cd
server
&&
make install-
torch
RUN
cd
server
&&
make install-
flash-attention
# Install specific version of transformers
# Install specific version of transformers
RUN
cd
server
&&
BUILD_EXTENSIONS
=
"True"
make install-transformers
RUN
cd
server
&&
BUILD_EXTENSIONS
=
"True"
make install-transformers
...
...
server/Makefile
View file @
05e9a796
transformers_commit
:=
2b57aa18da658e7d2f42ef6bd5b56751af582fef
transformers_commit
:=
2b57aa18da658e7d2f42ef6bd5b56751af582fef
flash_att_commit
:=
4d87e4d875077ad9efd25030efa4ab0ba92c19e1
gen-server
:
gen-server
:
# Compile protos
# Compile protos
...
@@ -12,13 +13,19 @@ install-transformers:
...
@@ -12,13 +13,19 @@ install-transformers:
# Install specific version of transformers with custom cuda kernels
# Install specific version of transformers with custom cuda kernels
pip uninstall transformers
-y
||
true
pip uninstall transformers
-y
||
true
rm
-rf
transformers
||
true
rm
-rf
transformers
||
true
rm
-rf
transformers-
$(transformers_commit)
||
true
git clone https://github.com/OlivierDehaene/transformers.git
curl
-L
-O
https://github.com/OlivierDehaene/transformers/archive/
$(transformers_commit)
.zip
cd
transformers
&&
git checkout
$(transformers_commit)
unzip
$(transformers_commit)
.zip
rm
$(transformers_commit)
.zip
mv
transformers-
$(transformers_commit)
transformers
cd
transformers
&&
python setup.py
install
cd
transformers
&&
python setup.py
install
install-flash-attention
:
# Install specific version of flash attention
pip
install
packaging
pip uninstall flash_attn rotary_emb dropout_layer_norm
-y
||
true
rm
-rf
flash-attention
||
true
git clone https://github.com/HazyResearch/flash-attention.git
cd
flash-attention
&&
git checkout
$(flash_att_commit)
cd
flash-attention
&&
python setup.py
install
&&
cd
csrc/layer_norm
&&
python setup.py
install
&&
cd
../rotary
&&
python setup.py
install
install-torch
:
install-torch
:
# Install specific version of torch
# Install specific version of torch
pip
install
torch
--extra-index-url
https://download.pytorch.org/whl/cu118
--no-cache-dir
pip
install
torch
--extra-index-url
https://download.pytorch.org/whl/cu118
--no-cache-dir
...
...
server/text_generation_server/models/__init__.py
View file @
05e9a796
import
os
import
torch
import
torch
from
loguru
import
logger
from
transformers
import
AutoConfig
from
transformers
import
AutoConfig
from
typing
import
Optional
from
typing
import
Optional
...
@@ -12,6 +14,14 @@ from text_generation_server.models.santacoder import SantaCoder
...
@@ -12,6 +14,14 @@ from text_generation_server.models.santacoder import SantaCoder
from
text_generation_server.models.gpt_neox
import
GPTNeoxSharded
from
text_generation_server.models.gpt_neox
import
GPTNeoxSharded
from
text_generation_server.models.t5
import
T5Sharded
from
text_generation_server.models.t5
import
T5Sharded
try
:
from
text_generation_server.models.flash_neox
import
FlashNeoX
,
FlashNeoXSharded
FLASH_NEOX
=
torch
.
cuda
.
is_available
()
and
int
(
os
.
environ
.
get
(
"FLASH_NEOX"
,
0
))
==
1
except
ImportError
:
if
int
(
os
.
environ
.
get
(
"FLASH_NEOX"
,
0
))
==
1
:
logger
.
exception
(
"Could not import FlashNeoX"
)
FLASH_NEOX
=
False
__all__
=
[
__all__
=
[
"Model"
,
"Model"
,
"BLOOM"
,
"BLOOM"
,
...
@@ -26,6 +36,10 @@ __all__ = [
...
@@ -26,6 +36,10 @@ __all__ = [
"get_model"
,
"get_model"
,
]
]
if
FLASH_NEOX
:
__all__
.
append
(
FlashNeoX
)
__all__
.
append
(
FlashNeoXSharded
)
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
# in PyTorch 1.12 and later.
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
...
@@ -59,9 +73,11 @@ def get_model(
...
@@ -59,9 +73,11 @@ def get_model(
if
config
.
model_type
==
"gpt_neox"
:
if
config
.
model_type
==
"gpt_neox"
:
if
sharded
:
if
sharded
:
return
GPTNeoxSharded
(
model_id
,
revision
,
quantize
=
quantize
)
neox_cls
=
FlashNeoXSharded
if
FLASH_NEOX
else
GPTNeoxSharded
return
neox_cls
(
model_id
,
revision
,
quantize
=
quantize
)
else
:
else
:
return
CausalLM
(
model_id
,
revision
,
quantize
=
quantize
)
neox_cls
=
FlashNeoX
if
FLASH_NEOX
else
CausalLM
return
neox_cls
(
model_id
,
revision
,
quantize
=
quantize
)
if
config
.
model_type
==
"t5"
:
if
config
.
model_type
==
"t5"
:
if
sharded
:
if
sharded
:
...
...
server/text_generation_server/models/causal_lm.py
View file @
05e9a796
...
@@ -64,7 +64,6 @@ class CausalLMBatch(Batch):
...
@@ -64,7 +64,6 @@ class CausalLMBatch(Batch):
inputs
=
[]
inputs
=
[]
next_token_choosers
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
stopping_criterias
=
[]
input_lengths
=
[]
# Parse batch
# Parse batch
padding_right_offset
=
0
padding_right_offset
=
0
...
...
server/text_generation_server/models/flash_neox.py
0 → 100644
View file @
05e9a796
This diff is collapsed.
Click to expand it.
server/text_generation_server/models/flash_neox_modeling.py
0 → 100644
View file @
05e9a796
import
torch
import
torch.distributed
import
torch.nn.functional
as
F
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.models.gpt_neox
import
GPTNeoXConfig
# Flash attention imports
import
rotary_emb
import
flash_attn_cuda
import
dropout_layer_norm
from
flash_attn.layers.rotary
import
RotaryEmbedding
class
TensorParallelColumnLinear
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
,
out_features
,
process_group
:
torch
.
distributed
.
ProcessGroup
,
bias
=
True
,
device
=
None
,
dtype
=
None
,
):
self
.
process_group
=
process_group
self
.
tp_world_size
=
process_group
.
size
()
assert
out_features
%
self
.
tp_world_size
==
0
out_features
=
out_features
//
self
.
tp_world_size
super
().
__init__
(
in_features
=
in_features
,
out_features
=
out_features
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
,
)
@
staticmethod
def
linear
(
input
,
weight
,
bias
):
return
F
.
linear
(
input
,
weight
,
bias
)
def
forward
(
self
,
input
):
return
self
.
linear
(
input
,
self
.
weight
,
self
.
bias
)
class
TensorParallelRowLinear
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
,
out_features
,
process_group
:
torch
.
distributed
.
ProcessGroup
,
bias
=
True
,
device
=
None
,
dtype
=
None
,
):
self
.
process_group
=
process_group
self
.
tp_world_size
=
process_group
.
size
()
assert
in_features
%
self
.
tp_world_size
==
0
in_features
=
in_features
//
self
.
tp_world_size
super
().
__init__
(
in_features
=
in_features
,
out_features
=
out_features
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
,
)
@
staticmethod
def
linear
(
input
,
weight
,
bias
):
return
F
.
linear
(
input
,
weight
,
bias
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
out
=
self
.
linear
(
input
,
self
.
weight
,
self
.
bias
)
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
process_group
)
return
out
class
TensorParallelEmbedding
(
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
process_group
:
torch
.
distributed
.
ProcessGroup
,
padding_idx
=
None
,
max_norm
=
None
,
norm_type
=
2.0
,
scale_grad_by_freq
=
False
,
sparse
=
False
,
_weight
=
None
,
device
=
None
,
dtype
=
None
,
):
self
.
process_group
=
process_group
self
.
tp_rank
=
process_group
.
rank
()
self
.
tp_world_size
=
process_group
.
size
()
self
.
original_num_embeddings
=
num_embeddings
assert
num_embeddings
%
self
.
tp_world_size
==
0
block_size
=
num_embeddings
//
self
.
tp_world_size
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings
self
.
min_id
=
self
.
tp_rank
*
block_size
self
.
max_id
=
(
self
.
tp_rank
+
1
)
*
block_size
super
().
__init__
(
block_size
,
embedding_dim
,
padding_idx
=
padding_idx
,
max_norm
=
max_norm
,
norm_type
=
norm_type
,
scale_grad_by_freq
=
scale_grad_by_freq
,
sparse
=
sparse
,
_weight
=
_weight
,
device
=
device
,
dtype
=
dtype
,
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Sanity check
if
torch
.
any
(
torch
.
logical_or
(
0
>
input
,
input
>=
self
.
original_num_embeddings
)
):
raise
IndexError
(
f
"Input is required to be in [0,
{
self
.
original_num_embeddings
}
[, got min:
{
torch
.
min
(
input
)
}
and max:
{
torch
.
max
(
input
)
}
"
)
# `0` if input is in the correct interval, else `1`
input_mask
=
torch
.
logical_or
(
self
.
min_id
>
input
,
input
>=
self
.
max_id
)
# translate for [0, self.max_id - self.min_id[
input
=
input
-
self
.
min_id
# default all out of bounds values to `0`
input
[
input_mask
]
=
0
out
=
super
().
forward
(
input
)
out
[
input_mask
]
=
0.0
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
process_group
)
return
out
class
PositionRotaryEmbedding
(
RotaryEmbedding
):
def
_update_cos_sin_cache
(
self
,
dtype
,
device
,
seqlen
):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if
(
seqlen
>
self
.
_seq_len_cached
or
self
.
_cos_cached
.
device
!=
device
or
self
.
_cos_cached
.
dtype
!=
dtype
):
self
.
_seq_len_cached
=
seqlen
t
=
torch
.
arange
(
seqlen
,
device
=
device
,
dtype
=
self
.
inv_freq
.
dtype
)
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs
=
torch
.
outer
(
t
,
self
.
inv_freq
.
to
(
device
=
t
.
device
))
self
.
_cos_cached
=
torch
.
cos
(
freqs
).
to
(
dtype
)
self
.
_sin_cached
=
torch
.
sin
(
freqs
).
to
(
dtype
)
def
get_cos_sin
(
self
,
position_ids
:
torch
.
Tensor
,
max_s
:
int
,
dtype
:
torch
.
dtype
):
"""
Return cos and sin for the asked position ids
"""
self
.
_update_cos_sin_cache
(
dtype
,
position_ids
.
device
,
max_s
)
cos
=
torch
.
index_select
(
self
.
_cos_cached
,
0
,
position_ids
)
sin
=
torch
.
index_select
(
self
.
_sin_cached
,
0
,
position_ids
)
return
cos
.
unsqueeze
(
1
),
sin
.
unsqueeze
(
1
)
def
forward
(
self
,
qkv
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
):
rotary_dim
=
cos
.
shape
[
-
1
]
q1
=
qkv
[:,
0
,
:,
:
rotary_dim
]
q2
=
qkv
[:,
0
,
:,
rotary_dim
:
2
*
rotary_dim
]
k1
=
qkv
[:,
1
,
:,
:
rotary_dim
]
k2
=
qkv
[:,
1
,
:,
rotary_dim
:
2
*
rotary_dim
]
rotary_emb
.
apply_rotary
(
q1
,
q2
,
cos
,
sin
,
q1
,
q2
,
False
)
rotary_emb
.
apply_rotary
(
k1
,
k2
,
cos
,
sin
,
k1
,
k2
,
False
)
return
qkv
class
FlashNeoxAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
num_heads
,
hidden_size
,
rotary_pct
,
rotary_emb_base
,
process_group
=
None
):
super
().
__init__
()
self
.
num_heads
=
num_heads
self
.
hidden_size
=
hidden_size
self
.
head_size
=
hidden_size
//
num_heads
rotary_ndims
=
int
(
self
.
head_size
*
rotary_pct
)
self
.
rotary_emb
=
PositionRotaryEmbedding
(
rotary_ndims
,
base
=
rotary_emb_base
)
self
.
softmax_scale
=
self
.
head_size
**
(
-
0.5
)
if
process_group
is
None
:
self
.
query_key_value
=
nn
.
Linear
(
hidden_size
,
3
*
hidden_size
)
self
.
dense
=
nn
.
Linear
(
hidden_size
,
hidden_size
)
else
:
self
.
num_heads
=
self
.
num_heads
//
process_group
.
size
()
self
.
query_key_value
=
TensorParallelColumnLinear
(
hidden_size
,
3
*
hidden_size
,
process_group
=
process_group
,
)
self
.
dense
=
TensorParallelRowLinear
(
hidden_size
,
hidden_size
,
process_group
=
process_group
,
)
self
.
swap_dims
=
True
# TODO: remove and swap dims when loading weights
def
_swap_dims
(
self
):
"""Swap dims for the first inference to avoid an additional permute"""
self
.
query_key_value
.
weight
=
torch
.
nn
.
Parameter
(
self
.
query_key_value
.
weight
.
view
(
self
.
num_heads
,
3
,
self
.
head_size
,
self
.
hidden_size
)
.
permute
(
1
,
0
,
2
,
3
)
.
reshape
(
-
1
,
self
.
hidden_size
)
)
self
.
query_key_value
.
bias
=
torch
.
nn
.
Parameter
(
self
.
query_key_value
.
bias
.
view
(
self
.
num_heads
,
3
,
self
.
head_size
)
.
permute
(
1
,
0
,
2
)
.
reshape
(
-
1
)
)
self
.
swap_dims
=
False
def
forward
(
self
,
hidden_states
,
cos
,
sin
,
cu_seqlens
,
max_s
,
layer_past
,
layer_past_present_indices
,
cu_seqlens_q
,
):
if
self
.
swap_dims
:
self
.
_swap_dims
()
qkv
=
self
.
query_key_value
(
hidden_states
)
qkv
=
qkv
.
view
(
-
1
,
3
,
self
.
num_heads
,
self
.
head_size
)
qkv_rot
=
self
.
rotary_emb
(
qkv
,
cos
,
sin
)
# Prefill
if
layer_past_present_indices
is
None
:
# Copy to layer past
layer_past
[...]
=
qkv_rot
[:,
1
:]
# output
attn_output
=
torch
.
empty_like
(
qkv
[:,
0
])
# flash attention
flash_attn_cuda
.
fwd
(
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
],
attn_output
,
cu_seqlens
,
cu_seqlens
,
max_s
,
max_s
,
0.0
,
self
.
softmax_scale
,
False
,
True
,
False
,
0
,
None
,
)
# Decode
else
:
query
=
qkv_rot
[:,
0
]
# Add present to the layer_past tensor at the correct indices
layer_past
[
layer_past_present_indices
]
=
qkv_rot
[:,
1
:]
# output
attn_output
=
torch
.
empty_like
(
query
)
# flash attention
flash_attn_cuda
.
fwd
(
query
,
layer_past
[:,
0
],
layer_past
[:,
1
],
attn_output
,
cu_seqlens_q
,
cu_seqlens
,
1
,
max_s
,
0.0
,
self
.
softmax_scale
,
False
,
False
,
False
,
0
,
None
,
)
return
self
.
dense
(
attn_output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
))
class
FlashMLP
(
nn
.
Module
):
def
__init__
(
self
,
act
,
hidden_size
,
intermediate_size
,
process_group
=
None
):
super
().
__init__
()
self
.
act
=
(
ACT2FN
[
act
]
if
"gelu"
not
in
act
else
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
)
)
if
process_group
is
None
:
self
.
dense_h_to_4h
=
nn
.
Linear
(
hidden_size
,
intermediate_size
)
self
.
dense_4h_to_h
=
nn
.
Linear
(
intermediate_size
,
hidden_size
)
else
:
self
.
dense_h_to_4h
=
TensorParallelColumnLinear
(
hidden_size
,
intermediate_size
,
process_group
=
process_group
,
)
self
.
dense_4h_to_h
=
TensorParallelRowLinear
(
intermediate_size
,
hidden_size
,
process_group
=
process_group
,
)
self
.
heuristic
=
"auto"
self
.
process_group
=
process_group
def
forward
(
self
,
hidden_states
):
hidden_states
=
self
.
dense_h_to_4h
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
dense_4h_to_h
(
hidden_states
)
return
hidden_states
class
FlashNeoXLayer
(
nn
.
Module
):
def
__init__
(
self
,
num_heads
,
act
,
hidden_size
,
intermediate_size
,
rotary_pct
,
rotary_emb_base
,
layer_norm_eps
,
use_parallel_residual
,
process_group
=
None
,
):
super
().
__init__
()
self
.
use_parallel_residual
=
use_parallel_residual
self
.
input_layernorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
layer_norm_eps
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
layer_norm_eps
)
self
.
attention
=
FlashNeoxAttention
(
num_heads
,
hidden_size
,
rotary_pct
,
rotary_emb_base
,
process_group
)
self
.
mlp
=
FlashMLP
(
act
,
hidden_size
,
intermediate_size
,
process_group
)
def
forward
(
self
,
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlens
,
max_s
,
layer_past
,
layer_past_present_indices
,
cu_seqlens_q
,
):
if
self
.
use_parallel_residual
:
# faster input layer norm
ln1_hidden_states
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
hidden_states
,
None
,
self
.
input_layernorm
.
weight
,
self
.
input_layernorm
.
bias
,
None
,
None
,
None
,
None
,
0.0
,
self
.
input_layernorm
.
eps
,
1.0
,
0
,
None
,
False
,
False
,
)
attn_output
=
self
.
attention
(
ln1_hidden_states
,
cos
,
sin
,
cu_seqlens
,
max_s
,
layer_past
,
layer_past_present_indices
,
cu_seqlens_q
,
)
# faster post attention layer norm
ln2_hidden_states
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
hidden_states
,
None
,
self
.
post_attention_layernorm
.
weight
,
self
.
post_attention_layernorm
.
bias
,
None
,
None
,
None
,
None
,
0.0
,
self
.
post_attention_layernorm
.
eps
,
1.0
,
0
,
None
,
False
,
False
,
)
mlp_output
=
self
.
mlp
(
ln2_hidden_states
)
return
mlp_output
+
attn_output
+
hidden_states
,
None
else
:
# faster input layer norm
hidden_states
,
residual
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
hidden_states
,
residual
,
self
.
input_layernorm
.
weight
,
self
.
input_layernorm
.
bias
,
None
,
None
,
None
,
None
,
0.0
,
self
.
input_layernorm
.
eps
,
1.0
,
0
,
None
,
False
,
False
,
)
hidden_states
=
self
.
attention
(
hidden_states
,
cos
,
sin
,
cu_seqlens
,
max_s
,
layer_past
,
layer_past_present_indices
,
cu_seqlens_q
,
)
# faster post attention layer norm
hidden_states
,
residual
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
hidden_states
,
residual
,
self
.
post_attention_layernorm
.
weight
,
self
.
post_attention_layernorm
.
bias
,
None
,
None
,
None
,
None
,
0.0
,
self
.
post_attention_layernorm
.
eps
,
1.0
,
0
,
None
,
False
,
False
,
)
mlp_output
=
self
.
mlp
(
hidden_states
)
return
mlp_output
,
residual
class
FlashGPTNeoXPreTrainedModel
(
PreTrainedModel
):
config_class
=
GPTNeoXConfig
base_model_prefix
=
"gpt_neox"
supports_gradient_checkpointing
=
False
_no_split_modules
=
None
class
FlashGPTNeoXModel
(
FlashGPTNeoXPreTrainedModel
):
def
__init__
(
self
,
config
,
process_group
=
None
):
super
().
__init__
(
config
)
self
.
config
=
config
self
.
tp_embeddings
=
False
if
process_group
is
not
None
:
self
.
tp_rank
=
process_group
.
rank
()
self
.
tp_world_size
=
process_group
.
size
()
if
config
.
vocab_size
%
self
.
tp_world_size
==
0
:
self
.
tp_embeddings
=
True
if
self
.
tp_embeddings
:
self
.
embed_in
=
TensorParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
process_group
=
process_group
)
else
:
self
.
embed_in
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
(
[
FlashNeoXLayer
(
config
.
num_attention_heads
,
config
.
hidden_act
,
config
.
hidden_size
,
config
.
intermediate_size
,
config
.
rotary_pct
,
config
.
rotary_emb_base
,
config
.
layer_norm_eps
,
config
.
use_parallel_residual
,
process_group
,
)
for
_
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
gradient_checkpointing
=
False
self
.
head_size
=
self
.
layers
[
0
].
attention
.
head_size
self
.
num_heads
=
self
.
layers
[
0
].
attention
.
num_heads
def
forward
(
self
,
input_ids
,
position_ids
,
cu_seqlens
,
max_s
,
past_key_values
=
None
,
):
hidden_states
=
self
.
embed_in
(
input_ids
)
# Prefill
if
past_key_values
is
None
:
# Create past tensor
past_key_values
=
hidden_states
.
new_empty
(
(
len
(
self
.
layers
),
len
(
hidden_states
),
2
,
self
.
num_heads
,
self
.
head_size
,
)
)
layer_past_present_indices
=
None
cu_seqlens_q
=
None
# Decode
else
:
# Create indices from cumulative sequence lengths
layer_past_present_indices
=
cu_seqlens
[
1
:]
-
1
cu_seqlens_q
=
torch
.
arange
(
len
(
cu_seqlens
),
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos
,
sin
=
self
.
layers
[
0
].
attention
.
rotary_emb
.
get_cos_sin
(
position_ids
,
max_s
,
hidden_states
.
dtype
)
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
layers
):
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlens
,
max_s
,
past_key_values
[
i
],
layer_past_present_indices
,
cu_seqlens_q
,
)
# Faster final layer norm
hidden_states
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
hidden_states
,
residual
,
self
.
final_layer_norm
.
weight
,
self
.
final_layer_norm
.
bias
,
None
,
None
,
None
,
None
,
0.0
,
self
.
final_layer_norm
.
eps
,
1.0
,
0
,
None
,
False
,
False
,
)
return
hidden_states
,
past_key_values
class
FlashGPTNeoXForCausalLM
(
FlashGPTNeoXPreTrainedModel
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
if
config
.
tp_parallel
:
process_group
=
torch
.
distributed
.
distributed_c10d
.
_get_default_group
()
else
:
process_group
=
None
self
.
gpt_neox
=
FlashGPTNeoXModel
(
config
,
process_group
)
if
self
.
gpt_neox
.
tp_embeddings
:
self
.
embed_out
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
vocab_size
//
process_group
.
size
(),
bias
=
False
,
)
else
:
self
.
embed_out
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
def
forward
(
self
,
input_ids
,
position_ids
,
cu_seqlens
,
max_s
,
past_key_values
=
None
,
):
hidden_states
,
present
=
self
.
gpt_neox
(
input_ids
,
position_ids
,
cu_seqlens
,
max_s
,
past_key_values
)
return
self
.
embed_out
(
hidden_states
),
present
server/text_generation_server/utils/tokens.py
View file @
05e9a796
...
@@ -24,7 +24,7 @@ class Sampling:
...
@@ -24,7 +24,7 @@ class Sampling:
self
.
seed
=
seed
self
.
seed
=
seed
def
__call__
(
self
,
logits
):
def
__call__
(
self
,
logits
):
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
)
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
-
1
)
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
1
,
generator
=
self
.
generator
)
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
1
,
generator
=
self
.
generator
)
return
next_tokens
return
next_tokens
...
...
server/text_generation_server/utils/watermark.py
View file @
05e9a796
...
@@ -17,6 +17,7 @@ import os
...
@@ -17,6 +17,7 @@ import os
import
torch
import
torch
from
transformers
import
LogitsProcessor
from
transformers
import
LogitsProcessor
from
typing
import
List
,
Union
GAMMA
=
os
.
getenv
(
"WATERMARK_GAMMA"
,
0.5
)
GAMMA
=
os
.
getenv
(
"WATERMARK_GAMMA"
,
0.5
)
DELTA
=
os
.
getenv
(
"WATERMARK_DELTA"
,
2.0
)
DELTA
=
os
.
getenv
(
"WATERMARK_DELTA"
,
2.0
)
...
@@ -36,7 +37,15 @@ class WatermarkLogitsProcessor(LogitsProcessor):
...
@@ -36,7 +37,15 @@ class WatermarkLogitsProcessor(LogitsProcessor):
self
.
rng
=
torch
.
Generator
(
device
=
device
)
self
.
rng
=
torch
.
Generator
(
device
=
device
)
self
.
hash_key
=
hash_key
self
.
hash_key
=
hash_key
def
_seed_rng
(
self
,
input_ids
:
torch
.
LongTensor
)
->
None
:
def
_seed_rng
(
self
,
input_ids
:
Union
[
List
[
int
],
torch
.
LongTensor
]):
if
isinstance
(
input_ids
,
list
):
assert
(
len
(
input_ids
)
>=
1
),
"requires at least a 1 token prefix sequence to seed rng"
prev_token
=
input_ids
[
-
1
]
else
:
input_ids
=
input_ids
[
0
]
assert
len
(
input_ids
)
==
1
assert
(
assert
(
input_ids
.
shape
[
-
1
]
>=
1
input_ids
.
shape
[
-
1
]
>=
1
),
"requires at least a 1 token prefix sequence to seed rng"
),
"requires at least a 1 token prefix sequence to seed rng"
...
@@ -44,15 +53,16 @@ class WatermarkLogitsProcessor(LogitsProcessor):
...
@@ -44,15 +53,16 @@ class WatermarkLogitsProcessor(LogitsProcessor):
self
.
rng
.
manual_seed
(
self
.
hash_key
*
prev_token
)
self
.
rng
.
manual_seed
(
self
.
hash_key
*
prev_token
)
def
_get_greenlist_ids
(
def
_get_greenlist_ids
(
self
,
input_ids
:
torch
.
LongTensor
,
max_value
:
int
self
,
)
->
list
[
int
]:
input_ids
:
Union
[
List
[
int
],
torch
.
LongTensor
],
max_value
:
int
,
device
:
torch
.
device
,
)
->
List
[
int
]:
# seed the rng using the previous tokens/prefix
# seed the rng using the previous tokens/prefix
self
.
_seed_rng
(
input_ids
)
self
.
_seed_rng
(
input_ids
)
greenlist_size
=
int
(
max_value
*
self
.
gamma
)
greenlist_size
=
int
(
max_value
*
self
.
gamma
)
vocab_permutation
=
torch
.
randperm
(
vocab_permutation
=
torch
.
randperm
(
max_value
,
device
=
device
,
generator
=
self
.
rng
)
max_value
,
device
=
input_ids
.
device
,
generator
=
self
.
rng
)
greenlist_ids
=
vocab_permutation
[:
greenlist_size
]
greenlist_ids
=
vocab_permutation
[:
greenlist_size
]
return
greenlist_ids
return
greenlist_ids
...
@@ -73,10 +83,11 @@ class WatermarkLogitsProcessor(LogitsProcessor):
...
@@ -73,10 +83,11 @@ class WatermarkLogitsProcessor(LogitsProcessor):
return
scores
return
scores
def
__call__
(
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
self
,
input_ids
:
Union
[
List
[
int
],
torch
.
LongTensor
]
,
scores
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
)
->
torch
.
FloatTensor
:
assert
len
(
input_ids
)
==
1
greenlist_ids
=
self
.
_get_greenlist_ids
(
greenlist_ids
=
self
.
_get_greenlist_ids
(
input_ids
[
0
],
scores
.
shape
[
-
1
])
input_ids
,
scores
.
shape
[
-
1
],
scores
.
device
)
green_tokens_mask
=
self
.
_calc_greenlist_mask
(
green_tokens_mask
=
self
.
_calc_greenlist_mask
(
scores
=
scores
,
greenlist_token_ids
=
greenlist_ids
scores
=
scores
,
greenlist_token_ids
=
greenlist_ids
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment