"git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "3f240fbb3734ab5f112a3d26d3856cf0a0e1a092"
Unverified Commit 8ec57558 authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Break cycle between the attention implementations and KV cache (#2627)

parent 5f32dea1
...@@ -11,21 +11,18 @@ if SYSTEM == "cuda": ...@@ -11,21 +11,18 @@ if SYSTEM == "cuda":
SUPPORTS_WINDOWING, SUPPORTS_WINDOWING,
attention, attention,
paged_attention, paged_attention,
reshape_and_cache,
) )
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from .rocm import ( from .rocm import (
SUPPORTS_WINDOWING, SUPPORTS_WINDOWING,
attention, attention,
paged_attention, paged_attention,
reshape_and_cache,
) )
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
from .ipex import ( from .ipex import (
SUPPORTS_WINDOWING, SUPPORTS_WINDOWING,
attention, attention,
paged_attention, paged_attention,
reshape_and_cache,
) )
else: else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
...@@ -36,7 +33,6 @@ from .kv_cache import KVCache ...@@ -36,7 +33,6 @@ from .kv_cache import KVCache
__all__ = [ __all__ = [
"attention", "attention",
"paged_attention", "paged_attention",
"reshape_and_cache",
"SUPPORTS_WINDOWING", "SUPPORTS_WINDOWING",
"KVCache", "KVCache",
"Seqlen", "Seqlen",
......
...@@ -12,30 +12,6 @@ major, minor = torch.cuda.get_device_capability() ...@@ -12,30 +12,6 @@ major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5 is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
try:
from vllm._C import cache_ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if ATTENTION in {"flashdecoding", "flashinfer"}:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
def paged_attention( def paged_attention(
query: torch.Tensor, query: torch.Tensor,
...@@ -346,5 +322,4 @@ __all__ = [ ...@@ -346,5 +322,4 @@ __all__ = [
"SUPPORTS_WINDOWING", "SUPPORTS_WINDOWING",
"attention", "attention",
"paged_attention", "paged_attention",
"reshape_and_cache",
] ]
...@@ -47,18 +47,6 @@ def attention( ...@@ -47,18 +47,6 @@ def attention(
return out return out
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
)
def paged_attention( def paged_attention(
query: torch.Tensor, query: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
...@@ -94,5 +82,4 @@ __all__ = [ ...@@ -94,5 +82,4 @@ __all__ = [
"SUPPORTS_WINDOWING", "SUPPORTS_WINDOWING",
"attention", "attention",
"paged_attention", "paged_attention",
"reshape_and_cache",
] ]
...@@ -115,6 +115,41 @@ class KVCache: ...@@ -115,6 +115,41 @@ class KVCache:
key_cache.view(-1, shape[-2], shape[-1])[slots] = key key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else: else:
from text_generation_server.layers.attention import reshape_and_cache paged_reshape_and_cache(key, value, key_cache, value_cache, slots)
reshape_and_cache(key, value, key_cache, value_cache, slots)
def paged_reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if SYSTEM == "cuda":
try:
from vllm._C import cache_ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
elif SYSTEM == "rocm":
try:
import vllm._custom_ops as ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
)
else:
raise NotImplementedError(
f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supportedattention"
)
...@@ -3,7 +3,6 @@ from typing import Optional ...@@ -3,7 +3,6 @@ from typing import Optional
import torch import torch
from text_generation_server.layers.attention.kv_cache import KVCache from text_generation_server.layers.attention.kv_cache import KVCache
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
from loguru import logger from loguru import logger
...@@ -28,28 +27,6 @@ except ImportError as e: ...@@ -28,28 +27,6 @@ except ImportError as e:
) )
use_rocm_custom_paged_attn = False use_rocm_custom_paged_attn = False
try:
import vllm._custom_ops as ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if ATTENTION == "flashdecoding":
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
def paged_attention( def paged_attention(
query: torch.Tensor, query: torch.Tensor,
...@@ -305,5 +282,4 @@ __all__ = [ ...@@ -305,5 +282,4 @@ __all__ = [
"SUPPORTS_WINDOWING", "SUPPORTS_WINDOWING",
"attention", "attention",
"paged_attention", "paged_attention",
"reshape_and_cache",
] ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment