Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
05e9a796
Unverified
Commit
05e9a796
authored
Mar 24, 2023
by
OlivierDehaene
Committed by
GitHub
Mar 24, 2023
Browse files
feat(server): flash neoX (#133)
parent
23e10288
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1307 additions
and
25 deletions
+1307
-25
.github/workflows/build.yaml
.github/workflows/build.yaml
+4
-0
.github/workflows/tests.yaml
.github/workflows/tests.yaml
+4
-0
Dockerfile
Dockerfile
+6
-3
server/Makefile
server/Makefile
+12
-5
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+18
-2
server/text_generation_server/models/causal_lm.py
server/text_generation_server/models/causal_lm.py
+0
-1
server/text_generation_server/models/flash_neox.py
server/text_generation_server/models/flash_neox.py
+601
-0
server/text_generation_server/models/flash_neox_modeling.py
server/text_generation_server/models/flash_neox_modeling.py
+637
-0
server/text_generation_server/utils/tokens.py
server/text_generation_server/utils/tokens.py
+1
-1
server/text_generation_server/utils/watermark.py
server/text_generation_server/utils/watermark.py
+24
-13
No files found.
.github/workflows/build.yaml
View file @
05e9a796
...
@@ -20,6 +20,10 @@ on:
...
@@ -20,6 +20,10 @@ on:
branches
:
branches
:
-
'
main'
-
'
main'
concurrency
:
group
:
${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress
:
true
jobs
:
jobs
:
build-and-push-image
:
build-and-push-image
:
runs-on
:
ubuntu-latest
runs-on
:
ubuntu-latest
...
...
.github/workflows/tests.yaml
View file @
05e9a796
...
@@ -11,6 +11,10 @@ on:
...
@@ -11,6 +11,10 @@ on:
-
"
Cargo.lock"
-
"
Cargo.lock"
-
"
rust-toolchain.toml"
-
"
rust-toolchain.toml"
concurrency
:
group
:
${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress
:
true
jobs
:
jobs
:
run_tests
:
run_tests
:
runs-on
:
ubuntu-20.04
runs-on
:
ubuntu-20.04
...
...
Dockerfile
View file @
05e9a796
...
@@ -43,7 +43,7 @@ ENV LANG=C.UTF-8 \
...
@@ -43,7 +43,7 @@ ENV LANG=C.UTF-8 \
CONDA_DEFAULT_ENV=text-generation \
CONDA_DEFAULT_ENV=text-generation \
PATH=$PATH:/opt/miniconda/envs/text-generation/bin:/opt/miniconda/bin:/usr/local/cuda/bin
PATH=$PATH:/opt/miniconda/envs/text-generation/bin:/opt/miniconda/bin:/usr/local/cuda/bin
RUN
apt-get update
&&
apt-get
install
-y
unzip
curl libssl-dev
&&
rm
-rf
/var/lib/apt/lists/
*
RUN
apt-get update
&&
apt-get
install
-y
git
curl libssl-dev
&&
rm
-rf
/var/lib/apt/lists/
*
RUN
cd
~
&&
\
RUN
cd
~
&&
\
curl
-L
-O
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
&&
\
curl
-L
-O
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
&&
\
...
@@ -53,10 +53,13 @@ RUN cd ~ && \
...
@@ -53,10 +53,13 @@ RUN cd ~ && \
WORKDIR
/usr/src
WORKDIR
/usr/src
# Install torch
RUN
pip
install
torch
--extra-index-url
https://download.pytorch.org/whl/cu118
--no-cache-dir
COPY
server/Makefile server/Makefile
COPY
server/Makefile server/Makefile
# Install specific version of
torch
# Install specific version of
flash attention
RUN
cd
server
&&
make install-
torch
RUN
cd
server
&&
make install-
flash-attention
# Install specific version of transformers
# Install specific version of transformers
RUN
cd
server
&&
BUILD_EXTENSIONS
=
"True"
make install-transformers
RUN
cd
server
&&
BUILD_EXTENSIONS
=
"True"
make install-transformers
...
...
server/Makefile
View file @
05e9a796
transformers_commit
:=
2b57aa18da658e7d2f42ef6bd5b56751af582fef
transformers_commit
:=
2b57aa18da658e7d2f42ef6bd5b56751af582fef
flash_att_commit
:=
4d87e4d875077ad9efd25030efa4ab0ba92c19e1
gen-server
:
gen-server
:
# Compile protos
# Compile protos
...
@@ -12,13 +13,19 @@ install-transformers:
...
@@ -12,13 +13,19 @@ install-transformers:
# Install specific version of transformers with custom cuda kernels
# Install specific version of transformers with custom cuda kernels
pip uninstall transformers
-y
||
true
pip uninstall transformers
-y
||
true
rm
-rf
transformers
||
true
rm
-rf
transformers
||
true
rm
-rf
transformers-
$(transformers_commit)
||
true
git clone https://github.com/OlivierDehaene/transformers.git
curl
-L
-O
https://github.com/OlivierDehaene/transformers/archive/
$(transformers_commit)
.zip
cd
transformers
&&
git checkout
$(transformers_commit)
unzip
$(transformers_commit)
.zip
rm
$(transformers_commit)
.zip
mv
transformers-
$(transformers_commit)
transformers
cd
transformers
&&
python setup.py
install
cd
transformers
&&
python setup.py
install
install-flash-attention
:
# Install specific version of flash attention
pip
install
packaging
pip uninstall flash_attn rotary_emb dropout_layer_norm
-y
||
true
rm
-rf
flash-attention
||
true
git clone https://github.com/HazyResearch/flash-attention.git
cd
flash-attention
&&
git checkout
$(flash_att_commit)
cd
flash-attention
&&
python setup.py
install
&&
cd
csrc/layer_norm
&&
python setup.py
install
&&
cd
../rotary
&&
python setup.py
install
install-torch
:
install-torch
:
# Install specific version of torch
# Install specific version of torch
pip
install
torch
--extra-index-url
https://download.pytorch.org/whl/cu118
--no-cache-dir
pip
install
torch
--extra-index-url
https://download.pytorch.org/whl/cu118
--no-cache-dir
...
...
server/text_generation_server/models/__init__.py
View file @
05e9a796
import
os
import
torch
import
torch
from
loguru
import
logger
from
transformers
import
AutoConfig
from
transformers
import
AutoConfig
from
typing
import
Optional
from
typing
import
Optional
...
@@ -12,6 +14,14 @@ from text_generation_server.models.santacoder import SantaCoder
...
@@ -12,6 +14,14 @@ from text_generation_server.models.santacoder import SantaCoder
from
text_generation_server.models.gpt_neox
import
GPTNeoxSharded
from
text_generation_server.models.gpt_neox
import
GPTNeoxSharded
from
text_generation_server.models.t5
import
T5Sharded
from
text_generation_server.models.t5
import
T5Sharded
try
:
from
text_generation_server.models.flash_neox
import
FlashNeoX
,
FlashNeoXSharded
FLASH_NEOX
=
torch
.
cuda
.
is_available
()
and
int
(
os
.
environ
.
get
(
"FLASH_NEOX"
,
0
))
==
1
except
ImportError
:
if
int
(
os
.
environ
.
get
(
"FLASH_NEOX"
,
0
))
==
1
:
logger
.
exception
(
"Could not import FlashNeoX"
)
FLASH_NEOX
=
False
__all__
=
[
__all__
=
[
"Model"
,
"Model"
,
"BLOOM"
,
"BLOOM"
,
...
@@ -26,6 +36,10 @@ __all__ = [
...
@@ -26,6 +36,10 @@ __all__ = [
"get_model"
,
"get_model"
,
]
]
if
FLASH_NEOX
:
__all__
.
append
(
FlashNeoX
)
__all__
.
append
(
FlashNeoXSharded
)
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
# in PyTorch 1.12 and later.
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
...
@@ -59,9 +73,11 @@ def get_model(
...
@@ -59,9 +73,11 @@ def get_model(
if
config
.
model_type
==
"gpt_neox"
:
if
config
.
model_type
==
"gpt_neox"
:
if
sharded
:
if
sharded
:
return
GPTNeoxSharded
(
model_id
,
revision
,
quantize
=
quantize
)
neox_cls
=
FlashNeoXSharded
if
FLASH_NEOX
else
GPTNeoxSharded
return
neox_cls
(
model_id
,
revision
,
quantize
=
quantize
)
else
:
else
:
return
CausalLM
(
model_id
,
revision
,
quantize
=
quantize
)
neox_cls
=
FlashNeoX
if
FLASH_NEOX
else
CausalLM
return
neox_cls
(
model_id
,
revision
,
quantize
=
quantize
)
if
config
.
model_type
==
"t5"
:
if
config
.
model_type
==
"t5"
:
if
sharded
:
if
sharded
:
...
...
server/text_generation_server/models/causal_lm.py
View file @
05e9a796
...
@@ -64,7 +64,6 @@ class CausalLMBatch(Batch):
...
@@ -64,7 +64,6 @@ class CausalLMBatch(Batch):
inputs
=
[]
inputs
=
[]
next_token_choosers
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
stopping_criterias
=
[]
input_lengths
=
[]
# Parse batch
# Parse batch
padding_right_offset
=
0
padding_right_offset
=
0
...
...
server/text_generation_server/models/flash_neox.py
0 → 100644
View file @
05e9a796
import
torch
import
torch.distributed
from
accelerate
import
init_empty_weights
from
dataclasses
import
dataclass
from
opentelemetry
import
trace
from
safetensors
import
safe_open
from
transformers
import
AutoTokenizer
,
PreTrainedTokenizerBase
,
AutoConfig
from
typing
import
Optional
,
Tuple
,
List
,
Type
,
Union
from
text_generation_server.models
import
Model
from
text_generation_server.models.flash_neox_modeling
import
(
FlashGPTNeoXForCausalLM
,
TensorParallelEmbedding
,
TensorParallelRowLinear
,
TensorParallelColumnLinear
,
)
from
text_generation_server.models.types
import
(
Batch
,
PrefillTokens
,
Generation
,
GeneratedText
,
)
from
text_generation_server.pb
import
generate_pb2
from
text_generation_server.utils
import
(
NextTokenChooser
,
StoppingCriteria
,
Sampling
,
initialize_torch_distributed
,
weight_files
,
)
tracer
=
trace
.
get_tracer
(
__name__
)
@
dataclass
class
FlashNeoXBatch
(
Batch
):
batch_id
:
int
requests
:
List
[
generate_pb2
.
Request
]
# Decoder values
input_ids
:
torch
.
Tensor
position_ids
:
torch
.
Tensor
# cumulative sequence lengths
cu_seqlens
:
torch
.
Tensor
max_seqlen
:
int
past_key_values
:
Optional
[
torch
.
Tensor
]
# All tokens
all_input_ids
:
List
[
List
[
int
]]
# Lengths of all generations present in the batch
input_lengths
:
List
[
int
]
# Generation helpers
next_token_choosers
:
List
[
NextTokenChooser
]
stopping_criterias
:
List
[
StoppingCriteria
]
def
to_pb
(
self
)
->
generate_pb2
.
Batch
:
return
generate_pb2
.
Batch
(
id
=
self
.
batch_id
,
requests
=
self
.
requests
,
size
=
len
(
self
)
)
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
PreTrainedTokenizerBase
,
device
:
torch
.
device
,
)
->
"CausalLMBatch"
:
input_ids
=
[]
position_ids
=
[]
cu_seqlens
=
[
0
]
max_seqlen
=
0
input_lengths
=
[]
all_input_ids
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
# Cumulative length
cumulative_length
=
0
# Parse batch
for
r
in
pb
.
requests
:
tokenized_input
=
tokenizer
(
r
.
inputs
,
return_tensors
=
"pt"
)[
"input_ids"
].
squeeze
(
0
)
input_ids
.
append
(
tokenized_input
)
all_input_ids
.
append
(
tokenized_input
.
tolist
())
input_length
=
len
(
tokenized_input
)
max_seqlen
=
max
(
max_seqlen
,
input_length
)
input_lengths
.
append
(
input_length
)
# Position ids
position_ids
.
append
(
torch
.
arange
(
0
,
input_length
,
dtype
=
torch
.
int32
))
# Add cumulative lengths of all previous inputs
cu_seqlens
.
append
(
cumulative_length
+
input_length
)
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
))
stopping_criterias
.
append
(
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
)
# Update
cumulative_length
+=
input_length
input_ids
=
torch
.
concat
(
input_ids
).
unsqueeze
(
1
)
position_ids
=
torch
.
concat
(
position_ids
)
cu_seqlens
=
torch
.
tensor
(
cu_seqlens
,
dtype
=
torch
.
int32
)
return
cls
(
batch_id
=
pb
.
id
,
requests
=
pb
.
requests
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
past_key_values
=
None
,
input_lengths
=
input_lengths
,
all_input_ids
=
all_input_ids
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
)
@
classmethod
@
tracer
.
start_as_current_span
(
"concatenate"
)
def
concatenate
(
cls
,
batches
:
List
[
"CausalLMBatch"
])
->
"CausalLMBatch"
:
# Batch attributes
requests
=
[]
input_lengths
=
[]
all_input_ids
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
# Batch tensors
input_ids
=
[]
position_ids
=
[]
cu_seqlens
=
[
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
)]
max_seqlen
=
0
past_key_values
=
[]
# Cumulative length
cumulative_length
=
torch
.
tensor
(
0
)
for
i
,
batch
in
enumerate
(
batches
):
requests
.
extend
(
batch
.
requests
)
input_lengths
.
extend
(
batch
.
input_lengths
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
# Add cumulative lengths of all previous inputs
cu_seqlens
.
append
(
batch
.
cu_seqlens
[
1
:]
+
cumulative_length
)
input_ids
.
append
(
batch
.
input_ids
)
position_ids
.
append
(
batch
.
position_ids
)
past_key_values
.
append
(
batch
.
past_key_values
)
max_seqlen
=
max
(
max_seqlen
,
batch
.
max_seqlen
)
# Update
cumulative_length
+=
batch
.
cu_seqlens
[
-
1
]
input_ids
=
torch
.
concat
(
input_ids
)
position_ids
=
torch
.
concat
(
position_ids
)
# Concat on dim=1 as first dim represents the model layers
past_key_values
=
torch
.
concat
(
past_key_values
,
dim
=
1
)
cu_seqlens
=
torch
.
concat
(
cu_seqlens
)
return
FlashNeoXBatch
(
batch_id
=
batches
[
0
].
batch_id
,
requests
=
requests
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
past_key_values
=
past_key_values
,
input_lengths
=
input_lengths
,
all_input_ids
=
all_input_ids
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
)
def
__len__
(
self
):
return
len
(
self
.
requests
)
class
FlashNeoX
(
Model
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
else
:
raise
NotImplementedError
(
"FlashNeoX is only available on GPU"
)
if
quantize
:
raise
NotImplementedError
(
"FlashNeoX does not support quantization"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
)
self
.
model
=
(
FlashGPTNeoXForCausalLM
.
from_pretrained
(
model_id
,
revision
=
revision
,
torch_dtype
=
dtype
,
)
.
eval
()
.
cuda
()
)
tokenizer
.
pad_token_id
=
(
self
.
model
.
config
.
pad_token_id
if
self
.
model
.
config
.
pad_token_id
is
not
None
else
self
.
model
.
config
.
eos_token_id
)
super
(
FlashNeoX
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
)
@
property
def
batch_type
(
self
)
->
Type
[
FlashNeoXBatch
]:
return
FlashNeoXBatch
def
decode
(
self
,
generated_ids
:
Union
[
torch
.
Tensor
,
List
[
int
]])
->
str
:
return
self
.
tokenizer
.
decode
(
generated_ids
,
skip_special_tokens
=
True
,
cleanup_tokenization_spaces
=
False
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
max_s
:
int
,
past_key_values
:
Optional
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Model Forward
return
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlens
=
cu_seqlens
,
max_s
=
max_s
,
past_key_values
=
past_key_values
,
)
@
tracer
.
start_as_current_span
(
"generate_token"
)
def
generate_token
(
self
,
batch
:
FlashNeoXBatch
)
->
Tuple
[
List
[
Generation
],
Optional
[
FlashNeoXBatch
]]:
# Better to send to device here to avoid device issues in concatenate
position_ids
=
batch
.
position_ids
.
to
(
self
.
device
,
non_blocking
=
True
)
cu_seqlens
=
batch
.
cu_seqlens
.
to
(
self
.
device
,
non_blocking
=
True
)
input_ids
=
batch
.
input_ids
.
squeeze
(
1
).
to
(
self
.
device
)
out
,
present
=
self
.
forward
(
input_ids
,
position_ids
,
cu_seqlens
,
batch
.
max_seqlen
,
batch
.
past_key_values
,
)
# List of indices to cache
next_batch_keep_indices
=
[]
# New values for next forward
next_batch_input_ids
=
[]
next_batch_position_ids
=
[]
next_batch_cu_seqlens
=
[
0
]
next_batch_max_seqlen
=
0
next_batch_past_key_values
=
[]
next_batch_input_lengths
=
[]
next_batch_all_input_ids
=
[]
# Cumulative length
cumulative_length
=
0
# Results
generations
:
List
[
Generation
]
=
[]
# Zipped iterator
iterator
=
zip
(
batch
.
requests
,
batch
.
input_lengths
,
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
batch
.
all_input_ids
,
)
# For each member of the batch
for
i
,
(
request
,
input_length
,
next_token_chooser
,
stopping_criteria
,
all_input_ids
,
)
in
enumerate
(
iterator
):
# Indexing metadata
start_index
=
cumulative_length
end_index
=
cumulative_length
+
input_length
if
batch
.
past_key_values
is
None
:
# Prefill mode
# out is of shape [cumulative_sequence_lengths, vocab_size]
logits
=
out
[
start_index
:
end_index
]
else
:
# Decode mode
# out is of shape [batch_size, vocab_size]
logits
=
out
[
i
].
unsqueeze
(
0
)
# Select next token
next_token_id
,
logprobs
=
next_token_chooser
(
all_input_ids
,
logits
)
# Copy to cpu to avoid other copies when indexing and calling .item()
next_token_id
=
next_token_id
.
to
(
"cpu"
,
non_blocking
=
True
)
logprobs
=
logprobs
.
to
(
"cpu"
)
next_token_id_squeezed
=
next_token_id
.
squeeze
()
next_token_id_item
=
next_token_id_squeezed
.
item
()
# Append next token to all tokens
all_input_ids
.
append
(
next_token_id_item
)
new_input_length
=
input_length
+
1
# Generated token
next_token_logprob
=
logprobs
[
-
1
,
next_token_id
]
next_token_text
=
self
.
decode_token
(
next_token_id_item
,
)
# Evaluate stopping criteria
stop
,
reason
=
stopping_criteria
(
next_token_id_item
,
next_token_text
,
)
if
stop
:
# Decode generated tokens
output_text
=
self
.
decode
(
all_input_ids
[
-
stopping_criteria
.
current_tokens
:]
)
# Get seed
if
isinstance
(
next_token_chooser
.
choice
,
Sampling
):
seed
=
next_token_chooser
.
choice
.
seed
else
:
seed
=
None
generated_text
=
GeneratedText
(
output_text
,
stopping_criteria
.
current_tokens
,
reason
,
seed
)
else
:
# Keep request in the batch
next_batch_keep_indices
.
append
(
i
)
generated_text
=
None
# Get sequence present
seq_present
=
present
[:,
start_index
:
end_index
]
# Pad it for next iter attention
past
=
torch
.
nn
.
functional
.
pad
(
seq_present
,
(
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
))
next_batch_past_key_values
.
append
(
past
)
next_batch_input_ids
.
append
(
next_token_id
)
next_batch_position_ids
.
append
(
input_length
)
# Cumulative sum
next_batch_cu_seqlens
.
append
(
next_batch_cu_seqlens
[
-
1
]
+
new_input_length
)
next_batch_input_lengths
.
append
(
new_input_length
)
next_batch_all_input_ids
.
append
(
all_input_ids
)
next_batch_max_seqlen
=
max
(
next_batch_max_seqlen
,
new_input_length
)
# Prefill
if
stopping_criteria
.
current_tokens
==
1
:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs
=
[
float
(
"nan"
)]
+
logprobs
.
gather
(
1
,
torch
.
tensor
(
all_input_ids
[
1
:]).
unsqueeze
(
1
)
).
squeeze
(
1
)[:
-
1
].
tolist
()
prefill_token_ids
=
all_input_ids
[:
-
1
]
prefill_texts
=
self
.
tokenizer
.
batch_decode
(
prefill_token_ids
,
clean_up_tokenization_spaces
=
False
,
skip_special_tokens
=
False
,
)
prefill_tokens
=
PrefillTokens
(
prefill_token_ids
,
prefill_logprobs
,
prefill_texts
)
else
:
prefill_tokens
=
None
generation
=
Generation
(
request
.
id
,
prefill_tokens
,
next_token_id_item
,
next_token_logprob
,
next_token_text
,
next_token_id_item
in
self
.
all_special_ids
,
generated_text
,
)
generations
.
append
(
generation
)
cumulative_length
+=
input_length
# We finished all generations in the batch; there is no next batch
if
not
next_batch_keep_indices
:
return
generations
,
None
# If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch
if
len
(
next_batch_keep_indices
)
!=
len
(
batch
):
# Apply indices to requests, token_choosers and stopping_criterias that need to be cached
next_batch_requests
=
[
batch
.
requests
[
i
]
for
i
in
next_batch_keep_indices
]
next_batch_next_token_choosers
=
[
batch
.
next_token_choosers
[
i
]
for
i
in
next_batch_keep_indices
]
next_batch_stopping_criterias
=
[
batch
.
stopping_criterias
[
i
]
for
i
in
next_batch_keep_indices
]
else
:
next_batch_requests
=
batch
.
requests
next_batch_next_token_choosers
=
batch
.
next_token_choosers
next_batch_stopping_criterias
=
batch
.
stopping_criterias
# Create final next batch tensors
next_batch_position_ids
=
torch
.
tensor
(
next_batch_position_ids
,
dtype
=
torch
.
int32
)
next_batch_cu_seqlens
=
torch
.
tensor
(
next_batch_cu_seqlens
,
dtype
=
torch
.
int32
)
if
len
(
next_batch_keep_indices
)
>
1
:
next_batch_input_ids
=
torch
.
concat
(
next_batch_input_ids
)
next_batch_past_key_values
=
torch
.
concat
(
next_batch_past_key_values
,
dim
=
1
)
else
:
next_batch_input_ids
=
next_batch_input_ids
[
0
]
next_batch_past_key_values
=
next_batch_past_key_values
[
0
]
next_batch
=
FlashNeoXBatch
(
batch_id
=
batch
.
batch_id
,
requests
=
next_batch_requests
,
input_ids
=
next_batch_input_ids
,
position_ids
=
next_batch_position_ids
,
cu_seqlens
=
next_batch_cu_seqlens
,
max_seqlen
=
next_batch_max_seqlen
,
past_key_values
=
next_batch_past_key_values
,
input_lengths
=
next_batch_input_lengths
,
all_input_ids
=
next_batch_all_input_ids
,
next_token_choosers
=
next_batch_next_token_choosers
,
stopping_criterias
=
next_batch_stopping_criterias
,
)
return
generations
,
next_batch
class
FlashNeoXSharded
(
FlashNeoX
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
):
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
master
=
self
.
rank
==
0
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
self
.
rank
}
"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
else
:
raise
NotImplementedError
(
"FlashNeoX is only available on GPU"
)
if
quantize
:
raise
NotImplementedError
(
"FlashNeoX does not support quantization"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
)
config
=
AutoConfig
.
from_pretrained
(
model_id
,
revision
=
revision
,
tp_parallel
=
True
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
with
init_empty_weights
():
model
=
FlashGPTNeoXForCausalLM
(
config
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
self
.
load_weights
(
model
,
filenames
,
quantize
=
quantize
,
device
=
device
,
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
)
self
.
model
=
model
.
eval
().
to
(
dtype
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
FlashNeoX
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
)
@
staticmethod
def
load_weights
(
model
,
filenames
:
List
[
str
],
quantize
:
bool
,
device
:
torch
.
device
,
rank
:
int
,
world_size
:
int
,
):
parameters
=
dict
(
model
.
named_parameters
())
for
file
in
filenames
:
with
safe_open
(
file
,
framework
=
"pt"
,
device
=
str
(
device
)
if
not
quantize
else
"cpu"
)
as
f
:
for
name
in
f
.
keys
():
module_name
,
param_name
=
name
.
rsplit
(
"."
,
1
)
module
=
model
.
get_submodule
(
module_name
)
current_parameter_tensor
=
parameters
.
get
(
name
,
None
)
slice_
=
f
.
get_slice
(
name
)
if
isinstance
(
module
,
TensorParallelColumnLinear
):
size
=
slice_
.
get_shape
()[
0
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
slice_
[
start
:
stop
]
elif
isinstance
(
module
,
TensorParallelRowLinear
):
if
param_name
==
"weight"
:
size
=
slice_
.
get_shape
()[
1
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
slice_
[:,
start
:
stop
]
else
:
tensor
=
slice_
[:]
# XXX: Hack for Rowlinear to add the bias only once.
if
rank
!=
0
:
tensor
=
torch
.
zeros_like
(
tensor
)
elif
isinstance
(
module
,
TensorParallelEmbedding
):
size
=
slice_
.
get_shape
()[
0
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
slice_
[
start
:
stop
]
elif
name
==
"embed_out.weight"
and
model
.
gpt_neox
.
tp_embeddings
:
size
=
slice_
.
get_shape
()[
0
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
slice_
[
start
:
stop
]
else
:
try
:
tensor
=
slice_
[:]
except
:
tensor
=
f
.
get_tensor
(
name
)
if
(
current_parameter_tensor
is
not
None
and
current_parameter_tensor
.
shape
!=
tensor
.
shape
):
raise
ValueError
(
f
"Name
{
name
}
-- Current
{
current_parameter_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
)
tensor
=
tensor
.
contiguous
()
if
current_parameter_tensor
is
not
None
:
module
.
_parameters
[
param_name
]
=
tensor
else
:
module
.
_buffers
[
param_name
]
=
tensor
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
max_s
:
int
,
past_key_values
:
Optional
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
self
.
model
.
gpt_neox
.
tp_embeddings
:
logits
,
present
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlens
=
cu_seqlens
,
max_s
=
max_s
,
past_key_values
=
past_key_values
,
)
# Logits are sharded, so we need to gather them
world_logits
=
[
torch
.
empty_like
(
logits
)
for
_
in
range
(
self
.
world_size
)]
torch
.
distributed
.
all_gather
(
world_logits
,
logits
,
group
=
self
.
process_group
)
world_logits
=
torch
.
cat
(
world_logits
,
dim
=
1
)
return
world_logits
,
present
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
else
:
return
super
(
FlashNeoXSharded
,
self
).
forward
(
input_ids
,
position_ids
,
cu_seqlens
,
max_s
,
past_key_values
)
server/text_generation_server/models/flash_neox_modeling.py
0 → 100644
View file @
05e9a796
import
torch
import
torch.distributed
import
torch.nn.functional
as
F
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.models.gpt_neox
import
GPTNeoXConfig
# Flash attention imports
import
rotary_emb
import
flash_attn_cuda
import
dropout_layer_norm
from
flash_attn.layers.rotary
import
RotaryEmbedding
class
TensorParallelColumnLinear
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
,
out_features
,
process_group
:
torch
.
distributed
.
ProcessGroup
,
bias
=
True
,
device
=
None
,
dtype
=
None
,
):
self
.
process_group
=
process_group
self
.
tp_world_size
=
process_group
.
size
()
assert
out_features
%
self
.
tp_world_size
==
0
out_features
=
out_features
//
self
.
tp_world_size
super
().
__init__
(
in_features
=
in_features
,
out_features
=
out_features
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
,
)
@
staticmethod
def
linear
(
input
,
weight
,
bias
):
return
F
.
linear
(
input
,
weight
,
bias
)
def
forward
(
self
,
input
):
return
self
.
linear
(
input
,
self
.
weight
,
self
.
bias
)
class
TensorParallelRowLinear
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
,
out_features
,
process_group
:
torch
.
distributed
.
ProcessGroup
,
bias
=
True
,
device
=
None
,
dtype
=
None
,
):
self
.
process_group
=
process_group
self
.
tp_world_size
=
process_group
.
size
()
assert
in_features
%
self
.
tp_world_size
==
0
in_features
=
in_features
//
self
.
tp_world_size
super
().
__init__
(
in_features
=
in_features
,
out_features
=
out_features
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
,
)
@
staticmethod
def
linear
(
input
,
weight
,
bias
):
return
F
.
linear
(
input
,
weight
,
bias
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
out
=
self
.
linear
(
input
,
self
.
weight
,
self
.
bias
)
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
process_group
)
return
out
class
TensorParallelEmbedding
(
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
process_group
:
torch
.
distributed
.
ProcessGroup
,
padding_idx
=
None
,
max_norm
=
None
,
norm_type
=
2.0
,
scale_grad_by_freq
=
False
,
sparse
=
False
,
_weight
=
None
,
device
=
None
,
dtype
=
None
,
):
self
.
process_group
=
process_group
self
.
tp_rank
=
process_group
.
rank
()
self
.
tp_world_size
=
process_group
.
size
()
self
.
original_num_embeddings
=
num_embeddings
assert
num_embeddings
%
self
.
tp_world_size
==
0
block_size
=
num_embeddings
//
self
.
tp_world_size
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings
self
.
min_id
=
self
.
tp_rank
*
block_size
self
.
max_id
=
(
self
.
tp_rank
+
1
)
*
block_size
super
().
__init__
(
block_size
,
embedding_dim
,
padding_idx
=
padding_idx
,
max_norm
=
max_norm
,
norm_type
=
norm_type
,
scale_grad_by_freq
=
scale_grad_by_freq
,
sparse
=
sparse
,
_weight
=
_weight
,
device
=
device
,
dtype
=
dtype
,
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Sanity check
if
torch
.
any
(
torch
.
logical_or
(
0
>
input
,
input
>=
self
.
original_num_embeddings
)
):
raise
IndexError
(
f
"Input is required to be in [0,
{
self
.
original_num_embeddings
}
[, got min:
{
torch
.
min
(
input
)
}
and max:
{
torch
.
max
(
input
)
}
"
)
# `0` if input is in the correct interval, else `1`
input_mask
=
torch
.
logical_or
(
self
.
min_id
>
input
,
input
>=
self
.
max_id
)
# translate for [0, self.max_id - self.min_id[
input
=
input
-
self
.
min_id
# default all out of bounds values to `0`
input
[
input_mask
]
=
0
out
=
super
().
forward
(
input
)
out
[
input_mask
]
=
0.0
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
process_group
)
return
out
class
PositionRotaryEmbedding
(
RotaryEmbedding
):
def
_update_cos_sin_cache
(
self
,
dtype
,
device
,
seqlen
):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if
(
seqlen
>
self
.
_seq_len_cached
or
self
.
_cos_cached
.
device
!=
device
or
self
.
_cos_cached
.
dtype
!=
dtype
):
self
.
_seq_len_cached
=
seqlen
t
=
torch
.
arange
(
seqlen
,
device
=
device
,
dtype
=
self
.
inv_freq
.
dtype
)
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs
=
torch
.
outer
(
t
,
self
.
inv_freq
.
to
(
device
=
t
.
device
))
self
.
_cos_cached
=
torch
.
cos
(
freqs
).
to
(
dtype
)
self
.
_sin_cached
=
torch
.
sin
(
freqs
).
to
(
dtype
)
def
get_cos_sin
(
self
,
position_ids
:
torch
.
Tensor
,
max_s
:
int
,
dtype
:
torch
.
dtype
):
"""
Return cos and sin for the asked position ids
"""
self
.
_update_cos_sin_cache
(
dtype
,
position_ids
.
device
,
max_s
)
cos
=
torch
.
index_select
(
self
.
_cos_cached
,
0
,
position_ids
)
sin
=
torch
.
index_select
(
self
.
_sin_cached
,
0
,
position_ids
)
return
cos
.
unsqueeze
(
1
),
sin
.
unsqueeze
(
1
)
def
forward
(
self
,
qkv
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
):
rotary_dim
=
cos
.
shape
[
-
1
]
q1
=
qkv
[:,
0
,
:,
:
rotary_dim
]
q2
=
qkv
[:,
0
,
:,
rotary_dim
:
2
*
rotary_dim
]
k1
=
qkv
[:,
1
,
:,
:
rotary_dim
]
k2
=
qkv
[:,
1
,
:,
rotary_dim
:
2
*
rotary_dim
]
rotary_emb
.
apply_rotary
(
q1
,
q2
,
cos
,
sin
,
q1
,
q2
,
False
)
rotary_emb
.
apply_rotary
(
k1
,
k2
,
cos
,
sin
,
k1
,
k2
,
False
)
return
qkv
class
FlashNeoxAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
num_heads
,
hidden_size
,
rotary_pct
,
rotary_emb_base
,
process_group
=
None
):
super
().
__init__
()
self
.
num_heads
=
num_heads
self
.
hidden_size
=
hidden_size
self
.
head_size
=
hidden_size
//
num_heads
rotary_ndims
=
int
(
self
.
head_size
*
rotary_pct
)
self
.
rotary_emb
=
PositionRotaryEmbedding
(
rotary_ndims
,
base
=
rotary_emb_base
)
self
.
softmax_scale
=
self
.
head_size
**
(
-
0.5
)
if
process_group
is
None
:
self
.
query_key_value
=
nn
.
Linear
(
hidden_size
,
3
*
hidden_size
)
self
.
dense
=
nn
.
Linear
(
hidden_size
,
hidden_size
)
else
:
self
.
num_heads
=
self
.
num_heads
//
process_group
.
size
()
self
.
query_key_value
=
TensorParallelColumnLinear
(
hidden_size
,
3
*
hidden_size
,
process_group
=
process_group
,
)
self
.
dense
=
TensorParallelRowLinear
(
hidden_size
,
hidden_size
,
process_group
=
process_group
,
)
self
.
swap_dims
=
True
# TODO: remove and swap dims when loading weights
def
_swap_dims
(
self
):
"""Swap dims for the first inference to avoid an additional permute"""
self
.
query_key_value
.
weight
=
torch
.
nn
.
Parameter
(
self
.
query_key_value
.
weight
.
view
(
self
.
num_heads
,
3
,
self
.
head_size
,
self
.
hidden_size
)
.
permute
(
1
,
0
,
2
,
3
)
.
reshape
(
-
1
,
self
.
hidden_size
)
)
self
.
query_key_value
.
bias
=
torch
.
nn
.
Parameter
(
self
.
query_key_value
.
bias
.
view
(
self
.
num_heads
,
3
,
self
.
head_size
)
.
permute
(
1
,
0
,
2
)
.
reshape
(
-
1
)
)
self
.
swap_dims
=
False
def
forward
(
self
,
hidden_states
,
cos
,
sin
,
cu_seqlens
,
max_s
,
layer_past
,
layer_past_present_indices
,
cu_seqlens_q
,
):
if
self
.
swap_dims
:
self
.
_swap_dims
()
qkv
=
self
.
query_key_value
(
hidden_states
)
qkv
=
qkv
.
view
(
-
1
,
3
,
self
.
num_heads
,
self
.
head_size
)
qkv_rot
=
self
.
rotary_emb
(
qkv
,
cos
,
sin
)
# Prefill
if
layer_past_present_indices
is
None
:
# Copy to layer past
layer_past
[...]
=
qkv_rot
[:,
1
:]
# output
attn_output
=
torch
.
empty_like
(
qkv
[:,
0
])
# flash attention
flash_attn_cuda
.
fwd
(
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
],
attn_output
,
cu_seqlens
,
cu_seqlens
,
max_s
,
max_s
,
0.0
,
self
.
softmax_scale
,
False
,
True
,
False
,
0
,
None
,
)
# Decode
else
:
query
=
qkv_rot
[:,
0
]
# Add present to the layer_past tensor at the correct indices
layer_past
[
layer_past_present_indices
]
=
qkv_rot
[:,
1
:]
# output
attn_output
=
torch
.
empty_like
(
query
)
# flash attention
flash_attn_cuda
.
fwd
(
query
,
layer_past
[:,
0
],
layer_past
[:,
1
],
attn_output
,
cu_seqlens_q
,
cu_seqlens
,
1
,
max_s
,
0.0
,
self
.
softmax_scale
,
False
,
False
,
False
,
0
,
None
,
)
return
self
.
dense
(
attn_output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
))
class
FlashMLP
(
nn
.
Module
):
def
__init__
(
self
,
act
,
hidden_size
,
intermediate_size
,
process_group
=
None
):
super
().
__init__
()
self
.
act
=
(
ACT2FN
[
act
]
if
"gelu"
not
in
act
else
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
)
)
if
process_group
is
None
:
self
.
dense_h_to_4h
=
nn
.
Linear
(
hidden_size
,
intermediate_size
)
self
.
dense_4h_to_h
=
nn
.
Linear
(
intermediate_size
,
hidden_size
)
else
:
self
.
dense_h_to_4h
=
TensorParallelColumnLinear
(
hidden_size
,
intermediate_size
,
process_group
=
process_group
,
)
self
.
dense_4h_to_h
=
TensorParallelRowLinear
(
intermediate_size
,
hidden_size
,
process_group
=
process_group
,
)
self
.
heuristic
=
"auto"
self
.
process_group
=
process_group
def
forward
(
self
,
hidden_states
):
hidden_states
=
self
.
dense_h_to_4h
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
dense_4h_to_h
(
hidden_states
)
return
hidden_states
class
FlashNeoXLayer
(
nn
.
Module
):
def
__init__
(
self
,
num_heads
,
act
,
hidden_size
,
intermediate_size
,
rotary_pct
,
rotary_emb_base
,
layer_norm_eps
,
use_parallel_residual
,
process_group
=
None
,
):
super
().
__init__
()
self
.
use_parallel_residual
=
use_parallel_residual
self
.
input_layernorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
layer_norm_eps
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
layer_norm_eps
)
self
.
attention
=
FlashNeoxAttention
(
num_heads
,
hidden_size
,
rotary_pct
,
rotary_emb_base
,
process_group
)
self
.
mlp
=
FlashMLP
(
act
,
hidden_size
,
intermediate_size
,
process_group
)
def
forward
(
self
,
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlens
,
max_s
,
layer_past
,
layer_past_present_indices
,
cu_seqlens_q
,
):
if
self
.
use_parallel_residual
:
# faster input layer norm
ln1_hidden_states
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
hidden_states
,
None
,
self
.
input_layernorm
.
weight
,
self
.
input_layernorm
.
bias
,
None
,
None
,
None
,
None
,
0.0
,
self
.
input_layernorm
.
eps
,
1.0
,
0
,
None
,
False
,
False
,
)
attn_output
=
self
.
attention
(
ln1_hidden_states
,
cos
,
sin
,
cu_seqlens
,
max_s
,
layer_past
,
layer_past_present_indices
,
cu_seqlens_q
,
)
# faster post attention layer norm
ln2_hidden_states
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
hidden_states
,
None
,
self
.
post_attention_layernorm
.
weight
,
self
.
post_attention_layernorm
.
bias
,
None
,
None
,
None
,
None
,
0.0
,
self
.
post_attention_layernorm
.
eps
,
1.0
,
0
,
None
,
False
,
False
,
)
mlp_output
=
self
.
mlp
(
ln2_hidden_states
)
return
mlp_output
+
attn_output
+
hidden_states
,
None
else
:
# faster input layer norm
hidden_states
,
residual
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
hidden_states
,
residual
,
self
.
input_layernorm
.
weight
,
self
.
input_layernorm
.
bias
,
None
,
None
,
None
,
None
,
0.0
,
self
.
input_layernorm
.
eps
,
1.0
,
0
,
None
,
False
,
False
,
)
hidden_states
=
self
.
attention
(
hidden_states
,
cos
,
sin
,
cu_seqlens
,
max_s
,
layer_past
,
layer_past_present_indices
,
cu_seqlens_q
,
)
# faster post attention layer norm
hidden_states
,
residual
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
hidden_states
,
residual
,
self
.
post_attention_layernorm
.
weight
,
self
.
post_attention_layernorm
.
bias
,
None
,
None
,
None
,
None
,
0.0
,
self
.
post_attention_layernorm
.
eps
,
1.0
,
0
,
None
,
False
,
False
,
)
mlp_output
=
self
.
mlp
(
hidden_states
)
return
mlp_output
,
residual
class
FlashGPTNeoXPreTrainedModel
(
PreTrainedModel
):
config_class
=
GPTNeoXConfig
base_model_prefix
=
"gpt_neox"
supports_gradient_checkpointing
=
False
_no_split_modules
=
None
class
FlashGPTNeoXModel
(
FlashGPTNeoXPreTrainedModel
):
def
__init__
(
self
,
config
,
process_group
=
None
):
super
().
__init__
(
config
)
self
.
config
=
config
self
.
tp_embeddings
=
False
if
process_group
is
not
None
:
self
.
tp_rank
=
process_group
.
rank
()
self
.
tp_world_size
=
process_group
.
size
()
if
config
.
vocab_size
%
self
.
tp_world_size
==
0
:
self
.
tp_embeddings
=
True
if
self
.
tp_embeddings
:
self
.
embed_in
=
TensorParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
process_group
=
process_group
)
else
:
self
.
embed_in
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
(
[
FlashNeoXLayer
(
config
.
num_attention_heads
,
config
.
hidden_act
,
config
.
hidden_size
,
config
.
intermediate_size
,
config
.
rotary_pct
,
config
.
rotary_emb_base
,
config
.
layer_norm_eps
,
config
.
use_parallel_residual
,
process_group
,
)
for
_
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
gradient_checkpointing
=
False
self
.
head_size
=
self
.
layers
[
0
].
attention
.
head_size
self
.
num_heads
=
self
.
layers
[
0
].
attention
.
num_heads
def
forward
(
self
,
input_ids
,
position_ids
,
cu_seqlens
,
max_s
,
past_key_values
=
None
,
):
hidden_states
=
self
.
embed_in
(
input_ids
)
# Prefill
if
past_key_values
is
None
:
# Create past tensor
past_key_values
=
hidden_states
.
new_empty
(
(
len
(
self
.
layers
),
len
(
hidden_states
),
2
,
self
.
num_heads
,
self
.
head_size
,
)
)
layer_past_present_indices
=
None
cu_seqlens_q
=
None
# Decode
else
:
# Create indices from cumulative sequence lengths
layer_past_present_indices
=
cu_seqlens
[
1
:]
-
1
cu_seqlens_q
=
torch
.
arange
(
len
(
cu_seqlens
),
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos
,
sin
=
self
.
layers
[
0
].
attention
.
rotary_emb
.
get_cos_sin
(
position_ids
,
max_s
,
hidden_states
.
dtype
)
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
layers
):
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlens
,
max_s
,
past_key_values
[
i
],
layer_past_present_indices
,
cu_seqlens_q
,
)
# Faster final layer norm
hidden_states
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
hidden_states
,
residual
,
self
.
final_layer_norm
.
weight
,
self
.
final_layer_norm
.
bias
,
None
,
None
,
None
,
None
,
0.0
,
self
.
final_layer_norm
.
eps
,
1.0
,
0
,
None
,
False
,
False
,
)
return
hidden_states
,
past_key_values
class
FlashGPTNeoXForCausalLM
(
FlashGPTNeoXPreTrainedModel
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
if
config
.
tp_parallel
:
process_group
=
torch
.
distributed
.
distributed_c10d
.
_get_default_group
()
else
:
process_group
=
None
self
.
gpt_neox
=
FlashGPTNeoXModel
(
config
,
process_group
)
if
self
.
gpt_neox
.
tp_embeddings
:
self
.
embed_out
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
vocab_size
//
process_group
.
size
(),
bias
=
False
,
)
else
:
self
.
embed_out
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
def
forward
(
self
,
input_ids
,
position_ids
,
cu_seqlens
,
max_s
,
past_key_values
=
None
,
):
hidden_states
,
present
=
self
.
gpt_neox
(
input_ids
,
position_ids
,
cu_seqlens
,
max_s
,
past_key_values
)
return
self
.
embed_out
(
hidden_states
),
present
server/text_generation_server/utils/tokens.py
View file @
05e9a796
...
@@ -24,7 +24,7 @@ class Sampling:
...
@@ -24,7 +24,7 @@ class Sampling:
self
.
seed
=
seed
self
.
seed
=
seed
def
__call__
(
self
,
logits
):
def
__call__
(
self
,
logits
):
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
)
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
-
1
)
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
1
,
generator
=
self
.
generator
)
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
1
,
generator
=
self
.
generator
)
return
next_tokens
return
next_tokens
...
...
server/text_generation_server/utils/watermark.py
View file @
05e9a796
...
@@ -17,6 +17,7 @@ import os
...
@@ -17,6 +17,7 @@ import os
import
torch
import
torch
from
transformers
import
LogitsProcessor
from
transformers
import
LogitsProcessor
from
typing
import
List
,
Union
GAMMA
=
os
.
getenv
(
"WATERMARK_GAMMA"
,
0.5
)
GAMMA
=
os
.
getenv
(
"WATERMARK_GAMMA"
,
0.5
)
DELTA
=
os
.
getenv
(
"WATERMARK_DELTA"
,
2.0
)
DELTA
=
os
.
getenv
(
"WATERMARK_DELTA"
,
2.0
)
...
@@ -36,23 +37,32 @@ class WatermarkLogitsProcessor(LogitsProcessor):
...
@@ -36,23 +37,32 @@ class WatermarkLogitsProcessor(LogitsProcessor):
self
.
rng
=
torch
.
Generator
(
device
=
device
)
self
.
rng
=
torch
.
Generator
(
device
=
device
)
self
.
hash_key
=
hash_key
self
.
hash_key
=
hash_key
def
_seed_rng
(
self
,
input_ids
:
torch
.
LongTensor
)
->
None
:
def
_seed_rng
(
self
,
input_ids
:
Union
[
List
[
int
],
torch
.
LongTensor
]):
assert
(
if
isinstance
(
input_ids
,
list
):
input_ids
.
shape
[
-
1
]
>=
1
assert
(
),
"requires at least a 1 token prefix sequence to seed rng"
len
(
input_ids
)
>=
1
prev_token
=
input_ids
[
-
1
].
item
()
),
"requires at least a 1 token prefix sequence to seed rng"
prev_token
=
input_ids
[
-
1
]
else
:
input_ids
=
input_ids
[
0
]
assert
len
(
input_ids
)
==
1
assert
(
input_ids
.
shape
[
-
1
]
>=
1
),
"requires at least a 1 token prefix sequence to seed rng"
prev_token
=
input_ids
[
-
1
].
item
()
self
.
rng
.
manual_seed
(
self
.
hash_key
*
prev_token
)
self
.
rng
.
manual_seed
(
self
.
hash_key
*
prev_token
)
def
_get_greenlist_ids
(
def
_get_greenlist_ids
(
self
,
input_ids
:
torch
.
LongTensor
,
max_value
:
int
self
,
)
->
list
[
int
]:
input_ids
:
Union
[
List
[
int
],
torch
.
LongTensor
],
max_value
:
int
,
device
:
torch
.
device
,
)
->
List
[
int
]:
# seed the rng using the previous tokens/prefix
# seed the rng using the previous tokens/prefix
self
.
_seed_rng
(
input_ids
)
self
.
_seed_rng
(
input_ids
)
greenlist_size
=
int
(
max_value
*
self
.
gamma
)
greenlist_size
=
int
(
max_value
*
self
.
gamma
)
vocab_permutation
=
torch
.
randperm
(
vocab_permutation
=
torch
.
randperm
(
max_value
,
device
=
device
,
generator
=
self
.
rng
)
max_value
,
device
=
input_ids
.
device
,
generator
=
self
.
rng
)
greenlist_ids
=
vocab_permutation
[:
greenlist_size
]
greenlist_ids
=
vocab_permutation
[:
greenlist_size
]
return
greenlist_ids
return
greenlist_ids
...
@@ -73,10 +83,11 @@ class WatermarkLogitsProcessor(LogitsProcessor):
...
@@ -73,10 +83,11 @@ class WatermarkLogitsProcessor(LogitsProcessor):
return
scores
return
scores
def
__call__
(
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
self
,
input_ids
:
Union
[
List
[
int
],
torch
.
LongTensor
]
,
scores
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
)
->
torch
.
FloatTensor
:
assert
len
(
input_ids
)
==
1
greenlist_ids
=
self
.
_get_greenlist_ids
(
greenlist_ids
=
self
.
_get_greenlist_ids
(
input_ids
[
0
],
scores
.
shape
[
-
1
])
input_ids
,
scores
.
shape
[
-
1
],
scores
.
device
)
green_tokens_mask
=
self
.
_calc_greenlist_mask
(
green_tokens_mask
=
self
.
_calc_greenlist_mask
(
scores
=
scores
,
greenlist_token_ids
=
greenlist_ids
scores
=
scores
,
greenlist_token_ids
=
greenlist_ids
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment