Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
611e21cb
Unverified
Commit
611e21cb
authored
Dec 16, 2022
by
OlivierDehaene
Committed by
GitHub
Dec 16, 2022
Browse files
fix(server): Fix stop sequences (#11)
parent
3e2e6240
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
77 additions
and
76 deletions
+77
-76
launcher/tests/integration_tests.rs
launcher/tests/integration_tests.rs
+22
-11
server/tests/test_utils.py
server/tests/test_utils.py
+19
-32
server/text_generation/models/causal_lm.py
server/text_generation/models/causal_lm.py
+6
-1
server/text_generation/models/seq2seq_lm.py
server/text_generation/models/seq2seq_lm.py
+6
-1
server/text_generation/utils.py
server/text_generation/utils.py
+24
-31
No files found.
launcher/tests/integration_tests.rs
View file @
611e21cb
use
std
::
fs
::
File
;
use
float_eq
::
assert_float_eq
;
use
serde
::
Deserialize
;
use
serde_json
::
Value
;
use
serde_json
::
Value
;
use
std
::
fs
::
File
;
use
std
::
io
::{
BufRead
,
BufReader
};
use
std
::
io
::{
BufRead
,
BufReader
};
use
std
::
path
::
PathBuf
;
use
std
::
path
::
PathBuf
;
use
std
::
thread
;
use
std
::
thread
;
use
std
::
thread
::
sleep
;
use
std
::
thread
::
sleep
;
use
std
::
time
::
Duration
;
use
std
::
time
::
Duration
;
use
float_eq
::
assert_float_eq
;
use
subprocess
::{
Popen
,
PopenConfig
,
Redirection
};
use
subprocess
::{
Popen
,
PopenConfig
,
Redirection
};
use
serde
::
Deserialize
;
#[derive(Deserialize)]
#[derive(Deserialize)]
struct
Details
{
struct
Details
{
...
@@ -22,7 +22,6 @@ struct GeneratedText {
...
@@ -22,7 +22,6 @@ struct GeneratedText {
details
:
Details
,
details
:
Details
,
}
}
fn
start_launcher
(
model_name
:
String
,
num_shard
:
usize
,
port
:
usize
,
master_port
:
usize
)
->
Popen
{
fn
start_launcher
(
model_name
:
String
,
num_shard
:
usize
,
port
:
usize
,
master_port
:
usize
)
->
Popen
{
let
argv
=
vec!
[
let
argv
=
vec!
[
"text-generation-launcher"
.to_string
(),
"text-generation-launcher"
.to_string
(),
...
@@ -46,7 +45,7 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
...
@@ -46,7 +45,7 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
..
Default
::
default
()
..
Default
::
default
()
},
},
)
)
.expect
(
"Could not start launcher"
);
.expect
(
"Could not start launcher"
);
// Redirect STDOUT and STDERR to the console
// Redirect STDOUT and STDERR to the console
let
launcher_stdout
=
launcher
.stdout
.take
()
.unwrap
();
let
launcher_stdout
=
launcher
.stdout
.take
()
.unwrap
();
...
@@ -63,7 +62,7 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
...
@@ -63,7 +62,7 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
}
}
});
});
for
_
in
0
..
3
0
{
for
_
in
0
..
6
0
{
let
health
=
reqwest
::
blocking
::
get
(
format!
(
"http://localhost:{}/health"
,
port
));
let
health
=
reqwest
::
blocking
::
get
(
format!
(
"http://localhost:{}/health"
,
port
));
if
health
.is_ok
()
{
if
health
.is_ok
()
{
return
launcher
;
return
launcher
;
...
@@ -76,7 +75,12 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
...
@@ -76,7 +75,12 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
panic!
(
"failed to launch {}"
,
model_name
)
panic!
(
"failed to launch {}"
,
model_name
)
}
}
fn
test_model
(
model_name
:
String
,
num_shard
:
usize
,
port
:
usize
,
master_port
:
usize
)
->
GeneratedText
{
fn
test_model
(
model_name
:
String
,
num_shard
:
usize
,
port
:
usize
,
master_port
:
usize
,
)
->
GeneratedText
{
let
mut
launcher
=
start_launcher
(
model_name
,
num_shard
,
port
,
master_port
);
let
mut
launcher
=
start_launcher
(
model_name
,
num_shard
,
port
,
master_port
);
let
data
=
r#"
let
data
=
r#"
...
@@ -101,7 +105,6 @@ fn test_model(model_name: String, num_shard: usize, port: usize, master_port: us
...
@@ -101,7 +105,6 @@ fn test_model(model_name: String, num_shard: usize, port: usize, master_port: us
results
.pop
()
.unwrap
()
results
.pop
()
.unwrap
()
}
}
fn
read_json
(
name
:
&
str
)
->
GeneratedText
{
fn
read_json
(
name
:
&
str
)
->
GeneratedText
{
let
mut
d
=
PathBuf
::
from
(
env!
(
"CARGO_MANIFEST_DIR"
));
let
mut
d
=
PathBuf
::
from
(
env!
(
"CARGO_MANIFEST_DIR"
));
d
.push
(
"tests/"
);
d
.push
(
"tests/"
);
...
@@ -117,9 +120,17 @@ fn read_json(name: &str) -> GeneratedText {
...
@@ -117,9 +120,17 @@ fn read_json(name: &str) -> GeneratedText {
fn
compare_results
(
result
:
GeneratedText
,
expected
:
GeneratedText
)
{
fn
compare_results
(
result
:
GeneratedText
,
expected
:
GeneratedText
)
{
assert_eq!
(
result
.generated_text
,
expected
.generated_text
);
assert_eq!
(
result
.generated_text
,
expected
.generated_text
);
assert_eq!
(
result
.details.finish_reason
,
expected
.details.finish_reason
);
assert_eq!
(
result
.details.finish_reason
,
expected
.details.finish_reason
);
assert_eq!
(
result
.details.generated_tokens
,
expected
.details.generated_tokens
);
assert_eq!
(
result
.details.generated_tokens
,
for
(
token
,
expected_token
)
in
result
.details.tokens
.into_iter
()
.zip
(
expected
.details.tokens
.into_iter
())
{
expected
.details.generated_tokens
);
for
(
token
,
expected_token
)
in
result
.details
.tokens
.into_iter
()
.zip
(
expected
.details.tokens
.into_iter
())
{
assert_eq!
(
token
.0
,
expected_token
.0
);
assert_eq!
(
token
.0
,
expected_token
.0
);
assert_eq!
(
token
.1
,
expected_token
.1
);
assert_eq!
(
token
.1
,
expected_token
.1
);
if
let
Some
(
logprob
)
=
token
.2
{
if
let
Some
(
logprob
)
=
token
.2
{
...
...
server/tests/test_utils.py
View file @
611e21cb
...
@@ -11,46 +11,33 @@ from text_generation.utils import (
...
@@ -11,46 +11,33 @@ from text_generation.utils import (
def
test_stop_sequence_criteria
():
def
test_stop_sequence_criteria
():
criteria
=
StopSequenceCriteria
(
[
1
,
2
,
3
]
)
criteria
=
StopSequenceCriteria
(
"/test;"
)
assert
not
criteria
(
1
)
assert
not
criteria
(
"/"
)
assert
criteria
.
current_token_idx
==
1
assert
not
criteria
(
"/test"
)
assert
not
criteria
(
2
)
assert
criteria
(
"/test;"
)
assert
criteria
.
current_token_idx
==
2
assert
not
criteria
(
"/test; "
)
assert
criteria
(
3
)
assert
criteria
.
current_token_idx
==
3
def
test_stop_sequence_criteria_reset
():
def
test_stopping_criteria
():
criteria
=
StopSequenceCriteria
([
1
,
2
,
3
])
criteria
=
StoppingCriteria
(
0
,
[
StopSequenceCriteria
(
"/test;"
)],
max_new_tokens
=
5
)
assert
criteria
(
65827
,
"/test"
)
==
(
False
,
None
)
assert
not
criteria
(
1
)
assert
criteria
(
30
,
";"
)
==
(
True
,
"stop_sequence"
)
assert
criteria
.
current_token_idx
==
1
assert
not
criteria
(
2
)
assert
criteria
.
current_token_idx
==
2
assert
not
criteria
(
4
)
assert
criteria
.
current_token_idx
==
0
def
test_stop_sequence_criteria_empty
():
with
pytest
.
raises
(
ValueError
):
StopSequenceCriteria
([])
def
test_stopping_criteria
():
def
test_stopping_criteria_eos
():
criteria
=
StoppingCriteria
([
StopSequenceCriteria
([
1
,
2
,
3
])],
max_new_tokens
=
5
)
criteria
=
StoppingCriteria
(
0
,
[
StopSequenceCriteria
(
"/test;"
)],
max_new_tokens
=
5
)
assert
criteria
([
1
])
==
(
False
,
None
)
assert
criteria
(
1
,
""
)
==
(
False
,
None
)
assert
criteria
([
1
,
2
])
==
(
False
,
None
)
assert
criteria
(
0
,
""
)
==
(
True
,
"eos_token"
)
assert
criteria
([
1
,
2
,
3
])
==
(
True
,
"stop_sequence"
)
def
test_stopping_criteria_max
():
def
test_stopping_criteria_max
():
criteria
=
StoppingCriteria
([
StopSequenceCriteria
(
[
1
,
2
,
3
]
)],
max_new_tokens
=
5
)
criteria
=
StoppingCriteria
(
0
,
[
StopSequenceCriteria
(
"/test;"
)],
max_new_tokens
=
5
)
assert
criteria
(
[
1
]
)
==
(
False
,
None
)
assert
criteria
(
1
,
""
)
==
(
False
,
None
)
assert
criteria
(
[
1
,
1
]
)
==
(
False
,
None
)
assert
criteria
(
1
,
""
)
==
(
False
,
None
)
assert
criteria
(
[
1
,
1
,
1
]
)
==
(
False
,
None
)
assert
criteria
(
1
,
""
)
==
(
False
,
None
)
assert
criteria
(
[
1
,
1
,
1
,
1
]
)
==
(
False
,
None
)
assert
criteria
(
1
,
""
)
==
(
False
,
None
)
assert
criteria
(
[
1
,
1
,
1
,
1
,
1
]
)
==
(
True
,
"length"
)
assert
criteria
(
1
,
""
)
==
(
True
,
"length"
)
def
test_weight_hub_files
():
def
test_weight_hub_files
():
...
...
server/text_generation/models/causal_lm.py
View file @
611e21cb
...
@@ -345,7 +345,12 @@ class CausalLM(Model):
...
@@ -345,7 +345,12 @@ class CausalLM(Model):
all_logprobs
=
torch
.
cat
([
all_logprobs
,
next_token_logprob
])
all_logprobs
=
torch
.
cat
([
all_logprobs
,
next_token_logprob
])
# Evaluate stopping criteria
# Evaluate stopping criteria
stop
,
reason
=
stopping_criteria
(
all_input_ids
)
stop
,
reason
=
stopping_criteria
(
next_token
.
squeeze
(),
self
.
tokenizer
.
decode
(
next_token
.
squeeze
(),
clean_up_tokenization_spaces
=
False
),
)
if
stop
:
if
stop
:
# Decode all tokens
# Decode all tokens
output_text
=
self
.
tokenizer
.
decode
(
output_text
=
self
.
tokenizer
.
decode
(
...
...
server/text_generation/models/seq2seq_lm.py
View file @
611e21cb
...
@@ -441,7 +441,12 @@ class Seq2SeqLM(Model):
...
@@ -441,7 +441,12 @@ class Seq2SeqLM(Model):
decoder_logprobs
=
torch
.
cat
([
decoder_logprobs
,
next_token_logprob
])
decoder_logprobs
=
torch
.
cat
([
decoder_logprobs
,
next_token_logprob
])
# Evaluate stopping criteria
# Evaluate stopping criteria
stop
,
reason
=
stopping_criteria
(
decoder_input_ids
)
stop
,
reason
=
stopping_criteria
(
next_token
.
squeeze
(),
self
.
tokenizer
.
decode
(
next_token
.
squeeze
(),
clean_up_tokenization_spaces
=
False
),
)
if
stop
:
if
stop
:
# Slice with decoder_input_length to remove padding
# Slice with decoder_input_length to remove padding
# Decode all tokens
# Decode all tokens
...
...
server/text_generation/utils.py
View file @
611e21cb
import
concurrent
import
concurrent
import
os
import
os
import
re
import
torch
import
torch
import
torch.distributed
import
torch.distributed
...
@@ -74,43 +75,39 @@ class NextTokenChooser:
...
@@ -74,43 +75,39 @@ class NextTokenChooser:
class
StopSequenceCriteria
:
class
StopSequenceCriteria
:
def
__init__
(
self
,
tokens
:
List
[
int
]):
def
__init__
(
self
,
stop_sequence
:
str
):
if
not
tokens
:
self
.
regex
=
re
.
compile
(
f
".*
{
stop_sequence
}
$"
)
raise
ValueError
(
"tokens cannot be empty"
)
def
__call__
(
self
,
output
:
str
)
->
bool
:
self
.
tokens
=
tokens
if
self
.
regex
.
findall
(
output
):
self
.
current_token_idx
=
0
def
__call__
(
self
,
last_token
:
int
)
->
bool
:
if
last_token
==
self
.
tokens
[
self
.
current_token_idx
]:
# Increase idx to go to next token
self
.
current_token_idx
+=
1
else
:
# Reset to first token of the stopping sequence
self
.
current_token_idx
=
0
if
self
.
current_token_idx
==
len
(
self
.
tokens
):
# We matched the entire sequence without resetting
return
True
return
True
return
False
return
False
class
StoppingCriteria
:
class
StoppingCriteria
:
def
__init__
(
def
__init__
(
self
,
stop_sequence_criterias
:
List
[
StopSequenceCriteria
],
max_new_tokens
=
20
self
,
eos_token_id
:
int
,
stop_sequence_criterias
:
List
[
StopSequenceCriteria
],
max_new_tokens
=
20
,
):
):
self
.
eos_token_id
=
eos_token_id
self
.
stop_sequence_criterias
=
stop_sequence_criterias
self
.
stop_sequence_criterias
=
stop_sequence_criterias
self
.
max_new_tokens
=
max_new_tokens
self
.
max_new_tokens
=
max_new_tokens
self
.
current_tokens
=
0
self
.
current_tokens
=
0
self
.
current_output
=
""
def
__call__
(
self
,
all_ids
)
->
Tuple
[
bool
,
Optional
[
str
]]:
def
__call__
(
self
,
last_token
:
int
,
last_output
:
str
)
->
Tuple
[
bool
,
Optional
[
str
]]:
self
.
current_tokens
+=
1
self
.
current_tokens
+=
1
if
self
.
current_tokens
>=
self
.
max_new_tokens
:
if
self
.
current_tokens
>=
self
.
max_new_tokens
:
return
True
,
"length"
return
True
,
"length"
last_token
=
all_ids
[
-
1
]
if
last_token
==
self
.
eos_token_id
:
return
True
,
"eos_token"
self
.
current_output
+=
last_output
for
stop_sequence_criteria
in
self
.
stop_sequence_criterias
:
for
stop_sequence_criteria
in
self
.
stop_sequence_criterias
:
if
stop_sequence_criteria
(
last_token
):
if
stop_sequence_criteria
(
self
.
current_output
):
return
True
,
"stop_sequence"
return
True
,
"stop_sequence"
return
False
,
None
return
False
,
None
...
@@ -119,16 +116,12 @@ class StoppingCriteria:
...
@@ -119,16 +116,12 @@ class StoppingCriteria:
def
from_pb
(
def
from_pb
(
cls
,
pb
:
generate_pb2
.
StoppingCriteriaParameters
,
tokenizer
:
AutoTokenizer
cls
,
pb
:
generate_pb2
.
StoppingCriteriaParameters
,
tokenizer
:
AutoTokenizer
)
->
"StoppingCriteria"
:
)
->
"StoppingCriteria"
:
stop_sequence_criterias
=
[]
stop_sequence_criterias
=
[
for
stop_sequence
in
pb
.
stop_sequences
:
StopSequenceCriteria
(
sequence
)
for
sequence
in
pb
.
stop_sequences
tokens
=
tokenizer
(
]
stop_sequence
,
padding
=
False
,
return_attention_mask
=
False
return
StoppingCriteria
(
).
input_ids
tokenizer
.
eos_token_id
,
stop_sequence_criterias
,
pb
.
max_new_tokens
if
tokens
:
)
stop_sequence_criterias
.
append
(
StopSequenceCriteria
(
tokens
))
stop_sequence_criterias
.
append
(
StopSequenceCriteria
([
tokenizer
.
eos_token_id
]))
return
StoppingCriteria
(
stop_sequence_criterias
,
pb
.
max_new_tokens
)
def
initialize_torch_distributed
():
def
initialize_torch_distributed
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment