Unverified Commit 54fec931 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

fix(server): fix seeding with multiple shards (#44)

parent 03bdf182
......@@ -1834,6 +1834,7 @@ dependencies = [
"futures",
"nohash-hasher",
"parking_lot",
"rand",
"serde",
"serde_json",
"text-generation-client",
......
......@@ -37,7 +37,7 @@ message NextTokenChooserParameters {
/// apply sampling on the logits
bool do_sample = 4;
/// random seed for sampling
optional uint64 seed = 5;
uint64 seed = 5;
}
message StoppingCriteriaParameters {
......
......@@ -19,6 +19,7 @@ clap = { version = "4.0.15", features = ["derive", "env"] }
futures = "0.3.24"
nohash-hasher = "0.2.0"
parking_lot = "0.12.1"
rand = "0.8.5"
serde = "1.0.145"
serde_json = "1.0.85"
thiserror = "1.0.37"
......
......@@ -166,7 +166,8 @@ impl From<&GenerateParameters> for NextTokenChooserParameters {
top_k: parameters.top_k as u32,
top_p: parameters.top_p,
do_sample: parameters.do_sample,
seed: parameters.seed,
// FIXME: remove unwrap
seed: parameters.seed.unwrap(),
}
}
}
......
......@@ -2,6 +2,8 @@
use crate::{ErrorResponse, GenerateRequest};
use axum::http::StatusCode;
use axum::Json;
use rand::rngs::ThreadRng;
use rand::Rng;
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
use tokio::sync::{mpsc, oneshot};
......@@ -92,18 +94,22 @@ fn validation_worker(
max_input_length: usize,
mut receiver: mpsc::Receiver<ValidationRequest>,
) {
// Seed rng
let mut rng = rand::thread_rng();
// Loop over requests
while let Some((request, response_tx)) = receiver.blocking_recv() {
response_tx
.send(validate(request, &tokenizer, max_input_length))
.send(validate(request, &tokenizer, max_input_length, &mut rng))
.unwrap_or(())
}
}
fn validate(
request: GenerateRequest,
mut request: GenerateRequest,
tokenizer: &Tokenizer,
max_input_length: usize,
rng: &mut ThreadRng,
) -> Result<(usize, GenerateRequest), ValidationError> {
if request.parameters.temperature <= 0.0 {
return Err(ValidationError::Temperature);
......@@ -124,6 +130,11 @@ fn validate(
));
}
// If seed is None, assign a random one
if request.parameters.seed.is_none() {
request.parameters.seed = Some(rng.gen());
}
// Get the number of tokens in the input
match tokenizer.encode(request.inputs.clone(), true) {
Ok(inputs) => {
......
This diff is collapsed.
......@@ -15,7 +15,7 @@ grpcio-status = "^1.51.1"
grpcio-reflection = "^1.51.1"
grpc-interceptor = "^0.15.0"
typer = "^0.6.1"
accelerate = "^0.12.0"
accelerate = "^0.15.0"
bitsandbytes = "^0.35.1"
safetensors = "^0.2.4"
loguru = "^0.6.0"
......
......@@ -33,8 +33,6 @@ try:
except Exception as e:
HAS_BITS_AND_BYTES = False
torch.manual_seed(0)
class BloomCausalLMBatch(CausalLMBatch):
@classmethod
......
......@@ -36,7 +36,6 @@ try:
except Exception as e:
HAS_BITS_AND_BYTES = False
torch.manual_seed(0)
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
......
......@@ -24,12 +24,10 @@ from text_generation.pb import generate_pb2
class Sampling:
def __init__(self, seed: Optional[int] = None, device: str = "cpu"):
def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device)
if seed is not None:
self.generator.manual_seed(seed)
else:
self.generator.seed()
self.seed = seed
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, dim=-1)
......@@ -38,10 +36,6 @@ class Sampling:
).squeeze(1)
return next_tokens
@property
def seed(self) -> int:
return self.generator.initial_seed()
class Greedy:
def __call__(self, logits):
......@@ -55,7 +49,7 @@ class NextTokenChooser:
top_k=None,
top_p=None,
do_sample=False,
seed=None,
seed=0,
device="cpu",
):
warpers = LogitsProcessorList()
......@@ -89,14 +83,12 @@ class NextTokenChooser:
def from_pb(
cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
) -> "NextTokenChooser":
# handle protobuf making default values 0
seed = pb.seed if pb.HasField("seed") else None
return NextTokenChooser(
temperature=pb.temperature,
top_k=pb.top_k,
top_p=pb.top_p,
do_sample=pb.do_sample,
seed=seed,
seed=pb.seed,
device=str(device),
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment