"...source/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "49b23e1583346029bdd28e67b2fd146c9569a789"
Unverified Commit cd298bc5 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: Support sampling seeding (#37)


Co-authored-by: default avatarYannic Kilcher <yk@users.noreply.github.com>
parent 1539d3cb
...@@ -36,6 +36,8 @@ message NextTokenChooserParameters { ...@@ -36,6 +36,8 @@ message NextTokenChooserParameters {
float top_p = 3; float top_p = 3;
/// apply sampling on the logits /// apply sampling on the logits
bool do_sample = 4; bool do_sample = 4;
/// random seed for sampling
optional uint64 seed = 5;
} }
message StoppingCriteriaParameters { message StoppingCriteriaParameters {
...@@ -82,6 +84,8 @@ message GeneratedText { ...@@ -82,6 +84,8 @@ message GeneratedText {
repeated float logprobs = 6; repeated float logprobs = 6;
/// Finish reason /// Finish reason
string finish_reason = 7; string finish_reason = 7;
/// Seed
optional uint64 seed = 8;
} }
message GenerateRequest { message GenerateRequest {
......
use std::fs; use std::fs;
fn main() -> Result<(), Box<dyn std::error::Error>> { fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("cargo:rerun-if-changed=../../proto/generate.proto");
fs::create_dir("src/pb").unwrap_or(()); fs::create_dir("src/pb").unwrap_or(());
tonic_build::configure() tonic_build::configure()
.build_client(true) .build_client(true)
......
...@@ -191,6 +191,7 @@ fn send_generated(finished: Vec<GeneratedText>, entries: &mut IntMap<u64, Entry> ...@@ -191,6 +191,7 @@ fn send_generated(finished: Vec<GeneratedText>, entries: &mut IntMap<u64, Entry>
tokens: output.tokens, tokens: output.tokens,
logprobs: output.logprobs, logprobs: output.logprobs,
finish_reason: output.finish_reason, finish_reason: output.finish_reason,
seed: output.seed,
queued: entry.time, queued: entry.time,
start: entry.batch_time.unwrap(), // unwrap is always valid start: entry.batch_time.unwrap(), // unwrap is always valid
end: Instant::now(), end: Instant::now(),
...@@ -208,6 +209,7 @@ pub(crate) struct InferResponse { ...@@ -208,6 +209,7 @@ pub(crate) struct InferResponse {
pub(crate) tokens: Vec<String>, pub(crate) tokens: Vec<String>,
pub(crate) logprobs: Vec<f32>, pub(crate) logprobs: Vec<f32>,
pub(crate) finish_reason: String, pub(crate) finish_reason: String,
pub(crate) seed: Option<u64>,
pub(crate) queued: Instant, pub(crate) queued: Instant,
pub(crate) start: Instant, pub(crate) start: Instant,
pub(crate) end: Instant, pub(crate) end: Instant,
......
...@@ -166,6 +166,7 @@ impl From<&GenerateParameters> for NextTokenChooserParameters { ...@@ -166,6 +166,7 @@ impl From<&GenerateParameters> for NextTokenChooserParameters {
top_k: parameters.top_k as u32, top_k: parameters.top_k as u32,
top_p: parameters.top_p, top_p: parameters.top_p,
do_sample: parameters.do_sample, do_sample: parameters.do_sample,
seed: parameters.seed,
} }
} }
} }
......
...@@ -25,6 +25,8 @@ pub(crate) struct GenerateParameters { ...@@ -25,6 +25,8 @@ pub(crate) struct GenerateParameters {
pub stop: Vec<String>, pub stop: Vec<String>,
#[serde(default)] #[serde(default)]
pub details: bool, pub details: bool,
#[serde(default)]
pub seed: Option<u64>,
} }
fn default_temperature() -> f32 { fn default_temperature() -> f32 {
...@@ -56,6 +58,7 @@ fn default_parameters() -> GenerateParameters { ...@@ -56,6 +58,7 @@ fn default_parameters() -> GenerateParameters {
max_new_tokens: default_max_new_tokens(), max_new_tokens: default_max_new_tokens(),
stop: vec![], stop: vec![],
details: false, details: false,
seed: None,
} }
} }
...@@ -70,6 +73,7 @@ pub(crate) struct GenerateRequest { ...@@ -70,6 +73,7 @@ pub(crate) struct GenerateRequest {
pub(crate) struct Details { pub(crate) struct Details {
pub finish_reason: String, pub finish_reason: String,
pub generated_tokens: u32, pub generated_tokens: u32,
pub seed: Option<u64>,
pub tokens: Vec<(u32, String, f32)>, pub tokens: Vec<(u32, String, f32)>,
} }
......
...@@ -55,6 +55,7 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E ...@@ -55,6 +55,7 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
max_new_tokens: 1, max_new_tokens: 1,
stop: vec![], stop: vec![],
details: false, details: false,
seed: None,
}, },
}, },
) )
...@@ -70,7 +71,8 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E ...@@ -70,7 +71,8 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
validation_time, validation_time,
queue_time, queue_time,
inference_time, inference_time,
time_per_token time_per_token,
seed
) )
)] )]
async fn generate( async fn generate(
...@@ -118,6 +120,7 @@ async fn generate( ...@@ -118,6 +120,7 @@ async fn generate(
.map(|((id, text), logprob)| (id, text, logprob)) .map(|((id, text), logprob)| (id, text, logprob))
.collect(); .collect();
Some(Details { Some(Details {
seed: response.seed,
finish_reason: response.finish_reason, finish_reason: response.finish_reason,
generated_tokens: response.generated_tokens, generated_tokens: response.generated_tokens,
tokens, tokens,
...@@ -162,6 +165,7 @@ async fn generate( ...@@ -162,6 +165,7 @@ async fn generate(
tracing::Span::current().record("queue_time", format!("{:?}", queue_time)); tracing::Span::current().record("queue_time", format!("{:?}", queue_time));
tracing::Span::current().record("inference_time", format!("{:?}", inference_time)); tracing::Span::current().record("inference_time", format!("{:?}", inference_time));
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token)); tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
tracing::Span::current().record("seed", format!("{:?}", response.seed));
tracing::info!("Output: {}", response.output_text); tracing::info!("Output: {}", response.output_text);
// Send response // Send response
......
...@@ -234,7 +234,9 @@ class BLOOMSharded(BLOOM): ...@@ -234,7 +234,9 @@ class BLOOMSharded(BLOOM):
if name == "word_embeddings.weight": if name == "word_embeddings.weight":
model.lm_head._parameters["weight"] = tensor model.lm_head._parameters["weight"] = tensor
def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None): def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
......
...@@ -7,7 +7,7 @@ from typing import Optional, Tuple, List, Type ...@@ -7,7 +7,7 @@ from typing import Optional, Tuple, List, Type
from text_generation.models import Model from text_generation.models import Model
from text_generation.models.types import GeneratedText, Batch from text_generation.models.types import GeneratedText, Batch
from text_generation.pb import generate_pb2 from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
@dataclass @dataclass
...@@ -296,7 +296,10 @@ class CausalLM(Model): ...@@ -296,7 +296,10 @@ class CausalLM(Model):
) )
with context_manager(): with context_manager():
logits, past = self.forward( logits, past = self.forward(
batch.input_ids, batch.attention_mask, batch.position_ids, batch.past_key_values batch.input_ids,
batch.attention_mask,
batch.position_ids,
batch.past_key_values,
) )
# List of indices to cache # List of indices to cache
...@@ -373,6 +376,12 @@ class CausalLM(Model): ...@@ -373,6 +376,12 @@ class CausalLM(Model):
1 1
).tolist() ).tolist()
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
# Add to the list of finished generations with the original request # Add to the list of finished generations with the original request
generated_texts.append( generated_texts.append(
GeneratedText( GeneratedText(
...@@ -383,6 +392,7 @@ class CausalLM(Model): ...@@ -383,6 +392,7 @@ class CausalLM(Model):
token_ids=token_ids.squeeze(1).tolist(), token_ids=token_ids.squeeze(1).tolist(),
logprobs=logprobs, logprobs=logprobs,
reason=reason, reason=reason,
seed=seed,
) )
) )
# add to the next batch # add to the next batch
......
...@@ -333,7 +333,9 @@ class GalacticaSharded(Galactica): ...@@ -333,7 +333,9 @@ class GalacticaSharded(Galactica):
if name == "model.decoder.embed_tokens.weight": if name == "model.decoder.embed_tokens.weight":
model.lm_head._parameters["weight"] = tensor model.lm_head._parameters["weight"] = tensor
def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None): def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
......
...@@ -39,12 +39,16 @@ class SantaCoder(CausalLM): ...@@ -39,12 +39,16 @@ class SantaCoder(CausalLM):
} }
) )
self.model = AutoModelForCausalLM.from_pretrained( self.model = (
model_name, AutoModelForCausalLM.from_pretrained(
torch_dtype=dtype, model_name,
load_in_8bit=quantize, torch_dtype=dtype,
trust_remote_code=True, # required load_in_8bit=quantize,
).to(device).eval() trust_remote_code=True, # required
)
.to(device)
.eval()
)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
......
...@@ -7,7 +7,7 @@ from typing import Optional, Tuple, List, Type ...@@ -7,7 +7,7 @@ from typing import Optional, Tuple, List, Type
from text_generation.models import Model from text_generation.models import Model
from text_generation.models.types import GeneratedText, Batch from text_generation.models.types import GeneratedText, Batch
from text_generation.pb import generate_pb2 from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
@dataclass @dataclass
...@@ -451,6 +451,13 @@ class Seq2SeqLM(Model): ...@@ -451,6 +451,13 @@ class Seq2SeqLM(Model):
logprobs = [float("nan")] + decoder_logprobs[ logprobs = [float("nan")] + decoder_logprobs[
-decoder_input_length: -decoder_input_length:
].tolist() ].tolist()
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
# Add to the list of finished generations with the original request # Add to the list of finished generations with the original request
generated_texts.append( generated_texts.append(
GeneratedText( GeneratedText(
...@@ -461,6 +468,7 @@ class Seq2SeqLM(Model): ...@@ -461,6 +468,7 @@ class Seq2SeqLM(Model):
token_ids=token_ids.tolist(), token_ids=token_ids.tolist(),
logprobs=logprobs, logprobs=logprobs,
reason=reason, reason=reason,
seed=seed,
) )
) )
# add to the next batch # add to the next batch
......
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List, Optional
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
...@@ -39,6 +39,7 @@ class GeneratedText: ...@@ -39,6 +39,7 @@ class GeneratedText:
token_ids: List[int] token_ids: List[int]
logprobs: List[float] logprobs: List[float]
reason: str reason: str
seed: Optional[int]
def to_pb(self) -> generate_pb2.GeneratedText: def to_pb(self) -> generate_pb2.GeneratedText:
return generate_pb2.GeneratedText( return generate_pb2.GeneratedText(
...@@ -49,4 +50,5 @@ class GeneratedText: ...@@ -49,4 +50,5 @@ class GeneratedText:
token_ids=self.token_ids, token_ids=self.token_ids,
logprobs=self.logprobs, logprobs=self.logprobs,
finish_reason=self.reason, finish_reason=self.reason,
seed=self.seed,
) )
...@@ -24,11 +24,24 @@ from text_generation.pb import generate_pb2 ...@@ -24,11 +24,24 @@ from text_generation.pb import generate_pb2
class Sampling: class Sampling:
def __init__(self, seed: Optional[int] = None):
self.generator = torch.Generator()
if seed is not None:
self.generator.manual_seed(seed)
else:
self.generator.seed()
def __call__(self, logits): def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, dim=-1) probs = torch.nn.functional.softmax(logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) next_tokens = torch.multinomial(
probs, num_samples=1, generator=self.generator
).squeeze(1)
return next_tokens return next_tokens
@property
def seed(self) -> int:
return self.generator.initial_seed()
class Greedy: class Greedy:
def __call__(self, logits): def __call__(self, logits):
...@@ -36,7 +49,9 @@ class Greedy: ...@@ -36,7 +49,9 @@ class Greedy:
class NextTokenChooser: class NextTokenChooser:
def __init__(self, temperature=1.0, top_k=None, top_p=None, do_sample=False): def __init__(
self, temperature=1.0, top_k=None, top_p=None, do_sample=False, seed=None
):
warpers = LogitsProcessorList() warpers = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py` # all samplers can be found in `generation_utils_samplers.py`
...@@ -53,7 +68,7 @@ class NextTokenChooser: ...@@ -53,7 +68,7 @@ class NextTokenChooser:
sampling = True sampling = True
self.warpers = warpers self.warpers = warpers
self.choice = Sampling() if sampling else Greedy() self.choice = Sampling(seed) if sampling else Greedy()
def __call__(self, input_ids, scores): def __call__(self, input_ids, scores):
# Warp logits # Warp logits
...@@ -66,11 +81,14 @@ class NextTokenChooser: ...@@ -66,11 +81,14 @@ class NextTokenChooser:
@classmethod @classmethod
def from_pb(cls, pb: generate_pb2.NextTokenChooserParameters) -> "NextTokenChooser": def from_pb(cls, pb: generate_pb2.NextTokenChooserParameters) -> "NextTokenChooser":
# handle protobuf making default values 0
seed = pb.seed if pb.HasField("seed") else None
return NextTokenChooser( return NextTokenChooser(
temperature=pb.temperature, temperature=pb.temperature,
top_k=pb.top_k, top_k=pb.top_k,
top_p=pb.top_p, top_p=pb.top_p,
do_sample=pb.do_sample, do_sample=pb.do_sample,
seed=seed,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment