"git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "038ed9999b897625b708cbb2591dc702dc39e513"
Unverified Commit 54fec931 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

fix(server): fix seeding with multiple shards (#44)

parent 03bdf182
...@@ -1834,6 +1834,7 @@ dependencies = [ ...@@ -1834,6 +1834,7 @@ dependencies = [
"futures", "futures",
"nohash-hasher", "nohash-hasher",
"parking_lot", "parking_lot",
"rand",
"serde", "serde",
"serde_json", "serde_json",
"text-generation-client", "text-generation-client",
......
...@@ -37,7 +37,7 @@ message NextTokenChooserParameters { ...@@ -37,7 +37,7 @@ message NextTokenChooserParameters {
/// apply sampling on the logits /// apply sampling on the logits
bool do_sample = 4; bool do_sample = 4;
/// random seed for sampling /// random seed for sampling
optional uint64 seed = 5; uint64 seed = 5;
} }
message StoppingCriteriaParameters { message StoppingCriteriaParameters {
......
...@@ -19,6 +19,7 @@ clap = { version = "4.0.15", features = ["derive", "env"] } ...@@ -19,6 +19,7 @@ clap = { version = "4.0.15", features = ["derive", "env"] }
futures = "0.3.24" futures = "0.3.24"
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"
parking_lot = "0.12.1" parking_lot = "0.12.1"
rand = "0.8.5"
serde = "1.0.145" serde = "1.0.145"
serde_json = "1.0.85" serde_json = "1.0.85"
thiserror = "1.0.37" thiserror = "1.0.37"
......
...@@ -166,7 +166,8 @@ impl From<&GenerateParameters> for NextTokenChooserParameters { ...@@ -166,7 +166,8 @@ impl From<&GenerateParameters> for NextTokenChooserParameters {
top_k: parameters.top_k as u32, top_k: parameters.top_k as u32,
top_p: parameters.top_p, top_p: parameters.top_p,
do_sample: parameters.do_sample, do_sample: parameters.do_sample,
seed: parameters.seed, // FIXME: remove unwrap
seed: parameters.seed.unwrap(),
} }
} }
} }
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
use crate::{ErrorResponse, GenerateRequest}; use crate::{ErrorResponse, GenerateRequest};
use axum::http::StatusCode; use axum::http::StatusCode;
use axum::Json; use axum::Json;
use rand::rngs::ThreadRng;
use rand::Rng;
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
...@@ -92,18 +94,22 @@ fn validation_worker( ...@@ -92,18 +94,22 @@ fn validation_worker(
max_input_length: usize, max_input_length: usize,
mut receiver: mpsc::Receiver<ValidationRequest>, mut receiver: mpsc::Receiver<ValidationRequest>,
) { ) {
// Seed rng
let mut rng = rand::thread_rng();
// Loop over requests // Loop over requests
while let Some((request, response_tx)) = receiver.blocking_recv() { while let Some((request, response_tx)) = receiver.blocking_recv() {
response_tx response_tx
.send(validate(request, &tokenizer, max_input_length)) .send(validate(request, &tokenizer, max_input_length, &mut rng))
.unwrap_or(()) .unwrap_or(())
} }
} }
fn validate( fn validate(
request: GenerateRequest, mut request: GenerateRequest,
tokenizer: &Tokenizer, tokenizer: &Tokenizer,
max_input_length: usize, max_input_length: usize,
rng: &mut ThreadRng,
) -> Result<(usize, GenerateRequest), ValidationError> { ) -> Result<(usize, GenerateRequest), ValidationError> {
if request.parameters.temperature <= 0.0 { if request.parameters.temperature <= 0.0 {
return Err(ValidationError::Temperature); return Err(ValidationError::Temperature);
...@@ -124,6 +130,11 @@ fn validate( ...@@ -124,6 +130,11 @@ fn validate(
)); ));
} }
// If seed is None, assign a random one
if request.parameters.seed.is_none() {
request.parameters.seed = Some(rng.gen());
}
// Get the number of tokens in the input // Get the number of tokens in the input
match tokenizer.encode(request.inputs.clone(), true) { match tokenizer.encode(request.inputs.clone(), true) {
Ok(inputs) => { Ok(inputs) => {
......
This diff is collapsed.
...@@ -15,7 +15,7 @@ grpcio-status = "^1.51.1" ...@@ -15,7 +15,7 @@ grpcio-status = "^1.51.1"
grpcio-reflection = "^1.51.1" grpcio-reflection = "^1.51.1"
grpc-interceptor = "^0.15.0" grpc-interceptor = "^0.15.0"
typer = "^0.6.1" typer = "^0.6.1"
accelerate = "^0.12.0" accelerate = "^0.15.0"
bitsandbytes = "^0.35.1" bitsandbytes = "^0.35.1"
safetensors = "^0.2.4" safetensors = "^0.2.4"
loguru = "^0.6.0" loguru = "^0.6.0"
......
...@@ -33,8 +33,6 @@ try: ...@@ -33,8 +33,6 @@ try:
except Exception as e: except Exception as e:
HAS_BITS_AND_BYTES = False HAS_BITS_AND_BYTES = False
torch.manual_seed(0)
class BloomCausalLMBatch(CausalLMBatch): class BloomCausalLMBatch(CausalLMBatch):
@classmethod @classmethod
......
...@@ -36,7 +36,6 @@ try: ...@@ -36,7 +36,6 @@ try:
except Exception as e: except Exception as e:
HAS_BITS_AND_BYTES = False HAS_BITS_AND_BYTES = False
torch.manual_seed(0)
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py # CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
......
...@@ -24,12 +24,10 @@ from text_generation.pb import generate_pb2 ...@@ -24,12 +24,10 @@ from text_generation.pb import generate_pb2
class Sampling: class Sampling:
def __init__(self, seed: Optional[int] = None, device: str = "cpu"): def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device) self.generator = torch.Generator(device)
if seed is not None: self.generator.manual_seed(seed)
self.generator.manual_seed(seed) self.seed = seed
else:
self.generator.seed()
def __call__(self, logits): def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, dim=-1) probs = torch.nn.functional.softmax(logits, dim=-1)
...@@ -38,10 +36,6 @@ class Sampling: ...@@ -38,10 +36,6 @@ class Sampling:
).squeeze(1) ).squeeze(1)
return next_tokens return next_tokens
@property
def seed(self) -> int:
return self.generator.initial_seed()
class Greedy: class Greedy:
def __call__(self, logits): def __call__(self, logits):
...@@ -55,7 +49,7 @@ class NextTokenChooser: ...@@ -55,7 +49,7 @@ class NextTokenChooser:
top_k=None, top_k=None,
top_p=None, top_p=None,
do_sample=False, do_sample=False,
seed=None, seed=0,
device="cpu", device="cpu",
): ):
warpers = LogitsProcessorList() warpers = LogitsProcessorList()
...@@ -89,14 +83,12 @@ class NextTokenChooser: ...@@ -89,14 +83,12 @@ class NextTokenChooser:
def from_pb( def from_pb(
cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
) -> "NextTokenChooser": ) -> "NextTokenChooser":
# handle protobuf making default values 0
seed = pb.seed if pb.HasField("seed") else None
return NextTokenChooser( return NextTokenChooser(
temperature=pb.temperature, temperature=pb.temperature,
top_k=pb.top_k, top_k=pb.top_k,
top_p=pb.top_p, top_p=pb.top_p,
do_sample=pb.do_sample, do_sample=pb.do_sample,
seed=seed, seed=pb.seed,
device=str(device), device=str(device),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment