Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
cd298bc5
Unverified
Commit
cd298bc5
authored
Jan 30, 2023
by
OlivierDehaene
Committed by
GitHub
Jan 30, 2023
Browse files
feat: Support sampling seeding (#37)
Co-authored-by:
Yannic Kilcher
<
yk@users.noreply.github.com
>
parent
1539d3cb
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
78 additions
and
16 deletions
+78
-16
proto/generate.proto
proto/generate.proto
+4
-0
router/client/build.rs
router/client/build.rs
+1
-0
router/src/batcher.rs
router/src/batcher.rs
+2
-0
router/src/db.rs
router/src/db.rs
+1
-0
router/src/lib.rs
router/src/lib.rs
+4
-0
router/src/server.rs
router/src/server.rs
+5
-1
server/text_generation/models/bloom.py
server/text_generation/models/bloom.py
+3
-1
server/text_generation/models/causal_lm.py
server/text_generation/models/causal_lm.py
+12
-2
server/text_generation/models/galactica.py
server/text_generation/models/galactica.py
+3
-1
server/text_generation/models/santacoder.py
server/text_generation/models/santacoder.py
+10
-6
server/text_generation/models/seq2seq_lm.py
server/text_generation/models/seq2seq_lm.py
+9
-1
server/text_generation/models/types.py
server/text_generation/models/types.py
+3
-1
server/text_generation/utils.py
server/text_generation/utils.py
+21
-3
No files found.
proto/generate.proto
View file @
cd298bc5
...
...
@@ -36,6 +36,8 @@ message NextTokenChooserParameters {
float
top_p
=
3
;
/// apply sampling on the logits
bool
do_sample
=
4
;
/// random seed for sampling
optional
uint64
seed
=
5
;
}
message
StoppingCriteriaParameters
{
...
...
@@ -82,6 +84,8 @@ message GeneratedText {
repeated
float
logprobs
=
6
;
/// Finish reason
string
finish_reason
=
7
;
/// Seed
optional
uint64
seed
=
8
;
}
message
GenerateRequest
{
...
...
router/client/build.rs
View file @
cd298bc5
use
std
::
fs
;
fn
main
()
->
Result
<
(),
Box
<
dyn
std
::
error
::
Error
>>
{
println!
(
"cargo:rerun-if-changed=../../proto/generate.proto"
);
fs
::
create_dir
(
"src/pb"
)
.unwrap_or
(());
tonic_build
::
configure
()
.build_client
(
true
)
...
...
router/src/batcher.rs
View file @
cd298bc5
...
...
@@ -191,6 +191,7 @@ fn send_generated(finished: Vec<GeneratedText>, entries: &mut IntMap<u64, Entry>
tokens
:
output
.tokens
,
logprobs
:
output
.logprobs
,
finish_reason
:
output
.finish_reason
,
seed
:
output
.seed
,
queued
:
entry
.time
,
start
:
entry
.batch_time
.unwrap
(),
// unwrap is always valid
end
:
Instant
::
now
(),
...
...
@@ -208,6 +209,7 @@ pub(crate) struct InferResponse {
pub
(
crate
)
tokens
:
Vec
<
String
>
,
pub
(
crate
)
logprobs
:
Vec
<
f32
>
,
pub
(
crate
)
finish_reason
:
String
,
pub
(
crate
)
seed
:
Option
<
u64
>
,
pub
(
crate
)
queued
:
Instant
,
pub
(
crate
)
start
:
Instant
,
pub
(
crate
)
end
:
Instant
,
...
...
router/src/db.rs
View file @
cd298bc5
...
...
@@ -166,6 +166,7 @@ impl From<&GenerateParameters> for NextTokenChooserParameters {
top_k
:
parameters
.top_k
as
u32
,
top_p
:
parameters
.top_p
,
do_sample
:
parameters
.do_sample
,
seed
:
parameters
.seed
,
}
}
}
...
...
router/src/lib.rs
View file @
cd298bc5
...
...
@@ -25,6 +25,8 @@ pub(crate) struct GenerateParameters {
pub
stop
:
Vec
<
String
>
,
#[serde(default)]
pub
details
:
bool
,
#[serde(default)]
pub
seed
:
Option
<
u64
>
,
}
fn
default_temperature
()
->
f32
{
...
...
@@ -56,6 +58,7 @@ fn default_parameters() -> GenerateParameters {
max_new_tokens
:
default_max_new_tokens
(),
stop
:
vec!
[],
details
:
false
,
seed
:
None
,
}
}
...
...
@@ -70,6 +73,7 @@ pub(crate) struct GenerateRequest {
pub
(
crate
)
struct
Details
{
pub
finish_reason
:
String
,
pub
generated_tokens
:
u32
,
pub
seed
:
Option
<
u64
>
,
pub
tokens
:
Vec
<
(
u32
,
String
,
f32
)
>
,
}
...
...
router/src/server.rs
View file @
cd298bc5
...
...
@@ -55,6 +55,7 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
max_new_tokens
:
1
,
stop
:
vec!
[],
details
:
false
,
seed
:
None
,
},
},
)
...
...
@@ -70,7 +71,8 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
validation_time,
queue_time,
inference_time,
time_per_token
time_per_token,
seed
)
)]
async
fn
generate
(
...
...
@@ -118,6 +120,7 @@ async fn generate(
.map
(|((
id
,
text
),
logprob
)|
(
id
,
text
,
logprob
))
.collect
();
Some
(
Details
{
seed
:
response
.seed
,
finish_reason
:
response
.finish_reason
,
generated_tokens
:
response
.generated_tokens
,
tokens
,
...
...
@@ -162,6 +165,7 @@ async fn generate(
tracing
::
Span
::
current
()
.record
(
"queue_time"
,
format!
(
"{:?}"
,
queue_time
));
tracing
::
Span
::
current
()
.record
(
"inference_time"
,
format!
(
"{:?}"
,
inference_time
));
tracing
::
Span
::
current
()
.record
(
"time_per_token"
,
format!
(
"{:?}"
,
time_per_token
));
tracing
::
Span
::
current
()
.record
(
"seed"
,
format!
(
"{:?}"
,
response
.seed
));
tracing
::
info!
(
"Output: {}"
,
response
.output_text
);
// Send response
...
...
server/text_generation/models/bloom.py
View file @
cd298bc5
...
...
@@ -234,7 +234,9 @@ class BLOOMSharded(BLOOM):
if
name
==
"word_embeddings.weight"
:
model
.
lm_head
.
_parameters
[
"weight"
]
=
tensor
def
forward
(
self
,
input_ids
,
attention_mask
,
position_ids
,
past_key_values
:
Optional
=
None
):
def
forward
(
self
,
input_ids
,
attention_mask
,
position_ids
,
past_key_values
:
Optional
=
None
):
outputs
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
...
...
server/text_generation/models/causal_lm.py
View file @
cd298bc5
...
...
@@ -7,7 +7,7 @@ from typing import Optional, Tuple, List, Type
from
text_generation.models
import
Model
from
text_generation.models.types
import
GeneratedText
,
Batch
from
text_generation.pb
import
generate_pb2
from
text_generation.utils
import
NextTokenChooser
,
StoppingCriteria
from
text_generation.utils
import
NextTokenChooser
,
StoppingCriteria
,
Sampling
@
dataclass
...
...
@@ -296,7 +296,10 @@ class CausalLM(Model):
)
with
context_manager
():
logits
,
past
=
self
.
forward
(
batch
.
input_ids
,
batch
.
attention_mask
,
batch
.
position_ids
,
batch
.
past_key_values
batch
.
input_ids
,
batch
.
attention_mask
,
batch
.
position_ids
,
batch
.
past_key_values
,
)
# List of indices to cache
...
...
@@ -373,6 +376,12 @@ class CausalLM(Model):
1
).
tolist
()
# Get seed
if
isinstance
(
next_token_chooser
.
choice
,
Sampling
):
seed
=
next_token_chooser
.
choice
.
seed
else
:
seed
=
None
# Add to the list of finished generations with the original request
generated_texts
.
append
(
GeneratedText
(
...
...
@@ -383,6 +392,7 @@ class CausalLM(Model):
token_ids
=
token_ids
.
squeeze
(
1
).
tolist
(),
logprobs
=
logprobs
,
reason
=
reason
,
seed
=
seed
,
)
)
# add to the next batch
...
...
server/text_generation/models/galactica.py
View file @
cd298bc5
...
...
@@ -333,7 +333,9 @@ class GalacticaSharded(Galactica):
if
name
==
"model.decoder.embed_tokens.weight"
:
model
.
lm_head
.
_parameters
[
"weight"
]
=
tensor
def
forward
(
self
,
input_ids
,
attention_mask
,
position_ids
,
past_key_values
:
Optional
=
None
):
def
forward
(
self
,
input_ids
,
attention_mask
,
position_ids
,
past_key_values
:
Optional
=
None
):
outputs
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
...
...
server/text_generation/models/santacoder.py
View file @
cd298bc5
...
...
@@ -39,12 +39,16 @@ class SantaCoder(CausalLM):
}
)
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
self
.
model
=
(
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
,
load_in_8bit
=
quantize
,
trust_remote_code
=
True
,
# required
).
to
(
device
).
eval
()
)
.
to
(
device
)
.
eval
()
)
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
...
...
server/text_generation/models/seq2seq_lm.py
View file @
cd298bc5
...
...
@@ -7,7 +7,7 @@ from typing import Optional, Tuple, List, Type
from
text_generation.models
import
Model
from
text_generation.models.types
import
GeneratedText
,
Batch
from
text_generation.pb
import
generate_pb2
from
text_generation.utils
import
NextTokenChooser
,
StoppingCriteria
from
text_generation.utils
import
NextTokenChooser
,
StoppingCriteria
,
Sampling
@
dataclass
...
...
@@ -451,6 +451,13 @@ class Seq2SeqLM(Model):
logprobs
=
[
float
(
"nan"
)]
+
decoder_logprobs
[
-
decoder_input_length
:
].
tolist
()
# Get seed
if
isinstance
(
next_token_chooser
.
choice
,
Sampling
):
seed
=
next_token_chooser
.
choice
.
seed
else
:
seed
=
None
# Add to the list of finished generations with the original request
generated_texts
.
append
(
GeneratedText
(
...
...
@@ -461,6 +468,7 @@ class Seq2SeqLM(Model):
token_ids
=
token_ids
.
tolist
(),
logprobs
=
logprobs
,
reason
=
reason
,
seed
=
seed
,
)
)
# add to the next batch
...
...
server/text_generation/models/types.py
View file @
cd298bc5
...
...
@@ -2,7 +2,7 @@ import torch
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
List
from
typing
import
List
,
Optional
from
transformers
import
PreTrainedTokenizerBase
...
...
@@ -39,6 +39,7 @@ class GeneratedText:
token_ids
:
List
[
int
]
logprobs
:
List
[
float
]
reason
:
str
seed
:
Optional
[
int
]
def
to_pb
(
self
)
->
generate_pb2
.
GeneratedText
:
return
generate_pb2
.
GeneratedText
(
...
...
@@ -49,4 +50,5 @@ class GeneratedText:
token_ids
=
self
.
token_ids
,
logprobs
=
self
.
logprobs
,
finish_reason
=
self
.
reason
,
seed
=
self
.
seed
,
)
server/text_generation/utils.py
View file @
cd298bc5
...
...
@@ -24,11 +24,24 @@ from text_generation.pb import generate_pb2
class
Sampling
:
def
__init__
(
self
,
seed
:
Optional
[
int
]
=
None
):
self
.
generator
=
torch
.
Generator
()
if
seed
is
not
None
:
self
.
generator
.
manual_seed
(
seed
)
else
:
self
.
generator
.
seed
()
def
__call__
(
self
,
logits
):
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
1
).
squeeze
(
1
)
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
1
,
generator
=
self
.
generator
).
squeeze
(
1
)
return
next_tokens
@
property
def
seed
(
self
)
->
int
:
return
self
.
generator
.
initial_seed
()
class
Greedy
:
def
__call__
(
self
,
logits
):
...
...
@@ -36,7 +49,9 @@ class Greedy:
class
NextTokenChooser
:
def
__init__
(
self
,
temperature
=
1.0
,
top_k
=
None
,
top_p
=
None
,
do_sample
=
False
):
def
__init__
(
self
,
temperature
=
1.0
,
top_k
=
None
,
top_p
=
None
,
do_sample
=
False
,
seed
=
None
):
warpers
=
LogitsProcessorList
()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
...
...
@@ -53,7 +68,7 @@ class NextTokenChooser:
sampling
=
True
self
.
warpers
=
warpers
self
.
choice
=
Sampling
()
if
sampling
else
Greedy
()
self
.
choice
=
Sampling
(
seed
)
if
sampling
else
Greedy
()
def
__call__
(
self
,
input_ids
,
scores
):
# Warp logits
...
...
@@ -66,11 +81,14 @@ class NextTokenChooser:
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
NextTokenChooserParameters
)
->
"NextTokenChooser"
:
# handle protobuf making default values 0
seed
=
pb
.
seed
if
pb
.
HasField
(
"seed"
)
else
None
return
NextTokenChooser
(
temperature
=
pb
.
temperature
,
top_k
=
pb
.
top_k
,
top_p
=
pb
.
top_p
,
do_sample
=
pb
.
do_sample
,
seed
=
seed
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment