Unverified Commit adaf1a39 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Metric for detokenization latency (#6160)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 1488ef2e
...@@ -54,7 +54,7 @@ pub fn decode(c: &mut Criterion) { ...@@ -54,7 +54,7 @@ pub fn decode(c: &mut Criterion) {
let tokenizer: Arc<dyn Tokenizer> = let tokenizer: Arc<dyn Tokenizer> =
Arc::new(HuggingFaceTokenizer::from_file(TEST_TOKENIZER).unwrap()); Arc::new(HuggingFaceTokenizer::from_file(TEST_TOKENIZER).unwrap());
let ds = DecodeStream::new(tokenizer, &[], false); let ds = DecodeStream::new(tokenizer, &[], false);
Decoder::new(ds, StopConditions::default(), false) Decoder::new(ds, StopConditions::default(), false, None)
}, },
|mut decoder| { |mut decoder| {
for tok in black_box(TEST_TOKS) { for tok in black_box(TEST_TOKS) {
...@@ -78,7 +78,7 @@ pub fn decode_big(c: &mut Criterion) { ...@@ -78,7 +78,7 @@ pub fn decode_big(c: &mut Criterion) {
let tokenizer: Arc<dyn Tokenizer> = let tokenizer: Arc<dyn Tokenizer> =
Arc::new(HuggingFaceTokenizer::from_file(TEST_TOKENIZER).unwrap()); Arc::new(HuggingFaceTokenizer::from_file(TEST_TOKENIZER).unwrap());
let ds = DecodeStream::new(tokenizer, &[], false); let ds = DecodeStream::new(tokenizer, &[], false);
Decoder::new(ds, StopConditions::default(), false) Decoder::new(ds, StopConditions::default(), false, None)
}, },
|mut decoder| { |mut decoder| {
for tok in black_box(&BIG_TEST_TOKS) { for tok in black_box(&BIG_TEST_TOKS) {
......
...@@ -15,11 +15,10 @@ ...@@ -15,11 +15,10 @@
//! Further post-processing can happen in the response stream. One example is the jailing mechanism for partial //! Further post-processing can happen in the response stream. One example is the jailing mechanism for partial
//! hidden stop condition matches, which can be handled in the response stream rather than the backend. //! hidden stop condition matches, which can be handled in the response stream rather than the backend.
use std::{collections::HashSet, sync::Arc}; use std::{collections::HashSet, sync::Arc, time::Instant};
use anyhow::Result; use anyhow::Result;
use futures::stream::{self, StreamExt}; use futures::stream::{self, StreamExt};
use tracing as log;
use crate::model_card::ModelDeploymentCard; use crate::model_card::ModelDeploymentCard;
use dynamo_runtime::{ use dynamo_runtime::{
...@@ -39,6 +38,7 @@ use crate::protocols::{ ...@@ -39,6 +38,7 @@ use crate::protocols::{
PreprocessedRequest, PreprocessedRequest,
}, },
preprocessor::PreprocessedEmbeddingRequest, preprocessor::PreprocessedEmbeddingRequest,
timing::RequestTracker,
}, },
}; };
use crate::tokenizers::{DecodeStream, HuggingFaceTokenizer, Tokenizer}; use crate::tokenizers::{DecodeStream, HuggingFaceTokenizer, Tokenizer};
...@@ -99,6 +99,7 @@ impl Backend { ...@@ -99,6 +99,7 @@ impl Backend {
stop_conditions: StopConditions, stop_conditions: StopConditions,
skip_special_tokens: bool, skip_special_tokens: bool,
include_stop_str_in_output: bool, include_stop_str_in_output: bool,
tracker: Option<Arc<RequestTracker>>,
) -> anyhow::Result<DecoderUnfoldState> { ) -> anyhow::Result<DecoderUnfoldState> {
let Some(tokenizer) = self.tokenizer.as_ref() else { let Some(tokenizer) = self.tokenizer.as_ref() else {
anyhow::bail!("Backend built from blank ModelDeploymentCard, no tokenizer"); anyhow::bail!("Backend built from blank ModelDeploymentCard, no tokenizer");
...@@ -107,6 +108,7 @@ impl Backend { ...@@ -107,6 +108,7 @@ impl Backend {
tokenizer.decode_stream(prompt_token_ids, skip_special_tokens), tokenizer.decode_stream(prompt_token_ids, skip_special_tokens),
stop_conditions, stop_conditions,
include_stop_str_in_output, include_stop_str_in_output,
tracker,
); );
Ok(DecoderUnfoldState { Ok(DecoderUnfoldState {
...@@ -144,6 +146,7 @@ impl ...@@ -144,6 +146,7 @@ impl
.sampling_options .sampling_options
.include_stop_str_in_output .include_stop_str_in_output
.unwrap_or(false); .unwrap_or(false);
let tracker = request.tracker.clone();
let next_stream = next.generate(request).await?; let next_stream = next.generate(request).await?;
...@@ -154,6 +157,7 @@ impl ...@@ -154,6 +157,7 @@ impl
stop_conditions, stop_conditions,
skip_special_tokens, skip_special_tokens,
include_stop_str_in_output, include_stop_str_in_output,
tracker,
)?; )?;
let processed_stream = stream::unfold(state, |mut state| async move { let processed_stream = stream::unfold(state, |mut state| async move {
...@@ -226,7 +230,7 @@ impl ...@@ -226,7 +230,7 @@ impl
if state.validate_engine_decode { if state.validate_engine_decode {
if data.finish_reason != finish_reason { if data.finish_reason != finish_reason {
log::warn!( tracing::warn!(
"finish reason mismatch: expected {:?}, got {:?}", "finish reason mismatch: expected {:?}, got {:?}",
data.finish_reason, data.finish_reason,
finish_reason finish_reason
...@@ -234,7 +238,11 @@ impl ...@@ -234,7 +238,11 @@ impl
} }
if data.text.is_some() && data.text != text { if data.text.is_some() && data.text != text {
log::warn!("text mismatch: expected {:?}, got {:?}", data.text, text); tracing::warn!(
"text mismatch: expected {:?}, got {:?}",
data.text,
text
);
} }
} }
...@@ -326,6 +334,7 @@ impl ...@@ -326,6 +334,7 @@ impl
#[allow(dead_code)] #[allow(dead_code)]
pub struct Decoder { pub struct Decoder {
decode_stream: DecodeStream, decode_stream: DecodeStream,
tracker: Option<Arc<RequestTracker>>,
// do not trigger stop conditions until at least this many tokens have been generated // do not trigger stop conditions until at least this many tokens have been generated
min_tokens: u32, min_tokens: u32,
...@@ -398,6 +407,7 @@ impl Decoder { ...@@ -398,6 +407,7 @@ impl Decoder {
decode_stream: DecodeStream, decode_stream: DecodeStream,
stop_condition: StopConditions, stop_condition: StopConditions,
include_stop_str_in_output: bool, include_stop_str_in_output: bool,
tracker: Option<Arc<RequestTracker>>,
) -> Self { ) -> Self {
let hidden_stop_ids: HashSet<TokenIdType> = stop_condition let hidden_stop_ids: HashSet<TokenIdType> = stop_condition
.stop_token_ids_hidden .stop_token_ids_hidden
...@@ -425,6 +435,7 @@ impl Decoder { ...@@ -425,6 +435,7 @@ impl Decoder {
Self { Self {
decode_stream, decode_stream,
tracker,
hidden_stop_ids, hidden_stop_ids,
hidden_stop_sequences, hidden_stop_sequences,
visible_stop_sequences, visible_stop_sequences,
...@@ -447,7 +458,11 @@ impl Decoder { ...@@ -447,7 +458,11 @@ impl Decoder {
self.generated_tokens += 1; self.generated_tokens += 1;
// decode the token // decode the token
let detokenize_start = Instant::now();
let token = self.decode_stream.step(token_id)?; let token = self.decode_stream.step(token_id)?;
if let Some(tracker) = &self.tracker {
tracker.record_detokenize_latency(detokenize_start.elapsed());
}
// stop conditions to not apply until the minimum number of tokens have been generated // stop conditions to not apply until the minimum number of tokens have been generated
if self.generated_tokens < self.min_tokens { if self.generated_tokens < self.min_tokens {
...@@ -468,18 +483,12 @@ impl Decoder { ...@@ -468,18 +483,12 @@ impl Decoder {
&& let Some(token) = &token && let Some(token) = &token
{ {
let pre_append = self.jail.len(); let pre_append = self.jail.len();
log::debug!("pre_append: {}", pre_append);
log::debug!("jail: {}", self.jail);
self.jail.push_str(token); self.jail.push_str(token);
log::debug!("post_append: {}", self.jail.len());
log::debug!("jail: {}", self.jail);
// Check hidden stop sequences first (excluded from output) // Check hidden stop sequences first (excluded from output)
for seq in &self.hidden_stop_sequences { for seq in &self.hidden_stop_sequences {
log::debug!("stop seq: {}", seq);
if let Some(offset) = galil_seiferas::gs_find(self.jail.as_bytes(), seq.as_bytes()) if let Some(offset) = galil_seiferas::gs_find(self.jail.as_bytes(), seq.as_bytes())
{ {
log::debug!("offset: {}", offset);
// return only new bytes after pre_append .. offset (excluding stop sequence) // return only new bytes after pre_append .. offset (excluding stop sequence)
// example: seq = "ox", token = "boxes", return "b" // example: seq = "ox", token = "boxes", return "b"
// note: this changes when we start jailing tokens for partial matches // note: this changes when we start jailing tokens for partial matches
......
...@@ -328,8 +328,11 @@ pub struct ResponseMetricCollector { ...@@ -328,8 +328,11 @@ pub struct ResponseMetricCollector {
osl: usize, osl: usize,
// we track if cached_tokens has been observed to ensure we only increment once per request // we track if cached_tokens has been observed to ensure we only increment once per request
cached_tokens_observed: bool, cached_tokens_observed: bool,
// we track if tokenizer latency has been observed to ensure we only increment once per request // we track if tokenize latency has been observed to ensure we only increment once per request
tokenizer_latency_observed: bool, tokenize_latency_observed: bool,
// latest accumulated detokenize latency and sample count reported by tracker
detokenize_latency_total: Duration,
detokenize_count_total: u64,
// Prefill worker info for TTFT attribution (set from LLMMetricAnnotation) // Prefill worker info for TTFT attribution (set from LLMMetricAnnotation)
prefill_worker_id: Option<u64>, prefill_worker_id: Option<u64>,
prefill_dp_rank: Option<u32>, prefill_dp_rank: Option<u32>,
...@@ -987,7 +990,9 @@ impl ResponseMetricCollector { ...@@ -987,7 +990,9 @@ impl ResponseMetricCollector {
start_time: Instant::now(), start_time: Instant::now(),
osl: 0, osl: 0,
cached_tokens_observed: false, cached_tokens_observed: false,
tokenizer_latency_observed: false, tokenize_latency_observed: false,
detokenize_latency_total: Duration::ZERO,
detokenize_count_total: 0,
prefill_worker_id: None, prefill_worker_id: None,
prefill_dp_rank: None, prefill_dp_rank: None,
prefill_worker_type: None, prefill_worker_type: None,
...@@ -1052,17 +1057,30 @@ impl ResponseMetricCollector { ...@@ -1052,17 +1057,30 @@ impl ResponseMetricCollector {
} }
} }
/// Observe tokenizer latency in milliseconds, once per request. /// Observe tokenize/detokenize latencies in milliseconds.
pub fn observe_tokenizer_latency(&mut self, tokenizer_latency: Option<Duration>) { /// Tokenize is observed once per request; detokenize is accumulated and observed at request end.
if let Some(latency) = tokenizer_latency pub fn observe_tokenize_latencies(
&& !self.tokenizer_latency_observed &mut self,
tokenize_latency: Option<Duration>,
detokenize_latency: Option<Duration>,
detokenize_count: Option<u64>,
) {
if let Some(latency) = tokenize_latency
&& !self.tokenize_latency_observed
{ {
self.tokenizer_latency_observed = true; self.tokenize_latency_observed = true;
self.metrics self.metrics
.tokenizer_latency .tokenizer_latency
.with_label_values(&[frontend_service::operation::TOKENIZE]) .with_label_values(&[frontend_service::operation::TOKENIZE])
.observe(latency.as_secs_f64() * 1000.0); .observe(latency.as_secs_f64() * 1000.0);
} }
if let Some(latency) = detokenize_latency {
self.detokenize_latency_total = latency;
}
if let Some(count) = detokenize_count {
self.detokenize_count_total = count;
}
} }
/// Observe a response with input sequence length and number of new tokens /// Observe a response with input sequence length and number of new tokens
...@@ -1155,6 +1173,15 @@ impl ResponseMetricCollector { ...@@ -1155,6 +1173,15 @@ impl ResponseMetricCollector {
impl Drop for ResponseMetricCollector { impl Drop for ResponseMetricCollector {
fn drop(&mut self) { fn drop(&mut self) {
if !self.detokenize_latency_total.is_zero() && self.detokenize_count_total > 0 {
let avg_detokenize_latency_ms = (self.detokenize_latency_total.as_secs_f64() * 1000.0)
/ self.detokenize_count_total as f64;
self.metrics
.tokenizer_latency
.with_label_values(&[frontend_service::operation::DETOKENIZE])
.observe(avg_detokenize_latency_ms);
}
// Publish final OSL when the collector is dropped // Publish final OSL when the collector is dropped
self.metrics self.metrics
.output_sequence_length .output_sequence_length
...@@ -1179,7 +1206,11 @@ pub fn process_response_and_observe_metrics<T>( ...@@ -1179,7 +1206,11 @@ pub fn process_response_and_observe_metrics<T>(
if let Ok(Some(metrics)) = LLMMetricAnnotation::from_annotation(annotated) { if let Ok(Some(metrics)) = LLMMetricAnnotation::from_annotation(annotated) {
response_collector.observe_current_osl(metrics.output_tokens); response_collector.observe_current_osl(metrics.output_tokens);
response_collector.observe_cached_tokens(metrics.cached_tokens); response_collector.observe_cached_tokens(metrics.cached_tokens);
response_collector.observe_tokenizer_latency(metrics.tokenizer_latency); response_collector.observe_tokenize_latencies(
metrics.tokenize_latency,
metrics.detokenize_total_latency,
metrics.detokenize_count,
);
response_collector.set_worker_info( response_collector.set_worker_info(
metrics.prefill_worker_id, metrics.prefill_worker_id,
metrics.prefill_dp_rank, metrics.prefill_dp_rank,
...@@ -1229,7 +1260,11 @@ pub fn process_response_using_event_converter_and_observe_metrics<T: Serialize>( ...@@ -1229,7 +1260,11 @@ pub fn process_response_using_event_converter_and_observe_metrics<T: Serialize>(
if let Ok(Some(metrics)) = LLMMetricAnnotation::from_annotation(&annotated) { if let Ok(Some(metrics)) = LLMMetricAnnotation::from_annotation(&annotated) {
response_collector.observe_current_osl(metrics.output_tokens); response_collector.observe_current_osl(metrics.output_tokens);
response_collector.observe_cached_tokens(metrics.cached_tokens); response_collector.observe_cached_tokens(metrics.cached_tokens);
response_collector.observe_tokenizer_latency(metrics.tokenizer_latency); response_collector.observe_tokenize_latencies(
metrics.tokenize_latency,
metrics.detokenize_total_latency,
metrics.detokenize_count,
);
response_collector.set_worker_info( response_collector.set_worker_info(
metrics.prefill_worker_id, metrics.prefill_worker_id,
metrics.prefill_dp_rank, metrics.prefill_dp_rank,
...@@ -1735,7 +1770,9 @@ mod tests { ...@@ -1735,7 +1770,9 @@ mod tests {
decode_worker_id: None, decode_worker_id: None,
decode_dp_rank: None, decode_dp_rank: None,
decode_worker_type: None, decode_worker_type: None,
tokenizer_latency: Some(Duration::from_millis(8)), tokenize_latency: Some(Duration::from_millis(8)),
detokenize_total_latency: Some(Duration::from_micros(100)),
detokenize_count: Some(2),
}; };
let annotation = llm_metrics.to_annotation::<()>().unwrap(); let annotation = llm_metrics.to_annotation::<()>().unwrap();
...@@ -1753,6 +1790,9 @@ mod tests { ...@@ -1753,6 +1790,9 @@ mod tests {
// Should return Ok(None) for metrics annotation events // Should return Ok(None) for metrics annotation events
assert!(matches!(result, Ok(None))); assert!(matches!(result, Ok(None)));
// Drop collector so the detokenize observation fires in Drop
drop(collector);
// Should have observed the cached tokens from the metrics annotation event // Should have observed the cached tokens from the metrics annotation event
let metric_families = registry.gather(); let metric_families = registry.gather();
let histogram_family = metric_families let histogram_family = metric_families
...@@ -1770,11 +1810,31 @@ mod tests { ...@@ -1770,11 +1810,31 @@ mod tests {
.iter() .iter()
.find(|mf| mf.name() == expected_tokenizer_metric_name) .find(|mf| mf.name() == expected_tokenizer_metric_name)
.expect("histogram should be registered"); .expect("histogram should be registered");
assert_eq!(
histogram_family.get_metric()[0] // Find the tokenize and detokenize observations by label
.get_histogram() let tokenize_metric = histogram_family
.get_sample_count(), .get_metric()
1 .iter()
.find(|m| m.get_label().iter().any(|l| l.value() == "tokenize"))
.expect("tokenize metric should exist");
assert_eq!(tokenize_metric.get_histogram().get_sample_count(), 1);
// 8ms
assert!(
(tokenize_metric.get_histogram().get_sample_sum() - 8.0).abs() < 0.001,
"tokenize latency should be 8.0ms"
);
let detokenize_metric = histogram_family
.get_metric()
.iter()
.find(|m| m.get_label().iter().any(|l| l.value() == "detokenize"))
.expect("detokenize metric should exist");
assert_eq!(detokenize_metric.get_histogram().get_sample_count(), 1);
// Average: 100us total / 2 samples = 50us = 0.05ms
assert!(
(detokenize_metric.get_histogram().get_sample_sum() - 0.05).abs() < 0.001,
"detokenize average latency should be 0.05ms, got {}",
detokenize_metric.get_histogram().get_sample_sum()
); );
} }
...@@ -1813,7 +1873,9 @@ mod tests { ...@@ -1813,7 +1873,9 @@ mod tests {
decode_worker_id: None, decode_worker_id: None,
decode_dp_rank: None, decode_dp_rank: None,
decode_worker_type: None, decode_worker_type: None,
tokenizer_latency: Some(Duration::from_millis(8)), tokenize_latency: Some(Duration::from_millis(8)),
detokenize_total_latency: Some(Duration::from_micros(100)),
detokenize_count: Some(2),
}; };
let annotation = llm_metrics.to_annotation::<()>().unwrap(); let annotation = llm_metrics.to_annotation::<()>().unwrap();
...@@ -1824,6 +1886,9 @@ mod tests { ...@@ -1824,6 +1886,9 @@ mod tests {
let mut http_queue_guard = None; let mut http_queue_guard = None;
process_response_and_observe_metrics(&annotated, &mut collector, &mut http_queue_guard); process_response_and_observe_metrics(&annotated, &mut collector, &mut http_queue_guard);
// Drop collector so the detokenize observation fires in Drop
drop(collector);
// Should have observed the cached tokens from the metrics annotation event // Should have observed the cached tokens from the metrics annotation event
let metric_families = registry.gather(); let metric_families = registry.gather();
let histogram_family = metric_families let histogram_family = metric_families
...@@ -1841,11 +1906,26 @@ mod tests { ...@@ -1841,11 +1906,26 @@ mod tests {
.iter() .iter()
.find(|mf| mf.name() == expected_tokenizer_metric_name) .find(|mf| mf.name() == expected_tokenizer_metric_name)
.expect("histogram should be registered"); .expect("histogram should be registered");
assert_eq!(
histogram_family.get_metric()[0] // Find the tokenize and detokenize observations by label
.get_histogram() let tokenize_metric = histogram_family
.get_sample_count(), .get_metric()
1 .iter()
.find(|m| m.get_label().iter().any(|l| l.value() == "tokenize"))
.expect("tokenize metric should exist");
assert_eq!(tokenize_metric.get_histogram().get_sample_count(), 1);
let detokenize_metric = histogram_family
.get_metric()
.iter()
.find(|m| m.get_label().iter().any(|l| l.value() == "detokenize"))
.expect("detokenize metric should exist");
assert_eq!(detokenize_metric.get_histogram().get_sample_count(), 1);
// Average: 100us total / 2 samples = 50us = 0.05ms
assert!(
(detokenize_metric.get_histogram().get_sample_sum() - 0.05).abs() < 0.001,
"detokenize average latency should be 0.05ms, got {}",
detokenize_metric.get_histogram().get_sample_sum()
); );
} }
} }
...@@ -94,7 +94,11 @@ pub struct LLMMetricAnnotation { ...@@ -94,7 +94,11 @@ pub struct LLMMetricAnnotation {
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub decode_worker_type: Option<String>, pub decode_worker_type: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub tokenizer_latency: Option<Duration>, pub tokenize_latency: Option<Duration>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub detokenize_total_latency: Option<Duration>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub detokenize_count: Option<u64>,
} }
impl LLMMetricAnnotation { impl LLMMetricAnnotation {
...@@ -525,7 +529,7 @@ impl OpenAIPreprocessor { ...@@ -525,7 +529,7 @@ impl OpenAIPreprocessor {
let encode_start = Instant::now(); let encode_start = Instant::now();
let encoding = self.tokenizer.encode(prompt)?; let encoding = self.tokenizer.encode(prompt)?;
if let Some(t) = tracker { if let Some(t) = tracker {
t.record_tokenizer_latency(encode_start.elapsed()); t.record_tokenize_latency(encode_start.elapsed());
} }
Ok(encoding) Ok(encoding)
} }
...@@ -715,7 +719,9 @@ impl OpenAIPreprocessor { ...@@ -715,7 +719,9 @@ impl OpenAIPreprocessor {
decode_worker_id, decode_worker_id,
decode_dp_rank, decode_dp_rank,
decode_worker_type, decode_worker_type,
tokenizer_latency: tracker.as_ref().and_then(|t| t.tokenizer_latency()), tokenize_latency: tracker.as_ref().and_then(|t| t.tokenize_latency()),
detokenize_total_latency: tracker.as_ref().and_then(|t| t.detokenize_total_latency()),
detokenize_count: tracker.as_ref().map(|t| t.detokenize_count()),
}; };
if let Ok(metrics_annotated) = llm_metrics.to_annotation::<()>() { if let Ok(metrics_annotated) = llm_metrics.to_annotation::<()>() {
...@@ -776,7 +782,11 @@ impl OpenAIPreprocessor { ...@@ -776,7 +782,11 @@ impl OpenAIPreprocessor {
decode_worker_id, decode_worker_id,
decode_dp_rank, decode_dp_rank,
decode_worker_type, decode_worker_type,
tokenizer_latency: tracker.as_ref().and_then(|t| t.tokenizer_latency()), tokenize_latency: tracker.as_ref().and_then(|t| t.tokenize_latency()),
detokenize_total_latency: tracker
.as_ref()
.and_then(|t| t.detokenize_total_latency()),
detokenize_count: tracker.as_ref().map(|t| t.detokenize_count()),
}; };
// Create annotation string // Create annotation string
......
...@@ -153,7 +153,13 @@ pub struct RequestTracker { ...@@ -153,7 +153,13 @@ pub struct RequestTracker {
phase_semaphore: Arc<Semaphore>, phase_semaphore: Arc<Semaphore>,
/// How long it took to tokenize the input /// How long it took to tokenize the input
tokenizer_latency: OnceLock<Duration>, tokenize_latency: OnceLock<Duration>,
/// Accumulated time spent detokenizing output tokens for this request (nanoseconds)
detokenize_total_ns: AtomicU64,
/// Number of detokenize samples accumulated for this request
detokenize_count: AtomicU64,
} }
impl RequestTracker { impl RequestTracker {
...@@ -184,7 +190,9 @@ impl RequestTracker { ...@@ -184,7 +190,9 @@ impl RequestTracker {
decode_worker_type: OnceLock::new(), decode_worker_type: OnceLock::new(),
phase: Mutex::new(RequestPhase::Aggregated), phase: Mutex::new(RequestPhase::Aggregated),
phase_semaphore: Arc::new(Semaphore::new(1)), phase_semaphore: Arc::new(Semaphore::new(1)),
tokenizer_latency: OnceLock::new(), tokenize_latency: OnceLock::new(),
detokenize_total_ns: AtomicU64::new(0),
detokenize_count: AtomicU64::new(0),
} }
} }
...@@ -338,12 +346,40 @@ impl RequestTracker { ...@@ -338,12 +346,40 @@ impl RequestTracker {
} }
} }
pub fn record_tokenizer_latency(&self, l: Duration) { pub fn record_tokenize_latency(&self, l: Duration) {
let _ = self.tokenizer_latency.set(l); let _ = self.tokenize_latency.set(l);
}
pub fn tokenize_latency(&self) -> Option<Duration> {
self.tokenize_latency.get().copied()
}
pub fn record_detokenize_latency(&self, l: Duration) {
// u128 -> u64 is safe because max u64 in nanos is over 500 years
let delta_ns = u64::try_from(l.as_nanos()).unwrap_or(u64::MAX);
// On an x86 system these atomics are very cheap
let _ = self.detokenize_total_ns.fetch_update(
Ordering::Relaxed,
Ordering::Relaxed,
// Saturating add to avoid wrapping to a nonsensical average on overflow.
|current| Some(current.saturating_add(delta_ns)),
);
self.detokenize_count.fetch_add(1, Ordering::Relaxed);
}
pub fn detokenize_total_latency(&self) -> Option<Duration> {
let total_ns = self.detokenize_total_ns.load(Ordering::Relaxed);
let count = self.detokenize_count.load(Ordering::Relaxed);
if count == 0 {
// We recorded no observations
None
} else {
Some(Duration::from_nanos(total_ns))
}
} }
pub fn tokenizer_latency(&self) -> Option<Duration> { pub fn detokenize_count(&self) -> u64 {
self.tokenizer_latency.get().copied() self.detokenize_count.load(Ordering::Relaxed)
} }
/// Get worker ID information if any worker IDs have been recorded. /// Get worker ID information if any worker IDs have been recorded.
......
...@@ -58,7 +58,7 @@ fn make_decoder( ...@@ -58,7 +58,7 @@ fn make_decoder(
stop: stop_sequences.map(|v| v.into_iter().map(String::from).collect()), stop: stop_sequences.map(|v| v.into_iter().map(String::from).collect()),
..Default::default() ..Default::default()
}; };
Decoder::new(decode_stream, stop_conditions, include_stop_str) Decoder::new(decode_stream, stop_conditions, include_stop_str, None)
} }
#[test] #[test]
......
...@@ -220,7 +220,6 @@ pub mod frontend_service { ...@@ -220,7 +220,6 @@ pub mod frontend_service {
pub const TOKENIZE: &str = "tokenize"; pub const TOKENIZE: &str = "tokenize";
/// Detokenization operation /// Detokenization operation
/// Currently unused, will be added next.
pub const DETOKENIZE: &str = "detokenize"; pub const DETOKENIZE: &str = "detokenize";
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment