Unverified Commit cd298bc5 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: Support sampling seeding (#37)


Co-authored-by: default avatarYannic Kilcher <yk@users.noreply.github.com>
parent 1539d3cb
......@@ -36,6 +36,8 @@ message NextTokenChooserParameters {
float top_p = 3;
/// apply sampling on the logits
bool do_sample = 4;
/// random seed for sampling
optional uint64 seed = 5;
}
message StoppingCriteriaParameters {
......@@ -82,6 +84,8 @@ message GeneratedText {
repeated float logprobs = 6;
/// Finish reason
string finish_reason = 7;
/// Seed
optional uint64 seed = 8;
}
message GenerateRequest {
......
use std::fs;
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("cargo:rerun-if-changed=../../proto/generate.proto");
fs::create_dir("src/pb").unwrap_or(());
tonic_build::configure()
.build_client(true)
......
......@@ -191,6 +191,7 @@ fn send_generated(finished: Vec<GeneratedText>, entries: &mut IntMap<u64, Entry>
tokens: output.tokens,
logprobs: output.logprobs,
finish_reason: output.finish_reason,
seed: output.seed,
queued: entry.time,
start: entry.batch_time.unwrap(), // unwrap is always valid
end: Instant::now(),
......@@ -208,6 +209,7 @@ pub(crate) struct InferResponse {
pub(crate) tokens: Vec<String>,
pub(crate) logprobs: Vec<f32>,
pub(crate) finish_reason: String,
pub(crate) seed: Option<u64>,
pub(crate) queued: Instant,
pub(crate) start: Instant,
pub(crate) end: Instant,
......
......@@ -166,6 +166,7 @@ impl From<&GenerateParameters> for NextTokenChooserParameters {
top_k: parameters.top_k as u32,
top_p: parameters.top_p,
do_sample: parameters.do_sample,
seed: parameters.seed,
}
}
}
......
......@@ -25,6 +25,8 @@ pub(crate) struct GenerateParameters {
pub stop: Vec<String>,
#[serde(default)]
pub details: bool,
#[serde(default)]
pub seed: Option<u64>,
}
fn default_temperature() -> f32 {
......@@ -56,6 +58,7 @@ fn default_parameters() -> GenerateParameters {
max_new_tokens: default_max_new_tokens(),
stop: vec![],
details: false,
seed: None,
}
}
......@@ -70,6 +73,7 @@ pub(crate) struct GenerateRequest {
pub(crate) struct Details {
pub finish_reason: String,
pub generated_tokens: u32,
pub seed: Option<u64>,
pub tokens: Vec<(u32, String, f32)>,
}
......
......@@ -55,6 +55,7 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
max_new_tokens: 1,
stop: vec![],
details: false,
seed: None,
},
},
)
......@@ -70,7 +71,8 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
validation_time,
queue_time,
inference_time,
time_per_token
time_per_token,
seed
)
)]
async fn generate(
......@@ -118,6 +120,7 @@ async fn generate(
.map(|((id, text), logprob)| (id, text, logprob))
.collect();
Some(Details {
seed: response.seed,
finish_reason: response.finish_reason,
generated_tokens: response.generated_tokens,
tokens,
......@@ -162,6 +165,7 @@ async fn generate(
tracing::Span::current().record("queue_time", format!("{:?}", queue_time));
tracing::Span::current().record("inference_time", format!("{:?}", inference_time));
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
tracing::Span::current().record("seed", format!("{:?}", response.seed));
tracing::info!("Output: {}", response.output_text);
// Send response
......
......@@ -234,7 +234,9 @@ class BLOOMSharded(BLOOM):
if name == "word_embeddings.weight":
model.lm_head._parameters["weight"] = tensor
def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
......
......@@ -7,7 +7,7 @@ from typing import Optional, Tuple, List, Type
from text_generation.models import Model
from text_generation.models.types import GeneratedText, Batch
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
@dataclass
......@@ -296,7 +296,10 @@ class CausalLM(Model):
)
with context_manager():
logits, past = self.forward(
batch.input_ids, batch.attention_mask, batch.position_ids, batch.past_key_values
batch.input_ids,
batch.attention_mask,
batch.position_ids,
batch.past_key_values,
)
# List of indices to cache
......@@ -373,6 +376,12 @@ class CausalLM(Model):
1
).tolist()
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
# Add to the list of finished generations with the original request
generated_texts.append(
GeneratedText(
......@@ -383,6 +392,7 @@ class CausalLM(Model):
token_ids=token_ids.squeeze(1).tolist(),
logprobs=logprobs,
reason=reason,
seed=seed,
)
)
# add to the next batch
......
......@@ -333,7 +333,9 @@ class GalacticaSharded(Galactica):
if name == "model.decoder.embed_tokens.weight":
model.lm_head._parameters["weight"] = tensor
def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
......
......@@ -39,12 +39,16 @@ class SantaCoder(CausalLM):
}
)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
load_in_8bit=quantize,
trust_remote_code=True, # required
).to(device).eval()
self.model = (
AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
load_in_8bit=quantize,
trust_remote_code=True, # required
)
.to(device)
.eval()
)
super(CausalLM, self).__init__(
tokenizer=tokenizer,
......
......@@ -7,7 +7,7 @@ from typing import Optional, Tuple, List, Type
from text_generation.models import Model
from text_generation.models.types import GeneratedText, Batch
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
@dataclass
......@@ -451,6 +451,13 @@ class Seq2SeqLM(Model):
logprobs = [float("nan")] + decoder_logprobs[
-decoder_input_length:
].tolist()
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
# Add to the list of finished generations with the original request
generated_texts.append(
GeneratedText(
......@@ -461,6 +468,7 @@ class Seq2SeqLM(Model):
token_ids=token_ids.tolist(),
logprobs=logprobs,
reason=reason,
seed=seed,
)
)
# add to the next batch
......
......@@ -2,7 +2,7 @@ import torch
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List
from typing import List, Optional
from transformers import PreTrainedTokenizerBase
......@@ -39,6 +39,7 @@ class GeneratedText:
token_ids: List[int]
logprobs: List[float]
reason: str
seed: Optional[int]
def to_pb(self) -> generate_pb2.GeneratedText:
return generate_pb2.GeneratedText(
......@@ -49,4 +50,5 @@ class GeneratedText:
token_ids=self.token_ids,
logprobs=self.logprobs,
finish_reason=self.reason,
seed=self.seed,
)
......@@ -24,11 +24,24 @@ from text_generation.pb import generate_pb2
class Sampling:
def __init__(self, seed: Optional[int] = None):
self.generator = torch.Generator()
if seed is not None:
self.generator.manual_seed(seed)
else:
self.generator.seed()
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
next_tokens = torch.multinomial(
probs, num_samples=1, generator=self.generator
).squeeze(1)
return next_tokens
@property
def seed(self) -> int:
return self.generator.initial_seed()
class Greedy:
def __call__(self, logits):
......@@ -36,7 +49,9 @@ class Greedy:
class NextTokenChooser:
def __init__(self, temperature=1.0, top_k=None, top_p=None, do_sample=False):
def __init__(
self, temperature=1.0, top_k=None, top_p=None, do_sample=False, seed=None
):
warpers = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
......@@ -53,7 +68,7 @@ class NextTokenChooser:
sampling = True
self.warpers = warpers
self.choice = Sampling() if sampling else Greedy()
self.choice = Sampling(seed) if sampling else Greedy()
def __call__(self, input_ids, scores):
# Warp logits
......@@ -66,11 +81,14 @@ class NextTokenChooser:
@classmethod
def from_pb(cls, pb: generate_pb2.NextTokenChooserParameters) -> "NextTokenChooser":
# handle protobuf making default values 0
seed = pb.seed if pb.HasField("seed") else None
return NextTokenChooser(
temperature=pb.temperature,
top_k=pb.top_k,
top_p=pb.top_p,
do_sample=pb.do_sample,
seed=seed,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment