Unverified Commit 38fcafcf authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding a test for FD. (#2516)

* Adding a test for FD.

* Fixing flashdecoding (empty batch doesn't work).

* Fixing the invalid popping.

* Fixing radix with block_size > 1

* Last reference.

* Use an actual hash.

* Update hash for slice.len() == 1

* Update the locks.

* Increasing docker timeout.
parent 77746552
This diff is collapsed.
......@@ -364,7 +364,7 @@ impl State {
// Add it back to the front
tracing::debug!("Over budget: not enough free blocks");
self.entries.push_front((id, entry));
break;
continue;
}
Some(block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}");
......@@ -436,6 +436,12 @@ impl State {
batch_entries.insert(id, entry);
}
// Empty batch
if batch_requests.is_empty() {
tracing::debug!("Filterered out all entries");
return None;
}
// Final batch size
let size = batch_requests.len() as u32;
next_batch_span.record("batch_size", size);
......
use crate::block_allocator::{Allocator, BlockAllocation};
use slotmap::{DefaultKey, SlotMap};
use std::hash::{Hash, Hasher};
use std::{
collections::{BTreeSet, HashMap},
sync::Arc,
};
fn hash(slice: &[u32]) -> u64 {
assert!(!slice.is_empty());
if slice.len() == 1 {
slice[0] as u64
} else {
let mut s = std::hash::DefaultHasher::new();
slice.hash(&mut s);
s.finish()
}
}
pub struct RadixAllocator {
allocation_id: u64,
......@@ -44,6 +56,10 @@ impl RadixAllocator {
// the free list if we cannot allocate enough blocks. This is only
// temporary, the trie needs to be able to report whether it can
// allocate the requested amount. Just not implemented yet.
tracing::debug!(
"Free blocks {} need {n_blocks_needed}",
self.free_blocks.len()
);
self.free_blocks.extend(
self.cache_blocks
.evict(n_blocks_needed - self.free_blocks.len()),
......@@ -94,6 +110,9 @@ impl Allocator for RadixAllocator {
match self.alloc_or_reclaim(suffix_blocks as usize) {
Some(suffix_blocks) => blocks.extend(suffix_blocks),
None => {
tracing::debug!("Cannot allocate {:?}", self.cache_blocks);
tracing::debug!("Found {prefix_len} prefix tokens need {suffix_blocks} suffix blocks for {tokens} tokens");
tracing::debug!("Block size {}", self.block_size);
self.cache_blocks
.decref(prefix_node)
.expect("Failed to decrement refcount");
......@@ -211,7 +230,6 @@ struct RadixAllocation {
pub enum TrieError {
InvalidNodeId,
RefCountUnderflow,
BlockTokenCountMismatch,
}
pub type NodeId = DefaultKey;
......@@ -268,7 +286,9 @@ impl RadixTrie {
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
let node = &self.nodes[node_id];
if let Some(&child_id) = node.children.get(&key[0]) {
if key.len() >= self.block_size {
let node_key = hash(&key[..self.block_size]);
if let Some(&child_id) = node.children.get(&node_key) {
self.update_access_time(child_id);
let child = self.nodes.get(child_id).expect("Invalid child identifier");
let shared_prefix_len = shared_prefix(&child.key, key, self.block_size);
......@@ -280,6 +300,7 @@ impl RadixTrie {
node_id = self.find_(child_id, key, blocks);
}
}
}
node_id
}
......@@ -344,9 +365,11 @@ impl RadixTrie {
// evict n_blocks and return `None` if we can't. We are now needlessly
// evicting prefixes from the cache in such a case.
let mut evicted = Vec::new();
tracing::debug!("Evicting in search of {n_blocks}");
while let Some((last_access, node_id)) = self.leaves.pop_first() {
let blocks_needed = n_blocks - evicted.len();
let blocks_needed = n_blocks.saturating_sub(evicted.len());
tracing::debug!("Evicting node {node_id:?} ");
let node = self.nodes.get(node_id).expect("Leave does not exist");
assert_eq!(
......@@ -368,8 +391,11 @@ impl RadixTrie {
// the required number of blocks and leave the remaining blocks
// untouched.
let node = self.nodes.get_mut(node_id).expect("Leave does not exist");
node.key.truncate(node.blocks.len() - blocks_needed);
evicted.extend(node.blocks.split_off(node.blocks.len() - blocks_needed));
let truncate_blocks = node.blocks.len() - blocks_needed;
let truncate_tokens = truncate_blocks * self.block_size;
node.key.truncate(truncate_tokens);
evicted.extend(node.blocks.split_off(truncate_blocks));
self.leaves.insert((last_access, node_id));
break;
}
......@@ -400,11 +426,10 @@ impl RadixTrie {
// the part of the prefix that is already in the trie to detect
// mismatches.
if tokens.len() != blocks.len() * self.block_size {
return Err(TrieError::BlockTokenCountMismatch);
}
assert_eq!(tokens.len(), blocks.len() * self.block_size);
if let Some(&child_id) = self.nodes[node_id].children.get(&tokens[0]) {
let node_key = hash(&tokens[..self.block_size]);
if let Some(&child_id) = self.nodes[node_id].children.get(&node_key) {
self.update_access_time(child_id);
let child = self
.nodes
......@@ -452,14 +477,15 @@ impl RadixTrie {
.get_mut(node_id)
.expect("Node to-be split does not exist");
let mut parent_key = node.key.split_off(prefix_len);
let mut parent_blocks = node.blocks.split_off(prefix_len);
let prefix_blocks = prefix_len / self.block_size;
let mut parent_blocks = node.blocks.split_off(prefix_blocks);
// Move first part of the prefix to the parent. We swap to avoid
// an allocation + copy for both splits of the key/blocks.
std::mem::swap(&mut node.key, &mut parent_key);
std::mem::swap(&mut node.blocks, &mut parent_blocks);
let node_key = node.key[0];
let node_key = hash(&node.key[..self.block_size]);
let grandparent_id = node.parent.expect("Node does not have a parent");
let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks);
......@@ -484,7 +510,7 @@ impl RadixTrie {
) -> NodeId {
let key = key.into();
let blocks = blocks.into();
let first = key[0];
let first = hash(&key[..self.block_size]);
let child = TrieNode::new(key, blocks, self.time, Some(parent_id));
let child_id = self.nodes.insert(child);
......@@ -496,10 +522,10 @@ impl RadixTrie {
}
/// Add a node to the parent.
fn add_node_to_parent(&mut self, parent_id: NodeId, first: u32, child_id: NodeId) {
fn add_node_to_parent(&mut self, parent_id: NodeId, hash: u64, child_id: NodeId) {
// Unwrap here, passing in an unknown id is a programming error.
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
if parent.children.insert(first, child_id).is_none() {
if parent.children.insert(hash, child_id).is_none() {
// Only increase reference count if child does not replace another child.
self.incref(parent_id)
.expect("Failed to increase parent refcount");
......@@ -517,7 +543,9 @@ impl RadixTrie {
);
let parent_id = node.parent.expect("Attempted to remove root node");
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
parent.children.remove(&node.key[0]);
let node_key = hash(&node.key[..self.block_size]);
parent.children.remove(&node_key);
self.decref(parent_id)
.expect("Failed to decrease parent refcount");
node
......@@ -571,7 +599,7 @@ impl RadixTrie {
#[derive(Debug)]
struct TrieNode {
blocks: Vec<u32>,
children: HashMap<u32, NodeId>,
children: HashMap<u64, NodeId>,
key: Vec<u32>,
last_accessed: u64,
parent: Option<NodeId>,
......
......@@ -853,11 +853,11 @@
]
},
"locked": {
"lastModified": 1726021481,
"narHash": "sha256-4J4E+Fh+77XIYnq2RVtg+ENWXpu6t74P0jKN/f2RQmI=",
"lastModified": 1726280639,
"narHash": "sha256-YfLRPlFZWrT2oRLNAoqf7G3+NnUTDdlIJk6tmBU7kXM=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "1c2c120246c51a644c20ba2a36a33d3bd4860d70",
"rev": "e9f8641c92f26fd1e076e705edb12147c384171d",
"type": "github"
},
"original": {
......@@ -978,11 +978,11 @@
"nixpkgs": "nixpkgs_6"
},
"locked": {
"lastModified": 1725950569,
"narHash": "sha256-nJHA1SvIQbXySpL2ueNbzQOhnkQASa5tOLz/kdW0PWA=",
"lastModified": 1726229792,
"narHash": "sha256-9xsLmjc9nr7a4PTddKv2DOi82ompTtJNyjO6R67y5tE=",
"owner": "danieldk",
"repo": "tgi-nix",
"rev": "d40f3c22e9bcc5e16c94d4605cf6a7d74dd07f46",
"rev": "1a902f4818e94c3f8d95f6000db17bc3fadd0ce7",
"type": "github"
},
"original": {
......
......@@ -342,6 +342,7 @@ def launcher(event_loop):
max_total_tokens: Optional[int] = None,
lora_adapters: Optional[List[str]] = None,
cuda_graphs: Optional[List[int]] = None,
attention: Optional[str] = None,
):
port = random.randint(8000, 10_000)
master_port = random.randint(10_000, 20_000)
......@@ -401,6 +402,8 @@ def launcher(event_loop):
if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false"
if attention is not None:
env["ATTENTION"] = attention
with tempfile.TemporaryFile("w+") as tmp:
# We'll output stdout/stderr to a temporary file. Using a pipe
......@@ -437,6 +440,7 @@ def launcher(event_loop):
max_total_tokens: Optional[int] = None,
lora_adapters: Optional[List[str]] = None,
cuda_graphs: Optional[List[int]] = None,
attention: Optional[str] = None,
):
port = random.randint(8000, 10_000)
......@@ -491,6 +495,8 @@ def launcher(event_loop):
}
if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false"
if attention is not None:
env["ATTENTION"] = attention
if HF_TOKEN is not None:
env["HF_TOKEN"] = HF_TOKEN
......@@ -522,6 +528,7 @@ def launcher(event_loop):
devices=devices,
volumes=volumes,
ports={"80/tcp": port},
healthcheck={"timeout": int(10 * 1e9)},
shm_size="1G",
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment