Unverified Commit 4327210e authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

[Major Change][Undecided yet] Move to FlashDecoding instead of PagedAttention kernel. (#1940)

* Using flash decoding

Conditional flashdecoding.

Fix max_q.

Working kvcache

Working version with flash decoding.

Make it work for mistral.

Fix after rebase..

Less intrusive.

REvert changes in modeling.

Speedup flashdecoding.

HHachweew
Hack to make other models work.

Fixing non flash decoding llama path.

Router logic knows about page size.

Missing 2 models.

Missing cohere.

Fixing cohere flash decoding.

Revamped all this architecture.

Fix cohere.

Fixing falcon.

Enabling custom block size schedule.

Update router/src/infer.rs

Not sending preallocated output.

* Making it work on non flash decoding.

* Fix Cohere.

* Fix non decoding paths.

* Rebased.

* No need for cache_manager anymore.

* Update?

* "ipex" -> "cpu"

* These do not belong.

* Factoring cu_seqlen_qk for better abstracting over every model.

* Fixing non flash tests/imports.

* Changing return everywhere.

* Update mistral past.

* Fixing Mi{s,x}tral (non functional in Flash Decoding mode though).

* Fixup mistral clamping (had issues with cuda graphs).

* No need to recreate anything actually.
parent 4f55f158
......@@ -39,7 +39,14 @@ impl SchedulerV2 {
speculate: u32,
generation_health: Arc<AtomicBool>,
) -> Self {
let queue = Queue::new(requires_padding, 16, window_size, speculate);
// Infer shared state
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
} else {
false
};
let block_size = if flashdecoding { 256 } else { 16 };
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let batching_task_notifier = Arc::new(Notify::new());
// Spawn batching background task that contains all the inference logic
......
......@@ -39,9 +39,15 @@ impl SchedulerV3 {
speculate: u32,
generation_health: Arc<AtomicBool>,
) -> Self {
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
} else {
false
};
let block_size = if flashdecoding { 256 } else { 16 };
let queue = Queue::new(
requires_padding,
16,
block_size,
window_size,
speculate,
max_batch_total_tokens,
......
from text_generation_server.utils.import_utils import SYSTEM
import os
from .common import Seqlen
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda":
......
from dataclasses import dataclass
from text_generation_server.models.globals import FLASH_DECODING
import torch
from typing import Optional
if FLASH_DECODING:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]
def __init__(self, input_lengths):
self.input_lengths = input_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
def clamp(self, max):
# Flash decoding doesn't need to clamp
return self
else:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
def clamp(self, max):
return Seqlen(torch.clamp(self.input_lengths, max=max))
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
......@@ -21,7 +23,14 @@ def reshape_and_cache(
value_cache: torch.Tensor,
slots: torch.Tensor,
):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
if FLASH_DECODING:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
def paged_attention(
......@@ -32,7 +41,7 @@ def paged_attention(
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
......@@ -53,7 +62,8 @@ def paged_attention(
#
# value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
# block_size = value_cache.shape[3]
block_size = BLOCK_SIZE
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
......@@ -62,58 +72,95 @@ def paged_attention(
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
from vllm._C import ops
if FLASH_DECODING:
max_q = 1
max_k = max_s
import flash_attn_2_cuda
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1:
ops.paged_attention_v1(
out,
# TODO fixme when flash contains the fix.
# Number of splits is not correctly handled
# by the current path
# https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
out2 = flash_attn_2_cuda.varlen_fwd(
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
None,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
max_q,
max_k,
0.0, # dropout
softmax_scale,
False, # zero_tensors
True, # causal
-1, # Window_left
-1, # Window right
False, # return softmax
None, # generator
)
return out2[0]
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype,
device=out.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=out.device,
)
max_logits = torch.empty_like(exp_sums)
input_lengths = seqlen.input_lengths
from vllm._C import ops
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
use_v1 = max_s <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512
)
if use_v1:
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype,
device=out.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=out.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
return out
try:
......
......@@ -55,7 +55,8 @@ def paged_attention(
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
cu_seqlen_q: torch.Tensor,
cu_seqlen_k: torch.Tensor,
max_s: int,
):
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
......@@ -66,7 +67,7 @@ def paged_attention(
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
cu_seqlen_q,
BLOCK_SIZE,
max_s,
None,
......
import os
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING
from loguru import logger
major, minor = torch.cuda.get_device_capability()
......@@ -26,7 +27,14 @@ def reshape_and_cache(
value_cache: torch.Tensor,
slots: torch.Tensor,
):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
if FLASH_DECODING:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
def paged_attention(
......@@ -37,7 +45,8 @@ def paged_attention(
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
cu_seqlen_q: torch.Tensor,
cu_seqlen_k: torch.Tensor,
max_s: int,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
......@@ -61,6 +70,7 @@ def paged_attention(
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = cu_seqlen_k
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
......@@ -119,6 +129,7 @@ def paged_attention(
"auto",
1.0,
)
return out
if ENGINE != "triton":
......
......@@ -12,7 +12,6 @@ from pathlib import Path
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOMSharded
from text_generation_server.models.mpt import MPTSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
......@@ -53,6 +52,7 @@ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
FLASH_ATTENTION = True
try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_gpt2 import FlashGPT2
from text_generation_server.models.flash_neox import FlashNeoXSharded
......@@ -92,6 +92,7 @@ except ImportError as e:
FLASH_ATTENTION = False
if FLASH_ATTENTION:
__all__.append(FlashCausalLM)
__all__.append(FlashGPT2)
__all__.append(FlashNeoXSharded)
__all__.append(FlashRWSharded)
......
......@@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
attention,
reshape_and_cache,
)
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import (
TensorParallelRowLinear,
......@@ -259,8 +260,8 @@ class FlashCohereAttention(torch.nn.Module):
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
slots,
max_s,
):
qkv = self.query_key_value(hidden_states)
......@@ -304,7 +305,7 @@ class FlashCohereAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
......@@ -464,6 +465,7 @@ class FlashCohereModel(torch.nn.Module):
)
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
......
......@@ -336,7 +336,7 @@ class DbrxAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
......
......@@ -251,7 +251,7 @@ class FlashGemma2Attention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
......
......@@ -245,7 +245,7 @@ class FlashGemmaAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
......
......@@ -245,7 +245,7 @@ class FlashGPT2Attention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
......
......@@ -33,6 +33,7 @@ from text_generation_server.layers.attention import (
attention,
reshape_and_cache,
)
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
......@@ -213,7 +214,7 @@ class FlashLlamaAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
......
......@@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
Seqlen,
paged_attention,
attention,
reshape_and_cache,
......@@ -229,7 +230,7 @@ class MistralAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
......@@ -512,7 +513,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
......
......@@ -291,7 +291,7 @@ class MixtralAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
......@@ -647,7 +647,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
hidden_states = self.model(
input_ids,
......
......@@ -168,7 +168,7 @@ class FlashNeoxAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
qkv[:, 0],
kv_cache[0],
......
......@@ -207,7 +207,7 @@ class FlashPhiAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
......
......@@ -149,7 +149,7 @@ class Qwen2Attention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
......
......@@ -217,7 +217,7 @@ class FlashRWAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
......@@ -340,7 +340,7 @@ class FlashRWLargeAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment