Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
d6a93fe9
Unverified
Commit
d6a93fe9
authored
Mar 24, 2023
by
OlivierDehaene
Committed by
GitHub
Mar 24, 2023
Browse files
fix(server): fix flash-neox scores warping (#137)
parent
05e9a796
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
79 additions
and
60 deletions
+79
-60
clients/python/pyproject.toml
clients/python/pyproject.toml
+1
-1
clients/python/tests/test_types.py
clients/python/tests/test_types.py
+2
-0
clients/python/text_generation/client.py
clients/python/text_generation/client.py
+2
-8
clients/python/text_generation/types.py
clients/python/text_generation/types.py
+2
-0
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+1
-0
server/text_generation_server/models/flash_neox.py
server/text_generation_server/models/flash_neox.py
+37
-21
server/text_generation_server/models/flash_neox_modeling.py
server/text_generation_server/models/flash_neox_modeling.py
+33
-29
server/text_generation_server/utils/watermark.py
server/text_generation_server/utils/watermark.py
+1
-1
No files found.
clients/python/pyproject.toml
View file @
d6a93fe9
[tool.poetry]
name
=
"text-generation"
version
=
"0.4.
0
"
version
=
"0.4.
1
"
description
=
"Hugging Face Text Generation Python Client"
license
=
"Apache-2.0"
authors
=
[
"Olivier Dehaene <olivier@huggingface.co>"
]
...
...
clients/python/tests/test_types.py
View file @
d6a93fe9
...
...
@@ -14,6 +14,8 @@ def test_parameters_validation():
Parameters
(
best_of
=
2
,
do_sample
=
True
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
best_of
=
2
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
best_of
=
2
,
seed
=
1
)
# Test repetition_penalty
Parameters
(
repetition_penalty
=
1
)
...
...
clients/python/text_generation/client.py
View file @
d6a93fe9
...
...
@@ -150,7 +150,6 @@ class Client:
prompt
:
str
,
do_sample
:
bool
=
False
,
max_new_tokens
:
int
=
20
,
best_of
:
Optional
[
int
]
=
None
,
repetition_penalty
:
Optional
[
float
]
=
None
,
return_full_text
:
bool
=
False
,
seed
:
Optional
[
int
]
=
None
,
...
...
@@ -172,8 +171,6 @@ class Client:
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
...
...
@@ -203,7 +200,7 @@ class Client:
"""
# Validate parameters
parameters
=
Parameters
(
best_of
=
best_of
,
best_of
=
None
,
details
=
True
,
do_sample
=
do_sample
,
max_new_tokens
=
max_new_tokens
,
...
...
@@ -388,7 +385,6 @@ class AsyncClient:
prompt
:
str
,
do_sample
:
bool
=
False
,
max_new_tokens
:
int
=
20
,
best_of
:
Optional
[
int
]
=
None
,
repetition_penalty
:
Optional
[
float
]
=
None
,
return_full_text
:
bool
=
False
,
seed
:
Optional
[
int
]
=
None
,
...
...
@@ -410,8 +406,6 @@ class AsyncClient:
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
...
...
@@ -441,7 +435,7 @@ class AsyncClient:
"""
# Validate parameters
parameters
=
Parameters
(
best_of
=
best_of
,
best_of
=
None
,
details
=
True
,
do_sample
=
do_sample
,
max_new_tokens
=
max_new_tokens
,
...
...
clients/python/text_generation/types.py
View file @
d6a93fe9
...
...
@@ -43,6 +43,8 @@ class Parameters(BaseModel):
if
field_value
is
not
None
:
if
field_value
<=
0
:
raise
ValidationError
(
"`best_of` must be strictly positive"
)
if
field_value
>
1
and
values
[
"seed"
]
is
not
None
:
raise
ValidationError
(
"`seed` must not be set when `best_of` is > 1"
)
sampling
=
(
values
[
"do_sample"
]
|
(
values
[
"temperature"
]
is
not
None
)
...
...
server/text_generation_server/models/__init__.py
View file @
d6a93fe9
...
...
@@ -16,6 +16,7 @@ from text_generation_server.models.t5 import T5Sharded
try
:
from
text_generation_server.models.flash_neox
import
FlashNeoX
,
FlashNeoXSharded
FLASH_NEOX
=
torch
.
cuda
.
is_available
()
and
int
(
os
.
environ
.
get
(
"FLASH_NEOX"
,
0
))
==
1
except
ImportError
:
if
int
(
os
.
environ
.
get
(
"FLASH_NEOX"
,
0
))
==
1
:
...
...
server/text_generation_server/models/flash_neox.py
View file @
d6a93fe9
import
torch
import
torch.distributed
from
torch.nn
import
functional
as
F
from
accelerate
import
init_empty_weights
from
dataclasses
import
dataclass
from
opentelemetry
import
trace
...
...
@@ -48,6 +50,7 @@ class FlashNeoXBatch(Batch):
# All tokens
all_input_ids
:
List
[
List
[
int
]]
all_input_ids_tensor
:
List
[
torch
.
Tensor
]
# Lengths of all generations present in the batch
input_lengths
:
List
[
int
]
...
...
@@ -75,6 +78,7 @@ class FlashNeoXBatch(Batch):
input_lengths
=
[]
all_input_ids
=
[]
all_input_ids_tensor
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
...
...
@@ -84,15 +88,14 @@ class FlashNeoXBatch(Batch):
# Parse batch
for
r
in
pb
.
requests
:
tokenized_input
=
tokenizer
(
r
.
inputs
,
return_tensors
=
"pt"
)[
"input_ids"
].
squeeze
(
0
)
input_ids
.
append
(
tokenized_input
)
all_input_ids
.
append
(
tokenized_input
.
tolist
())
tokenized_input
=
tokenizer
(
r
.
inputs
)[
"input_ids"
]
input_length
=
len
(
tokenized_input
)
max_seqlen
=
max
(
max_seqlen
,
input_length
)
input_lengths
.
append
(
input_length
)
all_input_ids
.
append
(
tokenized_input
)
tokenized_input
=
torch
.
tensor
(
tokenized_input
,
device
=
device
)
input_ids
.
append
(
tokenized_input
)
# Position ids
position_ids
.
append
(
torch
.
arange
(
0
,
input_length
,
dtype
=
torch
.
int32
))
...
...
@@ -101,14 +104,18 @@ class FlashNeoXBatch(Batch):
cu_seqlens
.
append
(
cumulative_length
+
input_length
)
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
))
stopping_criterias
.
append
(
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
stopping_criteria
=
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
stopping_criterias
.
append
(
stopping_criteria
)
all_input_ids_tensor
.
append
(
F
.
pad
(
tokenized_input
,
(
0
,
stopping_criteria
.
max_new_tokens
))
)
# Update
cumulative_length
+=
input_length
input_ids
=
torch
.
concat
(
input_ids
)
.
unsqueeze
(
1
)
input_ids
=
torch
.
concat
(
input_ids
)
position_ids
=
torch
.
concat
(
position_ids
)
cu_seqlens
=
torch
.
tensor
(
cu_seqlens
,
dtype
=
torch
.
int32
)
...
...
@@ -122,6 +129,7 @@ class FlashNeoXBatch(Batch):
past_key_values
=
None
,
input_lengths
=
input_lengths
,
all_input_ids
=
all_input_ids
,
all_input_ids_tensor
=
all_input_ids_tensor
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
)
...
...
@@ -133,6 +141,7 @@ class FlashNeoXBatch(Batch):
requests
=
[]
input_lengths
=
[]
all_input_ids
=
[]
all_input_ids_tensor
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
...
...
@@ -150,6 +159,7 @@ class FlashNeoXBatch(Batch):
requests
.
extend
(
batch
.
requests
)
input_lengths
.
extend
(
batch
.
input_lengths
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
all_input_ids_tensor
.
extend
(
batch
.
all_input_ids_tensor
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
...
...
@@ -181,6 +191,7 @@ class FlashNeoXBatch(Batch):
past_key_values
=
past_key_values
,
input_lengths
=
input_lengths
,
all_input_ids
=
all_input_ids
,
all_input_ids_tensor
=
all_input_ids_tensor
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
)
...
...
@@ -255,11 +266,10 @@ class FlashNeoX(Model):
)
->
Tuple
[
List
[
Generation
],
Optional
[
FlashNeoXBatch
]]:
# Better to send to device here to avoid device issues in concatenate
position_ids
=
batch
.
position_ids
.
to
(
self
.
device
,
non_blocking
=
True
)
cu_seqlens
=
batch
.
cu_seqlens
.
to
(
self
.
device
,
non_blocking
=
True
)
input_ids
=
batch
.
input_ids
.
squeeze
(
1
).
to
(
self
.
device
)
cu_seqlens
=
batch
.
cu_seqlens
.
to
(
self
.
device
)
out
,
present
=
self
.
forward
(
input_ids
,
batch
.
input_ids
,
position_ids
,
cu_seqlens
,
batch
.
max_seqlen
,
...
...
@@ -277,6 +287,7 @@ class FlashNeoX(Model):
next_batch_past_key_values
=
[]
next_batch_input_lengths
=
[]
next_batch_all_input_ids
=
[]
next_batch_all_input_ids_tensor
=
[]
# Cumulative length
cumulative_length
=
0
...
...
@@ -291,6 +302,7 @@ class FlashNeoX(Model):
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
batch
.
all_input_ids
,
batch
.
all_input_ids_tensor
,
)
# For each member of the batch
...
...
@@ -300,6 +312,7 @@ class FlashNeoX(Model):
next_token_chooser
,
stopping_criteria
,
all_input_ids
,
all_input_ids_tensor
,
)
in
enumerate
(
iterator
):
# Indexing metadata
start_index
=
cumulative_length
...
...
@@ -315,20 +328,19 @@ class FlashNeoX(Model):
logits
=
out
[
i
].
unsqueeze
(
0
)
# Select next token
next_token_id
,
logprobs
=
next_token_chooser
(
all_input_ids
,
logits
)
# Copy to cpu to avoid other copies when indexing and calling .item()
next_token_id
=
next_token_id
.
to
(
"cpu"
,
non_blocking
=
True
)
logprobs
=
logprobs
.
to
(
"cpu"
)
next_token_id
,
logprobs
=
next_token_chooser
(
all_input_ids_tensor
[
None
,
:
input_length
],
logits
)
next_token_id_squeezed
=
next_token_id
.
squeeze
()
next_token_id_item
=
next_token_id_squeezed
.
item
()
# Append next token to all tokens
all_input_ids
.
append
(
next_token_id_item
)
all_input_ids_tensor
[
input_length
]
=
next_token_id_item
new_input_length
=
input_length
+
1
# Generated token
next_token_logprob
=
logprobs
[
-
1
,
next_token_id
]
next_token_logprob
=
logprobs
[
-
1
,
next_token_id
_item
]
next_token_text
=
self
.
decode_token
(
next_token_id_item
,
)
...
...
@@ -372,13 +384,14 @@ class FlashNeoX(Model):
)
next_batch_input_lengths
.
append
(
new_input_length
)
next_batch_all_input_ids
.
append
(
all_input_ids
)
next_batch_all_input_ids_tensor
.
append
(
all_input_ids_tensor
)
next_batch_max_seqlen
=
max
(
next_batch_max_seqlen
,
new_input_length
)
# Prefill
if
stopping_criteria
.
current_tokens
==
1
:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs
=
[
float
(
"nan"
)]
+
logprobs
.
gather
(
1
,
torch
.
tensor
(
all_input_ids
[
1
:])
.
unsqueeze
(
1
)
1
,
all_input_ids_tensor
[
1
:
input_length
]
.
unsqueeze
(
1
)
).
squeeze
(
1
)[:
-
1
].
tolist
()
prefill_token_ids
=
all_input_ids
[:
-
1
]
prefill_texts
=
self
.
tokenizer
.
batch_decode
(
...
...
@@ -431,12 +444,14 @@ class FlashNeoX(Model):
)
next_batch_cu_seqlens
=
torch
.
tensor
(
next_batch_cu_seqlens
,
dtype
=
torch
.
int32
)
if
len
(
next_batch_keep_indices
)
>
1
:
next_batch_input_ids
=
torch
.
concat
(
next_batch_input_ids
)
next_batch_input_ids
=
torch
.
concat
(
next_batch_input_ids
)
.
squeeze
(
1
)
next_batch_past_key_values
=
torch
.
concat
(
next_batch_past_key_values
,
dim
=
1
)
else
:
next_batch_input_ids
=
next_batch_input_ids
[
0
]
next_batch_input_ids
=
next_batch_input_ids
[
0
]
.
view
(
1
)
next_batch_past_key_values
=
next_batch_past_key_values
[
0
]
print
(
next_batch_input_ids
.
shape
)
next_batch
=
FlashNeoXBatch
(
batch_id
=
batch
.
batch_id
,
requests
=
next_batch_requests
,
...
...
@@ -447,6 +462,7 @@ class FlashNeoX(Model):
past_key_values
=
next_batch_past_key_values
,
input_lengths
=
next_batch_input_lengths
,
all_input_ids
=
next_batch_all_input_ids
,
all_input_ids_tensor
=
next_batch_all_input_ids_tensor
,
next_token_choosers
=
next_batch_next_token_choosers
,
stopping_criterias
=
next_batch_stopping_criterias
,
)
...
...
server/text_generation_server/models/flash_neox_modeling.py
View file @
d6a93fe9
import
torch
import
torch.distributed
import
torch.nn.functional
as
F
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.modeling_utils
import
PreTrainedModel
...
...
@@ -16,7 +14,29 @@ import dropout_layer_norm
from
flash_attn.layers.rotary
import
RotaryEmbedding
class
TensorParallelColumnLinear
(
nn
.
Linear
):
class
FastLinear
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
bias
:
bool
=
True
,
device
=
None
,
dtype
=
None
,
)
->
None
:
super
(
FastLinear
,
self
).
__init__
(
in_features
,
out_features
,
bias
,
device
,
dtype
)
self
.
swap_dims
=
True
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
swap_dims
:
self
.
weight
=
nn
.
Parameter
(
self
.
weight
.
T
)
self
.
swap_dims
=
False
if
self
.
bias
is
not
None
:
return
torch
.
addmm
(
self
.
bias
,
input
,
self
.
weight
)
return
torch
.
matmul
(
input
,
self
.
weight
)
class
TensorParallelColumnLinear
(
FastLinear
):
def
__init__
(
self
,
in_features
,
...
...
@@ -39,15 +59,11 @@ class TensorParallelColumnLinear(nn.Linear):
dtype
=
dtype
,
)
@
staticmethod
def
linear
(
input
,
weight
,
bias
):
return
F
.
linear
(
input
,
weight
,
bias
)
def
forward
(
self
,
input
):
return
s
elf
.
linear
(
input
,
self
.
weight
,
self
.
bias
)
return
s
uper
(
TensorParallelColumnLinear
,
self
).
forward
(
input
)
class
TensorParallelRowLinear
(
nn
.
Linear
):
class
TensorParallelRowLinear
(
Fast
Linear
):
def
__init__
(
self
,
in_features
,
...
...
@@ -70,12 +86,8 @@ class TensorParallelRowLinear(nn.Linear):
dtype
=
dtype
,
)
@
staticmethod
def
linear
(
input
,
weight
,
bias
):
return
F
.
linear
(
input
,
weight
,
bias
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
out
=
s
elf
.
linear
(
input
,
self
.
weight
,
self
.
bias
)
out
=
s
uper
(
TensorParallelRowLinear
,
self
).
forward
(
input
)
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
process_group
)
return
out
...
...
@@ -122,14 +134,6 @@ class TensorParallelEmbedding(nn.Embedding):
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Sanity check
if
torch
.
any
(
torch
.
logical_or
(
0
>
input
,
input
>=
self
.
original_num_embeddings
)
):
raise
IndexError
(
f
"Input is required to be in [0,
{
self
.
original_num_embeddings
}
[, got min:
{
torch
.
min
(
input
)
}
and max:
{
torch
.
max
(
input
)
}
"
)
# `0` if input is in the correct interval, else `1`
input_mask
=
torch
.
logical_or
(
self
.
min_id
>
input
,
input
>=
self
.
max_id
)
# translate for [0, self.max_id - self.min_id[
...
...
@@ -196,8 +200,8 @@ class FlashNeoxAttention(torch.nn.Module):
self
.
softmax_scale
=
self
.
head_size
**
(
-
0.5
)
if
process_group
is
None
:
self
.
query_key_value
=
nn
.
Linear
(
hidden_size
,
3
*
hidden_size
)
self
.
dense
=
nn
.
Linear
(
hidden_size
,
hidden_size
)
self
.
query_key_value
=
Fast
Linear
(
hidden_size
,
3
*
hidden_size
)
self
.
dense
=
Fast
Linear
(
hidden_size
,
hidden_size
)
else
:
self
.
num_heads
=
self
.
num_heads
//
process_group
.
size
()
self
.
query_key_value
=
TensorParallelColumnLinear
(
...
...
@@ -312,8 +316,8 @@ class FlashMLP(nn.Module):
)
if
process_group
is
None
:
self
.
dense_h_to_4h
=
nn
.
Linear
(
hidden_size
,
intermediate_size
)
self
.
dense_4h_to_h
=
nn
.
Linear
(
intermediate_size
,
hidden_size
)
self
.
dense_h_to_4h
=
Fast
Linear
(
hidden_size
,
intermediate_size
)
self
.
dense_4h_to_h
=
Fast
Linear
(
intermediate_size
,
hidden_size
)
else
:
self
.
dense_h_to_4h
=
TensorParallelColumnLinear
(
hidden_size
,
...
...
@@ -556,7 +560,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
# Create indices from cumulative sequence lengths
layer_past_present_indices
=
cu_seqlens
[
1
:]
-
1
cu_seqlens_q
=
torch
.
arange
(
len
(
cu_seqlens
)
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
cu_seqlens
.
shape
[
0
]
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
# Get rotary cos and sin for this forward
...
...
@@ -613,13 +617,13 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
self
.
gpt_neox
=
FlashGPTNeoXModel
(
config
,
process_group
)
if
self
.
gpt_neox
.
tp_embeddings
:
self
.
embed_out
=
nn
.
Linear
(
self
.
embed_out
=
Fast
Linear
(
config
.
hidden_size
,
config
.
vocab_size
//
process_group
.
size
(),
bias
=
False
,
)
else
:
self
.
embed_out
=
nn
.
Linear
(
self
.
embed_out
=
Fast
Linear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
...
...
server/text_generation_server/utils/watermark.py
View file @
d6a93fe9
...
...
@@ -44,8 +44,8 @@ class WatermarkLogitsProcessor(LogitsProcessor):
),
"requires at least a 1 token prefix sequence to seed rng"
prev_token
=
input_ids
[
-
1
]
else
:
input_ids
=
input_ids
[
0
]
assert
len
(
input_ids
)
==
1
input_ids
=
input_ids
[
0
]
assert
(
input_ids
.
shape
[
-
1
]
>=
1
),
"requires at least a 1 token prefix sequence to seed rng"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment