Unverified Commit e74bd41e authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(server): add paged attention to flash models (#516)

Closes #478
parent 70f485bf
......@@ -88,7 +88,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
RUN /opt/conda/bin/conda install -c "nvidia/label/cuda-11.8.0" cuda==11.8 && \
/opt/conda/bin/conda clean -ya
# Build Flash Attention CUDA kernels
FROM kernel-builder as flash-att-builder
......@@ -109,6 +108,16 @@ COPY server/custom_kernels/ .
# Build specific version of transformers
RUN python setup.py build
# Build vllm CUDA kernels
FROM kernel-builder as vllm-builder
WORKDIR /usr/src
COPY server/Makefile-vllm Makefile
# Build specific version of vllm
RUN make build-vllm
# Text Generation Inference base image
FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base
......@@ -137,9 +146,12 @@ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cp
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Copy build artifacts from transformers builder
# Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels
# Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Install flash-attention dependencies
RUN pip install einops --no-cache-dir
......
......@@ -43,8 +43,8 @@ to power LLMs api-inference widgets.
- Tensor Parallelism for faster inference on multiple GPUs
- Token streaming using Server-Sent Events (SSE)
- [Continuous batching of incoming requests](https://github.com/huggingface/text-generation-inference/tree/main/router) for increased total throughput
- Optimized transformers code for inference using [flash-attention](https://github.com/HazyResearch/flash-attention) on the most popular architectures
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
- Optimized transformers code for inference using [flash-attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and [GPT-Q](https://arxiv.org/abs/2210.17323)
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))
......
......@@ -13,6 +13,7 @@ async def flash_neox(flash_neox_handle):
return flash_neox_handle.client
@pytest.mark.skip
@pytest.mark.asyncio
async def test_flash_neox(flash_neox, response_snapshot):
response = await flash_neox.generate(
......@@ -25,6 +26,7 @@ async def test_flash_neox(flash_neox, response_snapshot):
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio
async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):
responses = await generate_load(
......
......@@ -115,12 +115,6 @@ struct Args {
#[clap(default_value = "1512", long, env)]
max_total_tokens: usize,
/// The maximum allowed batch size during dynamic batching.
/// Using `max_batch_total_tokens` should be favored in general
/// as it's a finer way to control RAM usage.
#[clap(long, env)]
max_batch_size: Option<usize>,
/// This represents the ratio of waiting queries vs running queries where
/// you want to start considering pausing the running queries to include the waiting
/// ones into the same batch.
......@@ -134,6 +128,12 @@ struct Args {
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
/// Limits the number of tokens for the prefill operation.
/// Since this operation take the most memory and is compute bound, it is interesting
/// to limit the number of requests that can be sent.
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
/// **IMPORTANT** This is one critical control to allow maximum usage
/// of the available hardware.
///
......@@ -146,19 +146,12 @@ struct Args {
/// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100`
/// or a single query of `1000` tokens.
///
/// So you don't have to control that finely
/// `max_batch_size` or `max_total_tokens`. In fact you could mostly relax them if you
/// want maximum flexibility. However, for your users if they are asking for the full amount of
/// total tokens, they are likely to wait for a very long time to get a spot
/// in the batch (since they are going to be alone) so setting `max_batch_size`
/// and `max_total_tokens` can still be useful to prevent those long waiting times.
///
/// Overall this number should be the largest possible amount that fits the
/// remaining memory (after the model is loaded). Since the actual memory overhead
/// depends on other parameters like if you're using quantization, flash attention
/// or the model implementation, text-generation-inference cannot infer this number
/// automatically.
#[clap(default_value = "32000", long, env)]
#[clap(default_value = "16000", long, env)]
max_batch_total_tokens: u32,
/// This setting defines how many tokens can be passed before forcing the waiting
......@@ -180,9 +173,9 @@ struct Args {
/// for end users.
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(default_value = "3000", long, short, env)]
/// The port to listen on.
#[clap(default_value = "3000", long, short, env)]
port: u16,
/// The name of the socket for gRPC communication between the webserver
......@@ -329,6 +322,12 @@ fn shard_manager(
// Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
// Use cuda allocator. It leads to less memory fragmentation
env.push((
"PYTORCH_CUDA_ALLOC_CONF".into(),
"backend:cudaMallocAsync".into(),
));
// Torch Distributed Env vars
env.push(("RANK".into(), rank.to_string().into()));
env.push(("WORLD_SIZE".into(), world_size.to_string().into()));
......@@ -446,7 +445,7 @@ fn shard_manager(
// We received a shutdown signal
if *shutdown.lock().unwrap() {
p.terminate().unwrap();
p.kill().unwrap();
let _ = p.wait_timeout(Duration::from_secs(90));
tracing::info!("Shard {rank} terminated");
return;
......@@ -822,6 +821,10 @@ fn spawn_webserver(
args.max_input_length.to_string(),
"--max-total-tokens".to_string(),
args.max_total_tokens.to_string(),
"--max-batch-prefill-tokens".to_string(),
args.max_batch_prefill_tokens.to_string(),
"--max-batch-total-tokens".to_string(),
args.max_batch_total_tokens.to_string(),
"--waiting-served-ratio".to_string(),
args.waiting_served_ratio.to_string(),
"--max-waiting-tokens".to_string(),
......@@ -834,15 +837,6 @@ fn spawn_webserver(
args.model_id,
];
// Deprecate max_batch_size
if let Some(max_batch_size) = args.max_batch_size {
argv.push("--max-batch-size".to_string());
argv.push(max_batch_size.to_string())
} else {
argv.push("--max-batch-total-tokens".to_string());
argv.push(args.max_batch_total_tokens.to_string())
}
// Model optional revision
if let Some(ref revision) = args.revision {
argv.push("--revision".to_string());
......
......@@ -11,6 +11,8 @@ service TextGenerationService {
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
/// Remove requests from a cached batch
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
/// Warmup the model and compute max cache size
rpc Warmup (WarmupRequest) returns (WarmupResponse);
/// Prefill batch and decode first token
rpc Prefill (PrefillRequest) returns (PrefillResponse);
/// Decode token for a list of prefilled batches
......@@ -192,3 +194,13 @@ message DecodeResponse {
/// Next batch (cached)
optional CachedBatch batch = 2;
}
message WarmupRequest {
/// Batch to warmup on
Batch batch = 1;
/// Maximum number of tokens that the client will send
uint32 max_total_tokens = 2;
}
/// Empty response
message WarmupResponse {}
......@@ -3,6 +3,7 @@ use crate::pb::generate::v1::text_generation_service_client::TextGenerationServi
use crate::pb::generate::v1::*;
use crate::Result;
use grpc_metadata::InjectTelemetryContext;
use std::cmp::min;
use tonic::transport::{Channel, Uri};
use tracing::instrument;
......@@ -94,6 +95,63 @@ impl Client {
Ok(filtered_batch.batch)
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
) -> Result<()> {
let mut n_tokens = 0;
let mut requests = Vec::new();
// Create requests
while n_tokens < max_prefill_tokens {
requests.push(Request {
id: 0,
// We truncate the input on the server side to be sure that it has the correct size
inputs: "_test ".to_string().repeat(max_input_length as usize),
truncate: min(max_input_length, max_prefill_tokens - n_tokens),
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
top_k: 10,
top_p: 0.9,
typical_p: 0.9,
do_sample: false,
seed: 0,
repetition_penalty: 1.2,
watermark: true,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 2,
stop_sequences: vec![],
ignore_eos_token: false,
}),
prefill_logprobs: true,
});
n_tokens += max_input_length;
}
let batch = Batch {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: 0,
};
let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_total_tokens,
})
.inject_context();
self.stub.warmup(request).await?.into_inner();
Ok(())
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
......
......@@ -87,6 +87,27 @@ impl ShardedClient {
join_all(futures).await.pop().unwrap()
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
) -> Result<()> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| {
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
})
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
......
......@@ -45,6 +45,7 @@ impl Infer {
client: ShardedClient,
validation: Validation,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_concurrent_requests: usize,
......@@ -61,6 +62,7 @@ impl Infer {
tokio::spawn(batching_task(
client,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
queue.clone(),
......@@ -240,9 +242,11 @@ impl Infer {
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
#[allow(clippy::too_many_arguments)]
async fn batching_task(
mut client: ShardedClient,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
queue: Queue,
......@@ -257,8 +261,9 @@ async fn batching_task(
// Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue
while let Some((mut entries, batch, span)) =
queue.next_batch(None, max_batch_total_tokens).await
while let Some((mut entries, batch, span)) = queue
.next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens)
.await
{
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
.instrument(span)
......@@ -284,11 +289,12 @@ async fn batching_task(
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
};
let token_budget = max_batch_total_tokens - batch_max_tokens;
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) =
queue.next_batch(min_size, token_budget).await
if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_batch_prefill_tokens, token_budget)
.await
{
// Tracking metrics
if min_size.is_some() {
......
......@@ -32,10 +32,10 @@ struct Args {
max_input_length: usize,
#[clap(default_value = "1512", long, env)]
max_total_tokens: usize,
#[clap(long, env)]
max_batch_size: Option<usize>,
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
#[clap(default_value = "32000", long, env)]
max_batch_total_tokens: u32,
#[clap(default_value = "20", long, env)]
......@@ -78,9 +78,9 @@ fn main() -> Result<(), std::io::Error> {
max_stop_sequences,
max_input_length,
max_total_tokens,
max_batch_size,
waiting_served_ratio,
mut max_batch_total_tokens,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
port,
master_shard_uds_path,
......@@ -141,12 +141,6 @@ fn main() -> Result<(), std::io::Error> {
.block_on(async {
init_logging(otlp_endpoint, json_output);
if let Some(max_batch_size) = max_batch_size {
tracing::warn!("`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead");
max_batch_total_tokens = (max_batch_size * max_total_tokens) as u32;
tracing::warn!("Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}");
}
if tokenizer.is_none() {
tracing::warn!(
"Could not find a fast tokenizer implementation for {tokenizer_name}"
......@@ -161,10 +155,16 @@ fn main() -> Result<(), std::io::Error> {
sha: None,
pipeline_tag: None,
},
false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or_else(|| {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None }
}),
false => get_model_info(&tokenizer_name, &revision, authorization_token)
.await
.unwrap_or_else(|| {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
HubModelInfo {
model_id: tokenizer_name.to_string(),
sha: None,
pipeline_tag: None,
}
}),
};
// if pipeline-tag == text-generation we default to return_full_text = true
......@@ -190,6 +190,17 @@ fn main() -> Result<(), std::io::Error> {
.info()
.await
.expect("Unable to get shard info");
// Warmup model
tracing::info!("Warming up model");
sharded_client
.warmup(
max_input_length as u32,
max_batch_prefill_tokens,
max_batch_total_tokens,
)
.await
.expect("Unable to warmup model");
tracing::info!("Connected");
// Binds on localhost
......@@ -206,6 +217,7 @@ fn main() -> Result<(), std::io::Error> {
max_input_length,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
sharded_client,
......@@ -219,7 +231,7 @@ fn main() -> Result<(), std::io::Error> {
ngrok_username,
ngrok_password,
)
.await;
.await;
Ok(())
})
}
......
......@@ -58,6 +58,7 @@ impl Queue {
pub(crate) async fn next_batch(
&self,
min_size: Option<usize>,
prefill_token_budget: u32,
token_budget: u32,
) -> Option<NextBatch> {
// Create response channel
......@@ -67,6 +68,7 @@ impl Queue {
self.queue_sender
.send(QueueCommand::NextBatch {
min_size,
prefill_token_budget,
token_budget,
response_sender,
span: Span::current(),
......@@ -90,11 +92,12 @@ async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueComma
}
QueueCommand::NextBatch {
min_size,
prefill_token_budget,
token_budget,
response_sender,
span,
} => span.in_scope(|| {
let next_batch = state.next_batch(min_size, token_budget);
let next_batch = state.next_batch(min_size, prefill_token_budget, token_budget);
response_sender.send(next_batch).unwrap();
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
}),
......@@ -140,7 +143,12 @@ impl State {
}
// Get the next batch
fn next_batch(&mut self, min_size: Option<usize>, token_budget: u32) -> Option<NextBatch> {
fn next_batch(
&mut self,
min_size: Option<usize>,
prefill_token_budget: u32,
token_budget: u32,
) -> Option<NextBatch> {
if self.entries.is_empty() {
return None;
}
......@@ -184,7 +192,9 @@ impl State {
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
if (prefill_tokens + decode_tokens) > token_budget {
if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens) > token_budget
{
// Entry is over budget
// Add it back to the front
self.entries.push_front((id, entry));
......@@ -259,6 +269,7 @@ enum QueueCommand {
Append(Box<Entry>, Span),
NextBatch {
min_size: Option<usize>,
prefill_token_budget: u32,
token_budget: u32,
response_sender: oneshot::Sender<Option<NextBatch>>,
span: Span,
......@@ -328,8 +339,8 @@ mod tests {
fn test_next_batch_empty() {
let mut state = State::new(false);
assert!(state.next_batch(None, 1).is_none());
assert!(state.next_batch(Some(1), 1).is_none());
assert!(state.next_batch(None, 1, 1).is_none());
assert!(state.next_batch(Some(1), 1, 1).is_none());
}
#[test]
......@@ -340,7 +351,7 @@ mod tests {
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, 2).unwrap();
let (entries, batch, _) = state.next_batch(None, 2, 2).unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1));
......@@ -356,7 +367,7 @@ mod tests {
let (entry3, _guard3) = default_entry();
state.append(entry3);
assert!(state.next_batch(Some(2), 2).is_none());
assert!(state.next_batch(Some(2), 2, 2).is_none());
assert_eq!(state.next_id, 3);
assert_eq!(state.entries.len(), 1);
......@@ -372,7 +383,7 @@ mod tests {
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, 1).unwrap();
let (entries, batch, _) = state.next_batch(None, 1, 1).unwrap();
assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0);
......@@ -385,7 +396,7 @@ mod tests {
let (entry3, _guard3) = default_entry();
state.append(entry3);
let (entries, batch, _) = state.next_batch(None, 3).unwrap();
let (entries, batch, _) = state.next_batch(None, 3, 3).unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2));
......@@ -408,8 +419,8 @@ mod tests {
async fn test_queue_next_batch_empty() {
let queue = Queue::new(false);
assert!(queue.next_batch(None, 1).await.is_none());
assert!(queue.next_batch(Some(1), 1).await.is_none());
assert!(queue.next_batch(None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
}
#[tokio::test]
......@@ -420,7 +431,7 @@ mod tests {
queue.append(entry1);
queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap();
let (entries, batch, _) = queue.next_batch(None, 2, 2).await.unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1));
......@@ -433,11 +444,11 @@ mod tests {
queue.append(entry3);
// Not enough requests pending
assert!(queue.next_batch(Some(2), 2).await.is_none());
assert!(queue.next_batch(Some(2), 2, 2).await.is_none());
// Not enough token budget
assert!(queue.next_batch(Some(1), 0).await.is_none());
assert!(queue.next_batch(Some(1), 0, 0).await.is_none());
// Ok
let (entries2, batch2, _) = queue.next_batch(Some(1), 2).await.unwrap();
let (entries2, batch2, _) = queue.next_batch(Some(1), 2, 2).await.unwrap();
assert_eq!(entries2.len(), 1);
assert!(entries2.contains_key(&2));
assert!(entries2.get(&2).unwrap().batch_time.is_some());
......@@ -453,7 +464,7 @@ mod tests {
queue.append(entry1);
queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap();
let (entries, batch, _) = queue.next_batch(None, 1, 1).await.unwrap();
assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0);
......@@ -462,7 +473,7 @@ mod tests {
let (entry3, _guard3) = default_entry();
queue.append(entry3);
let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap();
let (entries, batch, _) = queue.next_batch(None, 3, 3).await.unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2));
......@@ -476,6 +487,6 @@ mod tests {
let (entry, _) = default_entry();
queue.append(entry);
assert!(queue.next_batch(None, 1).await.is_none());
assert!(queue.next_batch(None, 1, 1).await.is_none());
}
}
......@@ -514,6 +514,7 @@ pub async fn run(
max_input_length: usize,
max_total_tokens: usize,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
client: ShardedClient,
......@@ -582,6 +583,7 @@ pub async fn run(
client,
validation,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_concurrent_requests,
......
vllm_commit := d284b831c17f42a8ea63369a06138325f73c4cf9
vllm:
# Clone vllm
git clone https://github.com/OlivierDehaene/vllm.git
build-vllm: vllm
cd vllm && git fetch && git checkout $(vllm_commit)
cd vllm && python setup.py build
install-vllm: build-vllm
pip uninstall vllm -y || true
cd vllm && python setup.py install
\ No newline at end of file
......@@ -22,7 +22,9 @@ class Cache:
del batch
def clear(self):
self.cache.clear()
keys = list(self.cache.keys())
for k in keys:
self.delete(k)
def __len__(self):
return len(self.cache.keys())
......@@ -122,7 +122,7 @@ class CausalLMBatch(Batch):
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
max_tokens = len(inputs) * max_input_length + max_decode_tokens
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
return cls(
batch_id=pb.id,
......
......@@ -23,12 +23,16 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional
from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
import dropout_layer_norm
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
......@@ -106,7 +110,7 @@ class FlashLlamaAttention(torch.nn.Module):
prefix=f"{prefix}.rotary_emb", weights=weights
)
self.softmax_scale = self.head_size ** (-0.5)
self.softmax_scale = self.head_size**-0.5
self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = TensorParallelColumnLinear.load_multi(
......@@ -122,20 +126,22 @@ class FlashLlamaAttention(torch.nn.Module):
weights=weights,
bias=False,
)
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward(
self,
hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
......@@ -144,23 +150,25 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill
if prefill:
# Copy to layer past
layer_past[...] = qkv[:, 1:]
vllm_cache_ops.reshape_and_cache(
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
)
# output
attn_output = torch.empty_like(qkv[:, 0])
# output tensor
attn_output = torch.empty_like(qkv[:, 0])
# Prefill
if start_seq_prefill is not None:
# flash attention
flash_attn_cuda.fwd(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
attn_output,
start_seq,
end_seq,
start_seq,
end_seq,
start_seq_prefill,
end_seq_prefill,
start_seq_prefill,
end_seq_prefill,
max_s,
max_s,
0.0,
......@@ -173,31 +181,19 @@ class FlashLlamaAttention(torch.nn.Module):
)
# Decode
else:
query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices
layer_past[past_present_indices] = qkv[:, 1:]
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
layer_past[:, 0],
layer_past[:, 1],
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
attn_output,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
qkv[:, 0],
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
False,
False,
False,
0,
None,
block_tables,
input_lengths,
block_size,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
......@@ -265,14 +261,13 @@ class FlashLlamaLayer(nn.Module):
residual,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
......@@ -281,14 +276,13 @@ class FlashLlamaLayer(nn.Module):
normed_hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
)
# faster post attention rms norm
......@@ -333,40 +327,18 @@ class FlashLlamaModel(torch.nn.Module):
def forward(
self,
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty(
(
len(input_ids),
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# Decode
else:
prefill = False
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
......@@ -380,34 +352,18 @@ class FlashLlamaModel(torch.nn.Module):
residual,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
past_key_values[:, i],
past_present_indices,
prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states, past_key_values
return hidden_states
class FlashLlamaForCausalLM(torch.nn.Module):
......@@ -423,31 +379,29 @@ class FlashLlamaForCausalLM(torch.nn.Module):
def forward(
self,
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
lm_head_indices: Optional[torch.Tensor] = None,
):
hidden_states, present = self.model(
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits, present
return logits
......@@ -25,11 +25,15 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig
from typing import Optional
from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
......@@ -110,20 +114,22 @@ class FlashNeoxAttention(torch.nn.Module):
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=True
)
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward(
self,
hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
......@@ -132,23 +138,25 @@ class FlashNeoxAttention(torch.nn.Module):
self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill
if prefill:
# Copy to layer past
layer_past[...] = qkv[:, 1:]
vllm_cache_ops.reshape_and_cache(
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
)
# output
attn_output = torch.empty_like(qkv[:, 0])
# output tensor
attn_output = torch.empty_like(qkv[:, 0])
# Prefill
if start_seq_prefill is not None:
# flash attention
flash_attn_cuda.fwd(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
attn_output,
start_seq,
end_seq,
start_seq,
end_seq,
start_seq_prefill,
end_seq_prefill,
start_seq_prefill,
end_seq_prefill,
max_s,
max_s,
0.0,
......@@ -161,31 +169,19 @@ class FlashNeoxAttention(torch.nn.Module):
)
# Decode
else:
query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices
layer_past[past_present_indices] = qkv[:, 1:]
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
layer_past[:, 0],
layer_past[:, 1],
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
attn_output,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
qkv[:, 0],
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
False,
False,
False,
0,
None,
block_tables,
input_lengths,
block_size,
max_s,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
......@@ -250,14 +246,13 @@ class FlashNeoXLayer(nn.Module):
residual,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
if self.use_parallel_residual:
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
......@@ -266,14 +261,13 @@ class FlashNeoXLayer(nn.Module):
ln1_hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
)
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
......@@ -292,14 +286,13 @@ class FlashNeoXLayer(nn.Module):
hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
)
hidden_states, residual = self.post_attention_layernorm(
......@@ -346,40 +339,18 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
def forward(
self,
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_in(input_ids)
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty(
(
len(input_ids),
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# Decode
else:
prefill = False
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(
......@@ -393,34 +364,18 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
residual,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
past_key_values[:, i],
past_present_indices,
prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
return hidden_states, past_key_values
return hidden_states
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
......@@ -434,31 +389,29 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
def forward(
self,
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
lm_head_indices: Optional[torch.Tensor] = None,
):
hidden_states, present = self.gpt_neox(
) -> torch.Tensor:
hidden_states = self.gpt_neox(
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.embed_out(hidden_states)
return logits, present
return logits
......@@ -4,11 +4,15 @@ import torch.distributed
from torch import nn
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from typing import Optional
from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
......@@ -126,19 +130,27 @@ class FlashRWAttention(torch.nn.Module):
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
)
if self.num_heads_kv == 1:
self.kv_head_mapping = torch.zeros(
self.num_heads, dtype=torch.int32, device=weights.device
)
else:
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward(
self,
hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
qkv = self.query_key_value(hidden_states)
......@@ -156,25 +168,29 @@ class FlashRWAttention(torch.nn.Module):
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
vllm_cache_ops.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
# output
attn_output = torch.empty_like(query)
# Prefill
if prefill:
# Copy to layer past
layer_past[...] = kv
# Expand to query shape
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
if start_seq_prefill is not None:
if self.num_heads_kv == 1:
# Expand to query shape
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
start_seq,
end_seq,
start_seq,
end_seq,
start_seq_prefill,
end_seq_prefill,
start_seq_prefill,
end_seq_prefill,
max_s,
max_s,
0.0,
......@@ -187,32 +203,19 @@ class FlashRWAttention(torch.nn.Module):
)
# Decode
else:
# Add present to the layer_past tensor at the correct indices
layer_past[past_present_indices] = kv
# Expand to query shape
kv = layer_past.expand(-1, 2, self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
# kv_cache[1] => [num_blocks, num_heads_kv, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
attn_output,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
False,
False,
False,
0,
None,
block_tables,
input_lengths,
block_size,
max_s,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
......@@ -264,19 +267,22 @@ class FlashRWLargeAttention(torch.nn.Module):
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
)
self.kv_head_mapping = torch.arange(
0, self.num_groups, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_heads)
def forward(
self,
hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
......@@ -293,10 +299,19 @@ class FlashRWLargeAttention(torch.nn.Module):
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
vllm_cache_ops.reshape_and_cache(
kv[:, :, 0].contiguous(),
kv[:, :, 1].contiguous(),
kv_cache[0],
kv_cache[1],
slots,
)
# output
attn_output = torch.empty_like(query)
# Prefill
if prefill:
# Copy to layer past
layer_past[...] = kv
if start_seq_prefill is not None:
# Expand to query shape
kv = (
kv.unsqueeze(2)
......@@ -304,18 +319,16 @@ class FlashRWLargeAttention(torch.nn.Module):
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
attn_output,
start_seq,
end_seq,
start_seq,
end_seq,
start_seq_prefill,
end_seq_prefill,
start_seq_prefill,
end_seq_prefill,
max_s,
max_s,
0.0,
......@@ -328,36 +341,19 @@ class FlashRWLargeAttention(torch.nn.Module):
)
# Decode
else:
# Add present to the layer_past tensor at the correct indices
layer_past[past_present_indices] = kv
# Expand to query shape
kv = (
layer_past.unsqueeze(2)
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
# kv_cache[1] => [num_blocks, num_groups, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
attn_output,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
False,
False,
False,
0,
None,
block_tables,
input_lengths,
block_size,
max_s,
)
return self.dense(
......@@ -432,14 +428,13 @@ class FlashRWLayer(nn.Module):
residual,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
if self.parallel_attn:
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
......@@ -448,14 +443,13 @@ class FlashRWLayer(nn.Module):
ln_hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
)
mlp_output = self.mlp(ln_hidden_states)
......@@ -472,14 +466,13 @@ class FlashRWLayer(nn.Module):
hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
)
hidden_states, residual = self.post_attention_layernorm(
......@@ -523,14 +516,13 @@ class FlashRWLargeLayer(nn.Module):
residual,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
ln_attn, residual = self.ln_attn(hidden_states, residual)
ln_mlp, _ = self.ln_mlp(residual)
......@@ -540,14 +532,13 @@ class FlashRWLargeLayer(nn.Module):
ln_attn,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
)
# MLP.
......@@ -580,11 +571,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
for layer_id in range(config.num_hidden_layers)
]
)
self.cache_size = (
2,
self.h[0].self_attention.num_heads_kv,
self.h[0].self_attention.head_size,
)
self.cache_size = self.h[0].self_attention.num_heads_kv
elif config.model_type == "RefinedWeb":
self.h = nn.ModuleList(
[
......@@ -592,11 +579,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
for layer_id in range(config.num_hidden_layers)
]
)
self.cache_size = (
self.h[0].self_attention.num_groups,
2,
self.h[0].self_attention.head_size,
)
self.cache_size = self.h[0].self_attention.num_groups
else:
raise NotImplementedError(
f"model_type {config.model_type} is not supported."
......@@ -612,38 +595,18 @@ class FlashRWModel(FlashRWPreTrainedModel):
def forward(
self,
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty(
(
len(input_ids),
len(self.h),
*self.cache_size,
)
)
# Decode
else:
prefill = False
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(
......@@ -657,32 +620,18 @@ class FlashRWModel(FlashRWPreTrainedModel):
residual,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
torch.select(past_key_values, dim=1, index=i),
past_present_indices,
prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.h),
*self.cache_size,
)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states, past_key_values
return hidden_states
class FlashRWForCausalLM(FlashRWPreTrainedModel):
......@@ -697,31 +646,29 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
def forward(
self,
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
lm_head_indices: Optional[torch.Tensor] = None,
):
hidden_states, present = self.transformer(
) -> torch.Tensor:
hidden_states = self.transformer(
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits, present
return logits
......@@ -3,11 +3,15 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional
from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
......@@ -221,18 +225,20 @@ class FlashMQAttention(torch.nn.Module):
self.c_proj = load_row(
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
)
self.kv_head_mapping = torch.zeros(
self.num_heads, dtype=torch.int32, device=weights.device
)
def forward(
self,
hidden_states,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
qkv = self.c_attn(hidden_states)
......@@ -245,25 +251,28 @@ class FlashMQAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size)
vllm_cache_ops.reshape_and_cache(
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
)
# output
attn_output = torch.empty_like(query)
# Prefill
if prefill:
# Copy to layer past
layer_past[...] = key_value
if start_seq_prefill is not None:
# Expand from 1 to num_heads
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
attn_output,
start_seq,
end_seq,
start_seq,
end_seq,
start_seq_prefill,
end_seq_prefill,
start_seq_prefill,
end_seq_prefill,
max_s,
max_s,
0.0,
......@@ -276,32 +285,19 @@ class FlashMQAttention(torch.nn.Module):
)
# Decode
else:
# Add present to the layer_past tensor at the correct indices
layer_past[past_present_indices] = key_value
# Expand from 1 to num_heads
key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
# kv_cache[1] => [num_blocks, 1, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
attn_output,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
False,
False,
False,
0,
None,
block_tables,
input_lengths,
block_size,
max_s,
)
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
......@@ -361,27 +357,25 @@ class Block(nn.Module):
self,
hidden_states,
residual,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn(
hidden_states,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
)
hidden_states, residual = self.ln_2(hidden_states, residual)
......@@ -427,64 +421,38 @@ class FlashSantacoderModel(nn.Module):
def forward(
self,
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
) -> torch.Tensor:
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
if self.process_group.size() > 1:
torch.distributed.all_reduce(hidden_states, group=self.process_group)
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_zeros(
(len(input_ids), len(self.h), 2, 1, self.head_size)
)
# Decode
else:
prefill = False
residual = None
for i, layer in enumerate(self.h):
hidden_states, residual = layer(
hidden_states,
residual,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
torch.select(past_key_values, dim=1, index=i),
past_present_indices,
prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(pre_allocate_past_size, len(self.h), 2, 1, self.head_size)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states, past_key_values
return hidden_states
class FlashSantacoderForCausalLM(nn.Module):
......@@ -497,31 +465,29 @@ class FlashSantacoderForCausalLM(nn.Module):
def forward(
self,
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
lm_head_indices: Optional[torch.Tensor] = None,
):
hidden_states, present = self.transformer(
) -> torch.Tensor:
hidden_states = self.transformer(
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits, present
return logits
......@@ -1004,7 +1004,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
try:
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
except RuntimeError:
self.shared = TensorParallelEmbedding(prefix="encoder.embed_tokens", weights=weights)
self.shared = TensorParallelEmbedding(
prefix="encoder.embed_tokens", weights=weights
)
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment