Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
54fec931
Unverified
Commit
54fec931
authored
Jan 31, 2023
by
OlivierDehaene
Committed by
GitHub
Jan 31, 2023
Browse files
fix(server): fix seeding with multiple shards (#44)
parent
03bdf182
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
91 additions
and
86 deletions
+91
-86
Cargo.lock
Cargo.lock
+1
-0
proto/generate.proto
proto/generate.proto
+1
-1
router/Cargo.toml
router/Cargo.toml
+1
-0
router/src/db.rs
router/src/db.rs
+2
-1
router/src/validation.rs
router/src/validation.rs
+13
-2
server/poetry.lock
server/poetry.lock
+67
-65
server/pyproject.toml
server/pyproject.toml
+1
-1
server/text_generation/models/bloom.py
server/text_generation/models/bloom.py
+0
-2
server/text_generation/models/galactica.py
server/text_generation/models/galactica.py
+0
-1
server/text_generation/utils.py
server/text_generation/utils.py
+5
-13
No files found.
Cargo.lock
View file @
54fec931
...
...
@@ -1834,6 +1834,7 @@ dependencies = [
"futures",
"nohash-hasher",
"parking_lot",
"rand",
"serde",
"serde_json",
"text-generation-client",
...
...
proto/generate.proto
View file @
54fec931
...
...
@@ -37,7 +37,7 @@ message NextTokenChooserParameters {
/// apply sampling on the logits
bool
do_sample
=
4
;
/// random seed for sampling
optional
uint64
seed
=
5
;
uint64
seed
=
5
;
}
message
StoppingCriteriaParameters
{
...
...
router/Cargo.toml
View file @
54fec931
...
...
@@ -19,6 +19,7 @@ clap = { version = "4.0.15", features = ["derive", "env"] }
futures
=
"0.3.24"
nohash-hasher
=
"0.2.0"
parking_lot
=
"0.12.1"
rand
=
"0.8.5"
serde
=
"1.0.145"
serde_json
=
"1.0.85"
thiserror
=
"1.0.37"
...
...
router/src/db.rs
View file @
54fec931
...
...
@@ -166,7 +166,8 @@ impl From<&GenerateParameters> for NextTokenChooserParameters {
top_k
:
parameters
.top_k
as
u32
,
top_p
:
parameters
.top_p
,
do_sample
:
parameters
.do_sample
,
seed
:
parameters
.seed
,
// FIXME: remove unwrap
seed
:
parameters
.seed
.unwrap
(),
}
}
}
...
...
router/src/validation.rs
View file @
54fec931
...
...
@@ -2,6 +2,8 @@
use
crate
::{
ErrorResponse
,
GenerateRequest
};
use
axum
::
http
::
StatusCode
;
use
axum
::
Json
;
use
rand
::
rngs
::
ThreadRng
;
use
rand
::
Rng
;
use
thiserror
::
Error
;
use
tokenizers
::
tokenizer
::
Tokenizer
;
use
tokio
::
sync
::{
mpsc
,
oneshot
};
...
...
@@ -92,18 +94,22 @@ fn validation_worker(
max_input_length
:
usize
,
mut
receiver
:
mpsc
::
Receiver
<
ValidationRequest
>
,
)
{
// Seed rng
let
mut
rng
=
rand
::
thread_rng
();
// Loop over requests
while
let
Some
((
request
,
response_tx
))
=
receiver
.blocking_recv
()
{
response_tx
.send
(
validate
(
request
,
&
tokenizer
,
max_input_length
))
.send
(
validate
(
request
,
&
tokenizer
,
max_input_length
,
&
mut
rng
))
.unwrap_or
(())
}
}
fn
validate
(
request
:
GenerateRequest
,
mut
request
:
GenerateRequest
,
tokenizer
:
&
Tokenizer
,
max_input_length
:
usize
,
rng
:
&
mut
ThreadRng
,
)
->
Result
<
(
usize
,
GenerateRequest
),
ValidationError
>
{
if
request
.parameters.temperature
<=
0.0
{
return
Err
(
ValidationError
::
Temperature
);
...
...
@@ -124,6 +130,11 @@ fn validate(
));
}
// If seed is None, assign a random one
if
request
.parameters.seed
.is_none
()
{
request
.parameters.seed
=
Some
(
rng
.gen
());
}
// Get the number of tokens in the input
match
tokenizer
.encode
(
request
.inputs
.clone
(),
true
)
{
Ok
(
inputs
)
=>
{
...
...
server/poetry.lock
View file @
54fec931
This diff is collapsed.
Click to expand it.
server/pyproject.toml
View file @
54fec931
...
...
@@ -15,7 +15,7 @@ grpcio-status = "^1.51.1"
grpcio-reflection
=
"^1.51.1"
grpc-interceptor
=
"^0.15.0"
typer
=
"^0.6.1"
accelerate
=
"^0.1
2
.0"
accelerate
=
"^0.1
5
.0"
bitsandbytes
=
"^0.35.1"
safetensors
=
"^0.2.4"
loguru
=
"^0.6.0"
...
...
server/text_generation/models/bloom.py
View file @
54fec931
...
...
@@ -33,8 +33,6 @@ try:
except
Exception
as
e
:
HAS_BITS_AND_BYTES
=
False
torch
.
manual_seed
(
0
)
class
BloomCausalLMBatch
(
CausalLMBatch
):
@
classmethod
...
...
server/text_generation/models/galactica.py
View file @
54fec931
...
...
@@ -36,7 +36,6 @@ try:
except
Exception
as
e
:
HAS_BITS_AND_BYTES
=
False
torch
.
manual_seed
(
0
)
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
...
...
server/text_generation/utils.py
View file @
54fec931
...
...
@@ -24,12 +24,10 @@ from text_generation.pb import generate_pb2
class
Sampling
:
def
__init__
(
self
,
seed
:
Optional
[
int
]
=
None
,
device
:
str
=
"cpu"
):
def
__init__
(
self
,
seed
:
int
,
device
:
str
=
"cpu"
):
self
.
generator
=
torch
.
Generator
(
device
)
if
seed
is
not
None
:
self
.
generator
.
manual_seed
(
seed
)
else
:
self
.
generator
.
seed
()
self
.
generator
.
manual_seed
(
seed
)
self
.
seed
=
seed
def
__call__
(
self
,
logits
):
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
...
...
@@ -38,10 +36,6 @@ class Sampling:
).
squeeze
(
1
)
return
next_tokens
@
property
def
seed
(
self
)
->
int
:
return
self
.
generator
.
initial_seed
()
class
Greedy
:
def
__call__
(
self
,
logits
):
...
...
@@ -55,7 +49,7 @@ class NextTokenChooser:
top_k
=
None
,
top_p
=
None
,
do_sample
=
False
,
seed
=
None
,
seed
=
0
,
device
=
"cpu"
,
):
warpers
=
LogitsProcessorList
()
...
...
@@ -89,14 +83,12 @@ class NextTokenChooser:
def
from_pb
(
cls
,
pb
:
generate_pb2
.
NextTokenChooserParameters
,
device
:
torch
.
device
)
->
"NextTokenChooser"
:
# handle protobuf making default values 0
seed
=
pb
.
seed
if
pb
.
HasField
(
"seed"
)
else
None
return
NextTokenChooser
(
temperature
=
pb
.
temperature
,
top_k
=
pb
.
top_k
,
top_p
=
pb
.
top_p
,
do_sample
=
pb
.
do_sample
,
seed
=
seed
,
seed
=
pb
.
seed
,
device
=
str
(
device
),
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment