Unverified Commit 4327210e authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

[Major Change][Undecided yet] Move to FlashDecoding instead of PagedAttention kernel. (#1940)

* Using flash decoding

Conditional flashdecoding.

Fix max_q.

Working kvcache

Working version with flash decoding.

Make it work for mistral.

Fix after rebase..

Less intrusive.

REvert changes in modeling.

Speedup flashdecoding.

HHachweew
Hack to make other models work.

Fixing non flash decoding llama path.

Router logic knows about page size.

Missing 2 models.

Missing cohere.

Fixing cohere flash decoding.

Revamped all this architecture.

Fix cohere.

Fixing falcon.

Enabling custom block size schedule.

Update router/src/infer.rs

Not sending preallocated output.

* Making it work on non flash decoding.

* Fix Cohere.

* Fix non decoding paths.

* Rebased.

* No need for cache_manager anymore.

* Update?

* "ipex" -> "cpu"

* These do not belong.

* Factoring cu_seqlen_qk for better abstracting over every model.

* Fixing non flash tests/imports.

* Changing return everywhere.

* Update mistral past.

* Fixing Mi{s,x}tral (non functional in Flash Decoding mode though).

* Fixup mistral clamping (had issues with cuda graphs).

* No need to recreate anything actually.
parent 4f55f158
...@@ -39,7 +39,14 @@ impl SchedulerV2 { ...@@ -39,7 +39,14 @@ impl SchedulerV2 {
speculate: u32, speculate: u32,
generation_health: Arc<AtomicBool>, generation_health: Arc<AtomicBool>,
) -> Self { ) -> Self {
let queue = Queue::new(requires_padding, 16, window_size, speculate); // Infer shared state
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
} else {
false
};
let block_size = if flashdecoding { 256 } else { 16 };
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let batching_task_notifier = Arc::new(Notify::new()); let batching_task_notifier = Arc::new(Notify::new());
// Spawn batching background task that contains all the inference logic // Spawn batching background task that contains all the inference logic
......
...@@ -39,9 +39,15 @@ impl SchedulerV3 { ...@@ -39,9 +39,15 @@ impl SchedulerV3 {
speculate: u32, speculate: u32,
generation_health: Arc<AtomicBool>, generation_health: Arc<AtomicBool>,
) -> Self { ) -> Self {
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
} else {
false
};
let block_size = if flashdecoding { 256 } else { 16 };
let queue = Queue::new( let queue = Queue::new(
requires_padding, requires_padding,
16, block_size,
window_size, window_size,
speculate, speculate,
max_batch_total_tokens, max_batch_total_tokens,
......
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
import os import os
from .common import Seqlen
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.") raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda": if SYSTEM == "cuda":
......
from dataclasses import dataclass
from text_generation_server.models.globals import FLASH_DECODING
import torch
from typing import Optional
if FLASH_DECODING:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]
def __init__(self, input_lengths):
self.input_lengths = input_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
def clamp(self, max):
# Flash decoding doesn't need to clamp
return self
else:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
def clamp(self, max):
return Seqlen(torch.clamp(self.input_lengths, max=max))
import torch import torch
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5 is_sm75 = major == 7 and minor == 5
...@@ -21,7 +23,14 @@ def reshape_and_cache( ...@@ -21,7 +23,14 @@ def reshape_and_cache(
value_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
): ):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) if FLASH_DECODING:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
def paged_attention( def paged_attention(
...@@ -32,7 +41,7 @@ def paged_attention( ...@@ -32,7 +41,7 @@ def paged_attention(
kv_head_mapping: torch.Tensor, kv_head_mapping: torch.Tensor,
softmax_scale: float, softmax_scale: float,
block_tables: torch.Tensor, block_tables: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
): ):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
...@@ -53,7 +62,8 @@ def paged_attention( ...@@ -53,7 +62,8 @@ def paged_attention(
# #
# value_cache => [num_blocks, num_heads, head_size, block_size] # value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3] # block_size = value_cache.shape[3]
block_size = BLOCK_SIZE
num_seqs, num_heads, head_size = query.shape num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
...@@ -62,9 +72,45 @@ def paged_attention( ...@@ -62,9 +72,45 @@ def paged_attention(
# V1 to avoid the overhead of reduction. Also, if the number of # V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work # sequences or heads is large, we use V1 since there is enough work
# to parallelize. # to parallelize.
if FLASH_DECODING:
max_q = 1
max_k = max_s
import flash_attn_2_cuda
# TODO fixme when flash contains the fix.
# Number of splits is not correctly handled
# by the current path
# https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
out2 = flash_attn_2_cuda.varlen_fwd(
query,
key_cache,
value_cache,
None,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
block_tables,
None,
max_q,
max_k,
0.0, # dropout
softmax_scale,
False, # zero_tensors
True, # causal
-1, # Window_left
-1, # Window right
False, # return softmax
None, # generator
)
return out2[0]
else:
input_lengths = seqlen.input_lengths
from vllm._C import ops from vllm._C import ops
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) use_v1 = max_s <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512
)
if use_v1: if use_v1:
ops.paged_attention_v1( ops.paged_attention_v1(
out, out,
...@@ -114,6 +160,7 @@ def paged_attention( ...@@ -114,6 +160,7 @@ def paged_attention(
"auto", "auto",
1.0, 1.0,
) )
return out
try: try:
......
...@@ -55,7 +55,8 @@ def paged_attention( ...@@ -55,7 +55,8 @@ def paged_attention(
kv_head_mapping: torch.Tensor, kv_head_mapping: torch.Tensor,
softmax_scale: float, softmax_scale: float,
block_tables: torch.Tensor, block_tables: torch.Tensor,
input_lengths: torch.Tensor, cu_seqlen_q: torch.Tensor,
cu_seqlen_k: torch.Tensor,
max_s: int, max_s: int,
): ):
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
...@@ -66,7 +67,7 @@ def paged_attention( ...@@ -66,7 +67,7 @@ def paged_attention(
kv_head_mapping, kv_head_mapping,
softmax_scale, softmax_scale,
block_tables, block_tables,
input_lengths, cu_seqlen_q,
BLOCK_SIZE, BLOCK_SIZE,
max_s, max_s,
None, None,
......
import os import os
import torch import torch
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING
from loguru import logger from loguru import logger
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
...@@ -26,7 +27,14 @@ def reshape_and_cache( ...@@ -26,7 +27,14 @@ def reshape_and_cache(
value_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
): ):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) if FLASH_DECODING:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
def paged_attention( def paged_attention(
...@@ -37,7 +45,8 @@ def paged_attention( ...@@ -37,7 +45,8 @@ def paged_attention(
kv_head_mapping: torch.Tensor, kv_head_mapping: torch.Tensor,
softmax_scale: float, softmax_scale: float,
block_tables: torch.Tensor, block_tables: torch.Tensor,
input_lengths: torch.Tensor, cu_seqlen_q: torch.Tensor,
cu_seqlen_k: torch.Tensor,
max_s: int, max_s: int,
): ):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
...@@ -61,6 +70,7 @@ def paged_attention( ...@@ -61,6 +70,7 @@ def paged_attention(
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = cu_seqlen_k
# NOTE(woosuk): We use a simple heuristic to decide whether to use # NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use # PagedAttention V1 or V2. If the number of partitions is 1, we use
...@@ -119,6 +129,7 @@ def paged_attention( ...@@ -119,6 +129,7 @@ def paged_attention(
"auto", "auto",
1.0, 1.0,
) )
return out
if ENGINE != "triton": if ENGINE != "triton":
......
...@@ -12,7 +12,6 @@ from pathlib import Path ...@@ -12,7 +12,6 @@ from pathlib import Path
from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOMSharded from text_generation_server.models.bloom import BLOOMSharded
from text_generation_server.models.mpt import MPTSharded from text_generation_server.models.mpt import MPTSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.seq2seq_lm import Seq2SeqLM
...@@ -53,6 +52,7 @@ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." ...@@ -53,6 +52,7 @@ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
FLASH_ATTENTION = True FLASH_ATTENTION = True
try: try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_gpt2 import FlashGPT2 from text_generation_server.models.flash_gpt2 import FlashGPT2
from text_generation_server.models.flash_neox import FlashNeoXSharded from text_generation_server.models.flash_neox import FlashNeoXSharded
...@@ -92,6 +92,7 @@ except ImportError as e: ...@@ -92,6 +92,7 @@ except ImportError as e:
FLASH_ATTENTION = False FLASH_ATTENTION = False
if FLASH_ATTENTION: if FLASH_ATTENTION:
__all__.append(FlashCausalLM)
__all__.append(FlashGPT2) __all__.append(FlashGPT2)
__all__.append(FlashNeoXSharded) __all__.append(FlashNeoXSharded)
__all__.append(FlashRWSharded) __all__.append(FlashRWSharded)
......
...@@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( ...@@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
attention, attention,
reshape_and_cache, reshape_and_cache,
) )
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -259,8 +260,8 @@ class FlashCohereAttention(torch.nn.Module): ...@@ -259,8 +260,8 @@ class FlashCohereAttention(torch.nn.Module):
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables, block_tables,
slots,
input_lengths, input_lengths,
slots,
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
...@@ -304,7 +305,7 @@ class FlashCohereAttention(torch.nn.Module): ...@@ -304,7 +305,7 @@ class FlashCohereAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],
...@@ -464,6 +465,7 @@ class FlashCohereModel(torch.nn.Module): ...@@ -464,6 +465,7 @@ class FlashCohereModel(torch.nn.Module):
) )
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
......
...@@ -336,7 +336,7 @@ class DbrxAttention(torch.nn.Module): ...@@ -336,7 +336,7 @@ class DbrxAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],
......
...@@ -251,7 +251,7 @@ class FlashGemma2Attention(torch.nn.Module): ...@@ -251,7 +251,7 @@ class FlashGemma2Attention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],
......
...@@ -245,7 +245,7 @@ class FlashGemmaAttention(torch.nn.Module): ...@@ -245,7 +245,7 @@ class FlashGemmaAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],
......
...@@ -245,7 +245,7 @@ class FlashGPT2Attention(torch.nn.Module): ...@@ -245,7 +245,7 @@ class FlashGPT2Attention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],
......
...@@ -33,6 +33,7 @@ from text_generation_server.layers.attention import ( ...@@ -33,6 +33,7 @@ from text_generation_server.layers.attention import (
attention, attention,
reshape_and_cache, reshape_and_cache,
) )
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
...@@ -213,7 +214,7 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -213,7 +214,7 @@ class FlashLlamaAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],
......
...@@ -28,6 +28,7 @@ from typing import Optional, List, Tuple ...@@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
Seqlen,
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
...@@ -229,7 +230,7 @@ class MistralAttention(torch.nn.Module): ...@@ -229,7 +230,7 @@ class MistralAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],
...@@ -512,7 +513,7 @@ class FlashMistralForCausalLM(torch.nn.Module): ...@@ -512,7 +513,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) input_lengths = input_lengths.clamp(max=self.max_past_tensor)
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(
......
...@@ -291,7 +291,7 @@ class MixtralAttention(torch.nn.Module): ...@@ -291,7 +291,7 @@ class MixtralAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],
...@@ -647,7 +647,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): ...@@ -647,7 +647,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) input_lengths = input_lengths.clamp(max=self.max_past_tensor)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
......
...@@ -168,7 +168,7 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -168,7 +168,7 @@ class FlashNeoxAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
qkv[:, 0], qkv[:, 0],
kv_cache[0], kv_cache[0],
......
...@@ -207,7 +207,7 @@ class FlashPhiAttention(torch.nn.Module): ...@@ -207,7 +207,7 @@ class FlashPhiAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],
......
...@@ -149,7 +149,7 @@ class Qwen2Attention(torch.nn.Module): ...@@ -149,7 +149,7 @@ class Qwen2Attention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],
......
...@@ -217,7 +217,7 @@ class FlashRWAttention(torch.nn.Module): ...@@ -217,7 +217,7 @@ class FlashRWAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],
...@@ -340,7 +340,7 @@ class FlashRWLargeAttention(torch.nn.Module): ...@@ -340,7 +340,7 @@ class FlashRWLargeAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment