Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
3f2542bb
Unverified
Commit
3f2542bb
authored
Apr 05, 2023
by
OlivierDehaene
Committed by
GitHub
Apr 05, 2023
Browse files
fix(server): fix escape characters in stop sequence (#155)
parent
9122e7bd
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
90 additions
and
69 deletions
+90
-69
server/tests/utils/test_tokens.py
server/tests/utils/test_tokens.py
+9
-0
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
...erver/models/custom_modeling/flash_santacoder_modeling.py
+57
-59
server/text_generation_server/models/flash_santacoder.py
server/text_generation_server/models/flash_santacoder.py
+23
-10
server/text_generation_server/utils/tokens.py
server/text_generation_server/utils/tokens.py
+1
-0
No files found.
server/tests/utils/test_tokens.py
View file @
3f2542bb
...
...
@@ -14,6 +14,15 @@ def test_stop_sequence_criteria():
assert
not
criteria
(
"/test; "
)
def
test_stop_sequence_criteria_escape
():
criteria
=
StopSequenceCriteria
(
"<|stop|>"
)
assert
not
criteria
(
"<"
)
assert
not
criteria
(
"<|stop"
)
assert
criteria
(
"<|stop|>"
)
assert
not
criteria
(
"<|stop|> "
)
def
test_stopping_criteria
():
criteria
=
StoppingCriteria
(
0
,
[
StopSequenceCriteria
(
"/test;"
)],
max_new_tokens
=
5
)
assert
criteria
(
65827
,
"/test"
)
==
(
False
,
None
)
...
...
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
View file @
3f2542bb
...
...
@@ -162,15 +162,17 @@ class FlashMQAttention(torch.nn.Module):
class
MLP
(
nn
.
Module
):
def
__init__
(
self
,
act
,
hidden_size
,
intermediate_size
,
process_group
=
None
):
def
__init__
(
self
,
act
,
hidden_size
,
intermediate_size
,
process_group
=
None
):
super
().
__init__
()
self
.
act
=
(
ACT2FN
[
act
]
if
"gelu"
not
in
act
else
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
else
None
)
else
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
else
None
,
)
)
if
process_group
is
None
:
...
...
@@ -232,9 +234,7 @@ class Block(nn.Module):
cu_seqlens_q
,
)
hidden_states
,
residual
=
self
.
ln_2
(
hidden_states
,
residual
)
hidden_states
,
residual
=
self
.
ln_2
(
hidden_states
,
residual
)
mlp_output
=
self
.
mlp
(
hidden_states
)
...
...
@@ -258,16 +258,16 @@ class FlashSantacoderModel(nn.Module):
config
.
num_attention_heads
,
config
.
activation_function
,
config
.
hidden_size
,
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
config
.
hidden_size
,
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
config
.
hidden_size
,
config
.
layer_norm_epsilon
,
process_group
,
)
for
_
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
ln_f
=
FastLayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_f
=
FastLayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
head_size
=
self
.
h
[
0
].
attn
.
head_size
self
.
num_heads
=
self
.
h
[
0
].
attn
.
num_heads
...
...
@@ -335,9 +335,7 @@ class FlashSantacoderForCausalLM(nn.Module):
self
.
transformer
=
FlashSantacoderModel
(
config
,
process_group
)
self
.
lm_head
=
FastLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
self
.
lm_head
=
FastLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
def
post_load_weights
(
self
):
self
.
transformer
.
post_load_weights
()
...
...
server/text_generation_server/models/flash_santacoder.py
View file @
3f2542bb
...
...
@@ -9,7 +9,7 @@ from typing import Optional, List
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models.custom_modeling.flash_santacoder_modeling
import
(
FlashSantacoderForCausalLM
FlashSantacoderForCausalLM
,
)
from
text_generation_server.utils
import
(
weight_files
,
...
...
@@ -37,8 +37,9 @@ class FlashSantacoder(FlashCausalLM):
)
config
=
AutoConfig
.
from_pretrained
(
model_id
,
revision
=
revision
,
trust_remote_code
=
True
# Needed as the config is not part of Transformers
model_id
,
revision
=
revision
,
trust_remote_code
=
True
,
# Needed as the config is not part of Transformers
)
# We do not use from_pretrained as we modified the model internal module layout
...
...
@@ -91,7 +92,12 @@ class FlashSantacoder(FlashCausalLM):
current_parameter_tensor
=
None
if
current_parameter_tensor
is
not
None
:
if
"c_fc.weight"
in
key
or
"c_proj.weight"
in
key
or
"q_attn.weight"
in
key
or
"kv_attn.weight"
in
key
:
if
(
"c_fc.weight"
in
key
or
"c_proj.weight"
in
key
or
"q_attn.weight"
in
key
or
"kv_attn.weight"
in
key
):
# Tranpose as we use nn.Linear instead of Conv1D
value
=
value
.
T
...
...
@@ -99,11 +105,18 @@ class FlashSantacoder(FlashCausalLM):
# Init qkv
if
"attn.weight"
in
final_key
:
module
.
_parameters
[
param_name
]
=
value
.
new_empty
(
(
model
.
transformer
.
head_size
*
(
model
.
transformer
.
num_heads
+
2
),
value
.
shape
[
1
])
(
model
.
transformer
.
head_size
*
(
model
.
transformer
.
num_heads
+
2
),
value
.
shape
[
1
],
)
)
elif
"attn.bias"
in
final_key
:
module
.
_parameters
[
param_name
]
=
value
.
new_empty
(
(
model
.
transformer
.
head_size
*
(
model
.
transformer
.
num_heads
+
2
))
(
model
.
transformer
.
head_size
*
(
model
.
transformer
.
num_heads
+
2
)
)
)
# Copy to correct slice
...
...
@@ -113,11 +126,11 @@ class FlashSantacoder(FlashCausalLM):
module
.
_parameters
[
param_name
][:
value
.
shape
[
0
]]
=
value
elif
"kv_attn.weight"
in
key
:
module
.
_parameters
[
param_name
][
model
.
transformer
.
head_size
*
model
.
transformer
.
num_heads
:
model
.
transformer
.
head_size
*
model
.
transformer
.
num_heads
:
]
=
value
elif
"kv_attn.bias"
in
key
:
module
.
_parameters
[
param_name
][
model
.
transformer
.
head_size
*
model
.
transformer
.
num_heads
:
model
.
transformer
.
head_size
*
model
.
transformer
.
num_heads
:
]
=
value
else
:
if
current_parameter_tensor
.
shape
!=
value
.
shape
:
...
...
server/text_generation_server/utils/tokens.py
View file @
3f2542bb
...
...
@@ -110,6 +110,7 @@ class NextTokenChooser:
class
StopSequenceCriteria
:
def
__init__
(
self
,
stop_sequence
:
str
):
stop_sequence
=
re
.
escape
(
stop_sequence
)
self
.
regex
=
re
.
compile
(
f
".*
{
stop_sequence
}
$"
)
def
__call__
(
self
,
output
:
str
)
->
bool
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment