Unverified Commit 4327210e authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

[Major Change][Undecided yet] Move to FlashDecoding instead of PagedAttention kernel. (#1940)

* Using flash decoding

Conditional flashdecoding.

Fix max_q.

Working kvcache

Working version with flash decoding.

Make it work for mistral.

Fix after rebase..

Less intrusive.

REvert changes in modeling.

Speedup flashdecoding.

HHachweew
Hack to make other models work.

Fixing non flash decoding llama path.

Router logic knows about page size.

Missing 2 models.

Missing cohere.

Fixing cohere flash decoding.

Revamped all this architecture.

Fix cohere.

Fixing falcon.

Enabling custom block size schedule.

Update router/src/infer.rs

Not sending preallocated output.

* Making it work on non flash decoding.

* Fix Cohere.

* Fix non decoding paths.

* Rebased.

* No need for cache_manager anymore.

* Update?

* "ipex" -> "cpu"

* These do not belong.

* Factoring cu_seqlen_qk for better abstracting over every model.

* Fixing non flash tests/imports.

* Changing return everywhere.

* Update mistral past.

* Fixing Mi{s,x}tral (non functional in Flash Decoding mode though).

* Fixup mistral clamping (had issues with cuda graphs).

* No need to recreate anything actually.
parent 4f55f158
......@@ -301,7 +301,7 @@ class FlashMQAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
......
......@@ -255,7 +255,7 @@ class Starcoder2Attention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
......
......@@ -30,10 +30,13 @@ from text_generation_server.models.types import (
from text_generation_server.pb import generate_pb2
from text_generation_server.models.globals import (
MEM_POOL,
FLASH_DECODING,
BLOCK_SIZE,
CUDA_GRAPHS,
get_adapter_to_index,
MODEL_ID,
)
from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
......@@ -46,7 +49,6 @@ from text_generation_server.utils.import_utils import (
tracer = trace.get_tracer(__name__)
BLOCK_SIZE: int = 16
# Will be set in init
SLIDING_WINDOW: Optional[int] = None
......@@ -856,7 +858,23 @@ class FlashCausalLM(Model):
else:
x = BLOCK_SIZE // element_size
if SYSTEM == "ipex" and device == torch.device("cpu"):
if FLASH_DECODING:
self.kv_cache = [
(
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
elif SYSTEM == "ipex" and device == torch.device("cpu"):
self.kv_cache = [
(
torch.empty(
......@@ -908,6 +926,7 @@ class FlashCausalLM(Model):
"slots": slots,
"input_lengths": input_lengths,
}
input_lengths = Seqlen(input_lengths=input_lengths)
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph
......@@ -1067,6 +1086,7 @@ class FlashCausalLM(Model):
# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
input_lengths = Seqlen(input_lengths=input_lengths)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
......@@ -1153,6 +1173,7 @@ class FlashCausalLM(Model):
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
......
......@@ -5,6 +5,12 @@ from typing import Dict
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
if FLASH_DECODING:
logger.info("Using FLASH_DECODING")
cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None:
try:
......@@ -15,8 +21,6 @@ if cuda_graphs is not None:
)
else:
cuda_graphs = None
# sorting the cuda graphs in descending order helps reduce the
# memory impact and results in less memory usage
if cuda_graphs is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment