Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
8ec57558
"git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "3f240fbb3734ab5f112a3d26d3856cf0a0e1a092"
Unverified
Commit
8ec57558
authored
Oct 17, 2024
by
Daniël de Kok
Committed by
GitHub
Oct 17, 2024
Browse files
Break cycle between the attention implementations and KV cache (#2627)
parent
5f32dea1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
38 additions
and
69 deletions
+38
-69
server/text_generation_server/layers/attention/__init__.py
server/text_generation_server/layers/attention/__init__.py
+0
-4
server/text_generation_server/layers/attention/cuda.py
server/text_generation_server/layers/attention/cuda.py
+0
-25
server/text_generation_server/layers/attention/ipex.py
server/text_generation_server/layers/attention/ipex.py
+0
-13
server/text_generation_server/layers/attention/kv_cache.py
server/text_generation_server/layers/attention/kv_cache.py
+38
-3
server/text_generation_server/layers/attention/rocm.py
server/text_generation_server/layers/attention/rocm.py
+0
-24
No files found.
server/text_generation_server/layers/attention/__init__.py
View file @
8ec57558
...
@@ -11,21 +11,18 @@ if SYSTEM == "cuda":
...
@@ -11,21 +11,18 @@ if SYSTEM == "cuda":
SUPPORTS_WINDOWING
,
SUPPORTS_WINDOWING
,
attention
,
attention
,
paged_attention
,
paged_attention
,
reshape_and_cache
,
)
)
elif
SYSTEM
==
"rocm"
:
elif
SYSTEM
==
"rocm"
:
from
.rocm
import
(
from
.rocm
import
(
SUPPORTS_WINDOWING
,
SUPPORTS_WINDOWING
,
attention
,
attention
,
paged_attention
,
paged_attention
,
reshape_and_cache
,
)
)
elif
SYSTEM
==
"ipex"
:
elif
SYSTEM
==
"ipex"
:
from
.ipex
import
(
from
.ipex
import
(
SUPPORTS_WINDOWING
,
SUPPORTS_WINDOWING
,
attention
,
attention
,
paged_attention
,
paged_attention
,
reshape_and_cache
,
)
)
else
:
else
:
raise
ImportError
(
f
"System
{
SYSTEM
}
doesn't support flash/paged attention"
)
raise
ImportError
(
f
"System
{
SYSTEM
}
doesn't support flash/paged attention"
)
...
@@ -36,7 +33,6 @@ from .kv_cache import KVCache
...
@@ -36,7 +33,6 @@ from .kv_cache import KVCache
__all__
=
[
__all__
=
[
"attention"
,
"attention"
,
"paged_attention"
,
"paged_attention"
,
"reshape_and_cache"
,
"SUPPORTS_WINDOWING"
,
"SUPPORTS_WINDOWING"
,
"KVCache"
,
"KVCache"
,
"Seqlen"
,
"Seqlen"
,
...
...
server/text_generation_server/layers/attention/cuda.py
View file @
8ec57558
...
@@ -12,30 +12,6 @@ major, minor = torch.cuda.get_device_capability()
...
@@ -12,30 +12,6 @@ major, minor = torch.cuda.get_device_capability()
is_sm75
=
major
==
7
and
minor
==
5
is_sm75
=
major
==
7
and
minor
==
5
_PARTITION_SIZE
=
512
_PARTITION_SIZE
=
512
try
:
from
vllm._C
import
cache_ops
except
Exception
as
e
:
raise
ImportError
(
f
"Could not import vllm paged attention. Make sure your installation is correct. Complete error:
{
e
}
"
)
def
reshape_and_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
):
if
ATTENTION
in
{
"flashdecoding"
,
"flashinfer"
}:
shape
=
key_cache
.
shape
key_cache
.
view
(
-
1
,
shape
[
-
2
],
shape
[
-
1
])[
slots
]
=
key
value_cache
.
view
(
-
1
,
shape
[
-
2
],
shape
[
-
1
])[
slots
]
=
value
else
:
cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slots
,
"auto"
,
1.0
)
def
paged_attention
(
def
paged_attention
(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
@@ -346,5 +322,4 @@ __all__ = [
...
@@ -346,5 +322,4 @@ __all__ = [
"SUPPORTS_WINDOWING"
,
"SUPPORTS_WINDOWING"
,
"attention"
,
"attention"
,
"paged_attention"
,
"paged_attention"
,
"reshape_and_cache"
,
]
]
server/text_generation_server/layers/attention/ipex.py
View file @
8ec57558
...
@@ -47,18 +47,6 @@ def attention(
...
@@ -47,18 +47,6 @@ def attention(
return
out
return
out
def
reshape_and_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
):
ipex
.
llm
.
modules
.
PagedAttention
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slots
)
def
paged_attention
(
def
paged_attention
(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
kv_cache
:
KVCache
,
...
@@ -94,5 +82,4 @@ __all__ = [
...
@@ -94,5 +82,4 @@ __all__ = [
"SUPPORTS_WINDOWING"
,
"SUPPORTS_WINDOWING"
,
"attention"
,
"attention"
,
"paged_attention"
,
"paged_attention"
,
"reshape_and_cache"
,
]
]
server/text_generation_server/layers/attention/kv_cache.py
View file @
8ec57558
...
@@ -115,6 +115,41 @@ class KVCache:
...
@@ -115,6 +115,41 @@ class KVCache:
key_cache
.
view
(
-
1
,
shape
[
-
2
],
shape
[
-
1
])[
slots
]
=
key
key_cache
.
view
(
-
1
,
shape
[
-
2
],
shape
[
-
1
])[
slots
]
=
key
value_cache
.
view
(
-
1
,
shape
[
-
2
],
shape
[
-
1
])[
slots
]
=
value
value_cache
.
view
(
-
1
,
shape
[
-
2
],
shape
[
-
1
])[
slots
]
=
value
else
:
else
:
from
text_generation_server.layers.attention
import
reshape_and_cache
paged_reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slots
)
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slots
)
def
paged_reshape_and_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
):
if
SYSTEM
==
"cuda"
:
try
:
from
vllm._C
import
cache_ops
except
Exception
as
e
:
raise
ImportError
(
f
"Could not import vllm paged attention. Make sure your installation is correct. Complete error:
{
e
}
"
)
cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slots
,
"auto"
,
1.0
)
elif
SYSTEM
==
"rocm"
:
try
:
import
vllm._custom_ops
as
ops
except
Exception
as
e
:
raise
ImportError
(
f
"Could not import vllm paged attention. Make sure your installation is correct. Complete error:
{
e
}
"
)
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slots
,
"auto"
,
1.0
)
elif
SYSTEM
==
"ipex"
:
import
intel_extension_for_pytorch
as
ipex
ipex
.
llm
.
modules
.
PagedAttention
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slots
)
else
:
raise
NotImplementedError
(
f
"Cannot reshape and cache for paged attention, system '
{
SYSTEM
}
' not supportedattention"
)
server/text_generation_server/layers/attention/rocm.py
View file @
8ec57558
...
@@ -3,7 +3,6 @@ from typing import Optional
...
@@ -3,7 +3,6 @@ from typing import Optional
import
torch
import
torch
from
text_generation_server.layers.attention.kv_cache
import
KVCache
from
text_generation_server.layers.attention.kv_cache
import
KVCache
from
text_generation_server.utils.import_utils
import
SYSTEM
from
text_generation_server.utils.import_utils
import
SYSTEM
from
text_generation_server.models.globals
import
ATTENTION
from
text_generation_server.layers.attention
import
Seqlen
from
text_generation_server.layers.attention
import
Seqlen
from
text_generation_server.utils.log
import
log_master
from
text_generation_server.utils.log
import
log_master
from
loguru
import
logger
from
loguru
import
logger
...
@@ -28,28 +27,6 @@ except ImportError as e:
...
@@ -28,28 +27,6 @@ except ImportError as e:
)
)
use_rocm_custom_paged_attn
=
False
use_rocm_custom_paged_attn
=
False
try
:
import
vllm._custom_ops
as
ops
except
Exception
as
e
:
raise
ImportError
(
f
"Could not import vllm paged attention. Make sure your installation is correct. Complete error:
{
e
}
"
)
def
reshape_and_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
):
if
ATTENTION
==
"flashdecoding"
:
shape
=
key_cache
.
shape
key_cache
.
view
(
-
1
,
shape
[
-
2
],
shape
[
-
1
])[
slots
]
=
key
value_cache
.
view
(
-
1
,
shape
[
-
2
],
shape
[
-
1
])[
slots
]
=
value
else
:
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slots
,
"auto"
,
1.0
)
def
paged_attention
(
def
paged_attention
(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
@@ -305,5 +282,4 @@ __all__ = [
...
@@ -305,5 +282,4 @@ __all__ = [
"SUPPORTS_WINDOWING"
,
"SUPPORTS_WINDOWING"
,
"attention"
,
"attention"
,
"paged_attention"
,
"paged_attention"
,
"reshape_and_cache"
,
]
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment