Unverified Commit 5b6b74e2 authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Improve support for GPUs with capability < 8 (#2575)

* Improve support for GPUs with capability < 8

- For models that cannot use flashinfer, use flash-attn v1 + paged
  attention for models with a compute capability older than 8.
- Disable prefix caching when using paged attention.
- When using flash-attn v1, pass the key/value, rather than the
  cache, since v1 cannot use block tables.

* nix: add flash-attn-v1 to the server environment

* Move disabling prefix caching into the block of exceptions

* Capability as `usize`s
parent 0aa66d69
...@@ -38,6 +38,7 @@ from text_generation_server.layers import ( ...@@ -38,6 +38,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import ( from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
...@@ -192,8 +193,8 @@ class FlashGPTJAttention(torch.nn.Module): ...@@ -192,8 +193,8 @@ class FlashGPTJAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else key, kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if SYSTEM != "ipex" else value, kv_cache[1] if PREFILL_IN_KV_CACHE else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
...@@ -27,6 +27,7 @@ import torch.distributed ...@@ -27,6 +27,7 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
...@@ -220,8 +221,8 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -220,8 +221,8 @@ class FlashLlamaAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
...@@ -41,6 +41,7 @@ from text_generation_server.layers import ( ...@@ -41,6 +41,7 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear, TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear, TensorParallelAdapterRowLinear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
...@@ -218,8 +219,8 @@ class MistralAttention(torch.nn.Module): ...@@ -218,8 +219,8 @@ class MistralAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
...@@ -39,10 +39,10 @@ from text_generation_server.layers.attention import ( ...@@ -39,10 +39,10 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
reshape_and_cache, reshape_and_cache,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
...@@ -267,8 +267,8 @@ class MixtralAttention(torch.nn.Module): ...@@ -267,8 +267,8 @@ class MixtralAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
...@@ -26,7 +26,6 @@ from transformers.activations import ACT2FN ...@@ -26,7 +26,6 @@ from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
...@@ -40,6 +39,7 @@ from text_generation_server.layers import ( ...@@ -40,6 +39,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
...@@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
qkv[:, 0], qkv[:, 0],
kv_cache[0] if SYSTEM != "ipex" else qkv[:, 1], kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1],
kv_cache[1] if SYSTEM != "ipex" else qkv[:, 2], kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
...@@ -19,13 +19,13 @@ from text_generation_server.layers import ( ...@@ -19,13 +19,13 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
from text_generation_server.layers.rotary import ( from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.import_utils import SYSTEM
class PhiConfig(PretrainedConfig): class PhiConfig(PretrainedConfig):
...@@ -194,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module): ...@@ -194,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
...@@ -17,11 +17,11 @@ from text_generation_server.layers import ( ...@@ -17,11 +17,11 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
SpeculativeHead, SpeculativeHead,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.utils.import_utils import SYSTEM
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
...@@ -137,8 +137,8 @@ class Qwen2Attention(torch.nn.Module): ...@@ -137,8 +137,8 @@ class Qwen2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
...@@ -5,7 +5,6 @@ import torch.distributed ...@@ -5,7 +5,6 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import ( from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
TensorParallelColumnLinear, TensorParallelColumnLinear,
...@@ -13,6 +12,7 @@ from text_generation_server.layers import ( ...@@ -13,6 +12,7 @@ from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastLayerNorm from text_generation_server.layers.layernorm import FastLayerNorm
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
...@@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module): ...@@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
...@@ -325,8 +325,8 @@ class FlashRWLargeAttention(torch.nn.Module): ...@@ -325,8 +325,8 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, :, 0].contiguous(), kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(),
kv_cache[1] if SYSTEM != "ipex" else kv[:, :, 1].contiguous(), kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(),
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
...@@ -18,11 +18,11 @@ from text_generation_server.layers import ( ...@@ -18,11 +18,11 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
from text_generation_server.utils.import_utils import SYSTEM
def load_multi_mqa( def load_multi_mqa(
...@@ -293,8 +293,8 @@ class FlashMQAttention(torch.nn.Module): ...@@ -293,8 +293,8 @@ class FlashMQAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else key_value[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else key_value[:, 0],
kv_cache[1] if SYSTEM != "ipex" else key_value[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else key_value[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
...@@ -39,6 +39,7 @@ from text_generation_server.layers import ( ...@@ -39,6 +39,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
FastRMSNorm, FastRMSNorm,
...@@ -47,7 +48,6 @@ from text_generation_server.layers.rotary import ( ...@@ -47,7 +48,6 @@ from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
from text_generation_server.utils.import_utils import SYSTEM
class Starcoder2Config(PretrainedConfig): class Starcoder2Config(PretrainedConfig):
...@@ -242,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module): ...@@ -242,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment