Unverified Commit 38fcafcf authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding a test for FD. (#2516)

* Adding a test for FD.

* Fixing flashdecoding (empty batch doesn't work).

* Fixing the invalid popping.

* Fixing radix with block_size > 1

* Last reference.

* Use an actual hash.

* Update hash for slice.len() == 1

* Update the locks.

* Increasing docker timeout.
parent 77746552
This diff is collapsed.
...@@ -364,7 +364,7 @@ impl State { ...@@ -364,7 +364,7 @@ impl State {
// Add it back to the front // Add it back to the front
tracing::debug!("Over budget: not enough free blocks"); tracing::debug!("Over budget: not enough free blocks");
self.entries.push_front((id, entry)); self.entries.push_front((id, entry));
break; continue;
} }
Some(block_allocation) => { Some(block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}"); tracing::debug!("Allocation: {block_allocation:?}");
...@@ -436,6 +436,12 @@ impl State { ...@@ -436,6 +436,12 @@ impl State {
batch_entries.insert(id, entry); batch_entries.insert(id, entry);
} }
// Empty batch
if batch_requests.is_empty() {
tracing::debug!("Filterered out all entries");
return None;
}
// Final batch size // Final batch size
let size = batch_requests.len() as u32; let size = batch_requests.len() as u32;
next_batch_span.record("batch_size", size); next_batch_span.record("batch_size", size);
......
use crate::block_allocator::{Allocator, BlockAllocation}; use crate::block_allocator::{Allocator, BlockAllocation};
use slotmap::{DefaultKey, SlotMap}; use slotmap::{DefaultKey, SlotMap};
use std::hash::{Hash, Hasher};
use std::{ use std::{
collections::{BTreeSet, HashMap}, collections::{BTreeSet, HashMap},
sync::Arc, sync::Arc,
}; };
fn hash(slice: &[u32]) -> u64 {
assert!(!slice.is_empty());
if slice.len() == 1 {
slice[0] as u64
} else {
let mut s = std::hash::DefaultHasher::new();
slice.hash(&mut s);
s.finish()
}
}
pub struct RadixAllocator { pub struct RadixAllocator {
allocation_id: u64, allocation_id: u64,
...@@ -44,6 +56,10 @@ impl RadixAllocator { ...@@ -44,6 +56,10 @@ impl RadixAllocator {
// the free list if we cannot allocate enough blocks. This is only // the free list if we cannot allocate enough blocks. This is only
// temporary, the trie needs to be able to report whether it can // temporary, the trie needs to be able to report whether it can
// allocate the requested amount. Just not implemented yet. // allocate the requested amount. Just not implemented yet.
tracing::debug!(
"Free blocks {} need {n_blocks_needed}",
self.free_blocks.len()
);
self.free_blocks.extend( self.free_blocks.extend(
self.cache_blocks self.cache_blocks
.evict(n_blocks_needed - self.free_blocks.len()), .evict(n_blocks_needed - self.free_blocks.len()),
...@@ -94,6 +110,9 @@ impl Allocator for RadixAllocator { ...@@ -94,6 +110,9 @@ impl Allocator for RadixAllocator {
match self.alloc_or_reclaim(suffix_blocks as usize) { match self.alloc_or_reclaim(suffix_blocks as usize) {
Some(suffix_blocks) => blocks.extend(suffix_blocks), Some(suffix_blocks) => blocks.extend(suffix_blocks),
None => { None => {
tracing::debug!("Cannot allocate {:?}", self.cache_blocks);
tracing::debug!("Found {prefix_len} prefix tokens need {suffix_blocks} suffix blocks for {tokens} tokens");
tracing::debug!("Block size {}", self.block_size);
self.cache_blocks self.cache_blocks
.decref(prefix_node) .decref(prefix_node)
.expect("Failed to decrement refcount"); .expect("Failed to decrement refcount");
...@@ -211,7 +230,6 @@ struct RadixAllocation { ...@@ -211,7 +230,6 @@ struct RadixAllocation {
pub enum TrieError { pub enum TrieError {
InvalidNodeId, InvalidNodeId,
RefCountUnderflow, RefCountUnderflow,
BlockTokenCountMismatch,
} }
pub type NodeId = DefaultKey; pub type NodeId = DefaultKey;
...@@ -268,16 +286,19 @@ impl RadixTrie { ...@@ -268,16 +286,19 @@ impl RadixTrie {
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId { fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
let node = &self.nodes[node_id]; let node = &self.nodes[node_id];
if let Some(&child_id) = node.children.get(&key[0]) { if key.len() >= self.block_size {
self.update_access_time(child_id); let node_key = hash(&key[..self.block_size]);
let child = self.nodes.get(child_id).expect("Invalid child identifier"); if let Some(&child_id) = node.children.get(&node_key) {
let shared_prefix_len = shared_prefix(&child.key, key, self.block_size); self.update_access_time(child_id);
assert_eq!(shared_prefix_len % self.block_size, 0); let child = self.nodes.get(child_id).expect("Invalid child identifier");
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]); let shared_prefix_len = shared_prefix(&child.key, key, self.block_size);
assert_eq!(shared_prefix_len % self.block_size, 0);
let key = &key[shared_prefix_len..]; blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
if !key.is_empty() {
node_id = self.find_(child_id, key, blocks); let key = &key[shared_prefix_len..];
if !key.is_empty() {
node_id = self.find_(child_id, key, blocks);
}
} }
} }
...@@ -344,9 +365,11 @@ impl RadixTrie { ...@@ -344,9 +365,11 @@ impl RadixTrie {
// evict n_blocks and return `None` if we can't. We are now needlessly // evict n_blocks and return `None` if we can't. We are now needlessly
// evicting prefixes from the cache in such a case. // evicting prefixes from the cache in such a case.
let mut evicted = Vec::new(); let mut evicted = Vec::new();
tracing::debug!("Evicting in search of {n_blocks}");
while let Some((last_access, node_id)) = self.leaves.pop_first() { while let Some((last_access, node_id)) = self.leaves.pop_first() {
let blocks_needed = n_blocks - evicted.len(); let blocks_needed = n_blocks.saturating_sub(evicted.len());
tracing::debug!("Evicting node {node_id:?} ");
let node = self.nodes.get(node_id).expect("Leave does not exist"); let node = self.nodes.get(node_id).expect("Leave does not exist");
assert_eq!( assert_eq!(
...@@ -368,8 +391,11 @@ impl RadixTrie { ...@@ -368,8 +391,11 @@ impl RadixTrie {
// the required number of blocks and leave the remaining blocks // the required number of blocks and leave the remaining blocks
// untouched. // untouched.
let node = self.nodes.get_mut(node_id).expect("Leave does not exist"); let node = self.nodes.get_mut(node_id).expect("Leave does not exist");
node.key.truncate(node.blocks.len() - blocks_needed);
evicted.extend(node.blocks.split_off(node.blocks.len() - blocks_needed)); let truncate_blocks = node.blocks.len() - blocks_needed;
let truncate_tokens = truncate_blocks * self.block_size;
node.key.truncate(truncate_tokens);
evicted.extend(node.blocks.split_off(truncate_blocks));
self.leaves.insert((last_access, node_id)); self.leaves.insert((last_access, node_id));
break; break;
} }
...@@ -400,11 +426,10 @@ impl RadixTrie { ...@@ -400,11 +426,10 @@ impl RadixTrie {
// the part of the prefix that is already in the trie to detect // the part of the prefix that is already in the trie to detect
// mismatches. // mismatches.
if tokens.len() != blocks.len() * self.block_size { assert_eq!(tokens.len(), blocks.len() * self.block_size);
return Err(TrieError::BlockTokenCountMismatch);
}
if let Some(&child_id) = self.nodes[node_id].children.get(&tokens[0]) { let node_key = hash(&tokens[..self.block_size]);
if let Some(&child_id) = self.nodes[node_id].children.get(&node_key) {
self.update_access_time(child_id); self.update_access_time(child_id);
let child = self let child = self
.nodes .nodes
...@@ -452,14 +477,15 @@ impl RadixTrie { ...@@ -452,14 +477,15 @@ impl RadixTrie {
.get_mut(node_id) .get_mut(node_id)
.expect("Node to-be split does not exist"); .expect("Node to-be split does not exist");
let mut parent_key = node.key.split_off(prefix_len); let mut parent_key = node.key.split_off(prefix_len);
let mut parent_blocks = node.blocks.split_off(prefix_len); let prefix_blocks = prefix_len / self.block_size;
let mut parent_blocks = node.blocks.split_off(prefix_blocks);
// Move first part of the prefix to the parent. We swap to avoid // Move first part of the prefix to the parent. We swap to avoid
// an allocation + copy for both splits of the key/blocks. // an allocation + copy for both splits of the key/blocks.
std::mem::swap(&mut node.key, &mut parent_key); std::mem::swap(&mut node.key, &mut parent_key);
std::mem::swap(&mut node.blocks, &mut parent_blocks); std::mem::swap(&mut node.blocks, &mut parent_blocks);
let node_key = node.key[0]; let node_key = hash(&node.key[..self.block_size]);
let grandparent_id = node.parent.expect("Node does not have a parent"); let grandparent_id = node.parent.expect("Node does not have a parent");
let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks); let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks);
...@@ -484,7 +510,7 @@ impl RadixTrie { ...@@ -484,7 +510,7 @@ impl RadixTrie {
) -> NodeId { ) -> NodeId {
let key = key.into(); let key = key.into();
let blocks = blocks.into(); let blocks = blocks.into();
let first = key[0]; let first = hash(&key[..self.block_size]);
let child = TrieNode::new(key, blocks, self.time, Some(parent_id)); let child = TrieNode::new(key, blocks, self.time, Some(parent_id));
let child_id = self.nodes.insert(child); let child_id = self.nodes.insert(child);
...@@ -496,10 +522,10 @@ impl RadixTrie { ...@@ -496,10 +522,10 @@ impl RadixTrie {
} }
/// Add a node to the parent. /// Add a node to the parent.
fn add_node_to_parent(&mut self, parent_id: NodeId, first: u32, child_id: NodeId) { fn add_node_to_parent(&mut self, parent_id: NodeId, hash: u64, child_id: NodeId) {
// Unwrap here, passing in an unknown id is a programming error. // Unwrap here, passing in an unknown id is a programming error.
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
if parent.children.insert(first, child_id).is_none() { if parent.children.insert(hash, child_id).is_none() {
// Only increase reference count if child does not replace another child. // Only increase reference count if child does not replace another child.
self.incref(parent_id) self.incref(parent_id)
.expect("Failed to increase parent refcount"); .expect("Failed to increase parent refcount");
...@@ -517,7 +543,9 @@ impl RadixTrie { ...@@ -517,7 +543,9 @@ impl RadixTrie {
); );
let parent_id = node.parent.expect("Attempted to remove root node"); let parent_id = node.parent.expect("Attempted to remove root node");
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
parent.children.remove(&node.key[0]);
let node_key = hash(&node.key[..self.block_size]);
parent.children.remove(&node_key);
self.decref(parent_id) self.decref(parent_id)
.expect("Failed to decrease parent refcount"); .expect("Failed to decrease parent refcount");
node node
...@@ -571,7 +599,7 @@ impl RadixTrie { ...@@ -571,7 +599,7 @@ impl RadixTrie {
#[derive(Debug)] #[derive(Debug)]
struct TrieNode { struct TrieNode {
blocks: Vec<u32>, blocks: Vec<u32>,
children: HashMap<u32, NodeId>, children: HashMap<u64, NodeId>,
key: Vec<u32>, key: Vec<u32>,
last_accessed: u64, last_accessed: u64,
parent: Option<NodeId>, parent: Option<NodeId>,
......
...@@ -853,11 +853,11 @@ ...@@ -853,11 +853,11 @@
] ]
}, },
"locked": { "locked": {
"lastModified": 1726021481, "lastModified": 1726280639,
"narHash": "sha256-4J4E+Fh+77XIYnq2RVtg+ENWXpu6t74P0jKN/f2RQmI=", "narHash": "sha256-YfLRPlFZWrT2oRLNAoqf7G3+NnUTDdlIJk6tmBU7kXM=",
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "1c2c120246c51a644c20ba2a36a33d3bd4860d70", "rev": "e9f8641c92f26fd1e076e705edb12147c384171d",
"type": "github" "type": "github"
}, },
"original": { "original": {
...@@ -978,11 +978,11 @@ ...@@ -978,11 +978,11 @@
"nixpkgs": "nixpkgs_6" "nixpkgs": "nixpkgs_6"
}, },
"locked": { "locked": {
"lastModified": 1725950569, "lastModified": 1726229792,
"narHash": "sha256-nJHA1SvIQbXySpL2ueNbzQOhnkQASa5tOLz/kdW0PWA=", "narHash": "sha256-9xsLmjc9nr7a4PTddKv2DOi82ompTtJNyjO6R67y5tE=",
"owner": "danieldk", "owner": "danieldk",
"repo": "tgi-nix", "repo": "tgi-nix",
"rev": "d40f3c22e9bcc5e16c94d4605cf6a7d74dd07f46", "rev": "1a902f4818e94c3f8d95f6000db17bc3fadd0ce7",
"type": "github" "type": "github"
}, },
"original": { "original": {
......
...@@ -342,6 +342,7 @@ def launcher(event_loop): ...@@ -342,6 +342,7 @@ def launcher(event_loop):
max_total_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None,
lora_adapters: Optional[List[str]] = None, lora_adapters: Optional[List[str]] = None,
cuda_graphs: Optional[List[int]] = None, cuda_graphs: Optional[List[int]] = None,
attention: Optional[str] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
master_port = random.randint(10_000, 20_000) master_port = random.randint(10_000, 20_000)
...@@ -401,6 +402,8 @@ def launcher(event_loop): ...@@ -401,6 +402,8 @@ def launcher(event_loop):
if not use_flash_attention: if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false" env["USE_FLASH_ATTENTION"] = "false"
if attention is not None:
env["ATTENTION"] = attention
with tempfile.TemporaryFile("w+") as tmp: with tempfile.TemporaryFile("w+") as tmp:
# We'll output stdout/stderr to a temporary file. Using a pipe # We'll output stdout/stderr to a temporary file. Using a pipe
...@@ -437,6 +440,7 @@ def launcher(event_loop): ...@@ -437,6 +440,7 @@ def launcher(event_loop):
max_total_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None,
lora_adapters: Optional[List[str]] = None, lora_adapters: Optional[List[str]] = None,
cuda_graphs: Optional[List[int]] = None, cuda_graphs: Optional[List[int]] = None,
attention: Optional[str] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
...@@ -491,6 +495,8 @@ def launcher(event_loop): ...@@ -491,6 +495,8 @@ def launcher(event_loop):
} }
if not use_flash_attention: if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false" env["USE_FLASH_ATTENTION"] = "false"
if attention is not None:
env["ATTENTION"] = attention
if HF_TOKEN is not None: if HF_TOKEN is not None:
env["HF_TOKEN"] = HF_TOKEN env["HF_TOKEN"] = HF_TOKEN
...@@ -522,6 +528,7 @@ def launcher(event_loop): ...@@ -522,6 +528,7 @@ def launcher(event_loop):
devices=devices, devices=devices,
volumes=volumes, volumes=volumes,
ports={"80/tcp": port}, ports={"80/tcp": port},
healthcheck={"timeout": int(10 * 1e9)},
shm_size="1G", shm_size="1G",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment