Unverified Commit e415b690 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Lots of improvements (Still 2 allocators) (#2449)



* Making prefix/flashinfer the default and testing the full release tests.

* Include flashinfer in the docker.

* Using prebuilt.

* Allowing window_left_size (dummy version).

* Disabling flashinfer/prefix caching on odd head_dim

* Disable prefix caching for lora.

* More specific codes.

* Update lock

* Updating integration tests with new values with FI/FD.

Remove paged as a default too, and using FD everywhere.

* Update cargo lock ?

* Upgrade to 1.80 because of bitstream...

* Everywhere 1.80

* Forgot last default place.

* Apply suggestions from code review
Co-authored-by: default avatardrbh <david.richard.holtz@gmail.com>

* Updated flake lock

* Tmp

* Upgrade resolution system for less errors in resolution.

* Remove lambda for cleaner function.

* Handling debugger.

* OVerride the env in server tests.

* Is this enough to make it work ?

* This seems to be working.

* Downgrade some logs.

* Fixing the default for vlm.

* Don't enable prefix caching on VLM just yet.

* Change `add_special_tokens` in order to have the correct tokens for chat
input and not (since it's super important with the prefixing now)

* Fixing prefix caching for flashdecoding.

* Update all models.

* Fixed flashinfer version.

* add_special_tokens is internal only

* Fixing seqlen with the new vlms.

* Fixing the issue with `add_special_tokens` not being passed around.

* Fixing the test.

* Removing encoder_decoder (seq2seq).

* Update the chat test.

* Fixing the batching tokenization in flash causal lm.

* Truncating left for radix purposes.

* Oops this doesn't belong here.

* Put back default pure shell.

* Update server tests

- Default to throughput test in k6
- Use TGI_WIGGLE_ROOM to adjust wiggle room

* Only n_heads / process_group.size() are necessary.

* Revert the integrationt tests change (seem linked to head_size
modification).

* Adding error message when assert is violated.

* Fixing the free algorithm to handle times where the common prefix is
smaller.

* Apply suggestions from code review
Co-authored-by: default avatarOlivierDehaene <olivier@huggingface.co>

* Update server/text_generation_server/layers/attention/common.py
Co-authored-by: default avatarOlivierDehaene <olivier@huggingface.co>

* Fix disabling prefix caching - Fix windowing checks.

* Revert the Cohere tokenizer change (for now using a revision instead).

* Fmt.

---------
Co-authored-by: default avatardrbh <david.richard.holtz@gmail.com>
Co-authored-by: default avatarOlivierDehaene <olivier@huggingface.co>
parent 4e821c00
......@@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
......@@ -213,7 +214,7 @@ class FlashGemma2Attention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
qkv = self.query_key_value(hidden_states)
......@@ -236,12 +237,10 @@ class FlashGemma2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
causal=self.causal,
window_size_left=self.window_size,
......@@ -256,7 +255,7 @@ class FlashGemma2Attention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
softcap=self.softcap,
)
......@@ -343,7 +342,7 @@ class FlashGemma2Layer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
......@@ -357,7 +356,7 @@ class FlashGemma2Layer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -408,7 +407,7 @@ class FlashGemma2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
) -> torch.Tensor:
hidden_states = inputs_embeds
......@@ -430,7 +429,7 @@ class FlashGemma2Model(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -477,7 +476,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -491,7 +490,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
if lm_head_indices is not None:
......
......@@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
......@@ -207,7 +208,7 @@ class FlashGemmaAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
qkv = self.query_key_value(hidden_states)
......@@ -230,12 +231,10 @@ class FlashGemmaAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
causal=self.causal,
)
......@@ -248,7 +247,7 @@ class FlashGemmaAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -320,7 +319,7 @@ class FlashGemmaLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
......@@ -334,7 +333,7 @@ class FlashGemmaLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -382,7 +381,7 @@ class FlashGemmaModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
) -> torch.Tensor:
hidden_states = inputs_embeds
......@@ -404,7 +403,7 @@ class FlashGemmaModel(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -449,7 +448,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -463,7 +462,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
if lm_head_indices is not None:
......
......@@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
......@@ -213,7 +214,7 @@ class FlashGPT2Attention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
query, key, value = self.query_key_value(hidden_states).split(
......@@ -230,12 +231,10 @@ class FlashGPT2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
key,
value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
......@@ -247,7 +246,7 @@ class FlashGPT2Attention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -316,7 +315,7 @@ class FlashGPT2Layer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
residual = hidden_states
......@@ -329,7 +328,7 @@ class FlashGPT2Layer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -382,7 +381,7 @@ class FlashGPT2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
......@@ -398,7 +397,7 @@ class FlashGPT2Model(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -435,7 +434,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -451,7 +450,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
......
......@@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
......@@ -167,7 +168,7 @@ class FlashGPTJAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
query, key, value = self.query_key_value(hidden_states).split(
......@@ -192,10 +193,10 @@ class FlashGPTJAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
key,
value,
cu_seqlen_prefill,
max_s,
kv_cache[0],
kv_cache[1],
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
......@@ -207,7 +208,7 @@ class FlashGPTJAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -268,7 +269,7 @@ class FlashGPTJLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
hidden_states, residual = self.input_layernorm(hidden_states, residual)
......@@ -281,7 +282,7 @@ class FlashGPTJLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -328,7 +329,7 @@ class FlashGPTJModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
......@@ -351,7 +352,7 @@ class FlashGPTJModel(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -382,7 +383,7 @@ class FlashGPTJForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -395,7 +396,7 @@ class FlashGPTJForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices=prefill_cache_indices,
)
......
......@@ -32,6 +32,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
......@@ -194,7 +195,7 @@ class FlashLlamaAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
adapter_data,
):
......@@ -218,12 +219,10 @@ class FlashLlamaAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
......@@ -235,7 +234,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -375,7 +374,7 @@ class FlashLlamaLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
adapter_data,
):
......@@ -390,7 +389,7 @@ class FlashLlamaLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
adapter_data,
)
......@@ -479,7 +478,7 @@ class FlashLlamaModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
......@@ -504,7 +503,7 @@ class FlashLlamaModel(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
adapter_data,
)
......@@ -548,7 +547,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -562,7 +561,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
......
......@@ -31,6 +31,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
......@@ -185,7 +186,7 @@ class MistralAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
......@@ -217,12 +218,10 @@ class MistralAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
window_size_left=self.max_past,
)
......@@ -235,7 +234,7 @@ class MistralAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -356,7 +355,7 @@ class MistralLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
......@@ -372,7 +371,7 @@ class MistralLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
......@@ -424,7 +423,7 @@ class MistralModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
......@@ -448,7 +447,7 @@ class MistralModel(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
......@@ -499,7 +498,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -512,7 +511,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
seqlen = seqlen.clamp(max=self.max_past_tensor)
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
......@@ -522,7 +521,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,
......
......@@ -35,6 +35,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
FastLinear,
......@@ -243,7 +244,7 @@ class MixtralAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
):
......@@ -274,12 +275,10 @@ class MixtralAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
window_size_left=self.max_past,
)
......@@ -292,7 +291,7 @@ class MixtralAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -498,7 +497,7 @@ class MixtralLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
):
......@@ -513,7 +512,7 @@ class MixtralLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
)
......@@ -568,7 +567,7 @@ class MixtralModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
......@@ -592,7 +591,7 @@ class MixtralModel(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
)
......@@ -627,7 +626,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -640,7 +639,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model(
input_ids,
......@@ -649,7 +648,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,
......
......@@ -31,6 +31,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
......@@ -147,7 +148,7 @@ class FlashNeoxAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
qkv = self.query_key_value(hidden_states)
......@@ -171,12 +172,10 @@ class FlashNeoxAttention(torch.nn.Module):
# flash attention
attn_output = attention(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
......@@ -188,7 +187,7 @@ class FlashNeoxAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -258,7 +257,7 @@ class FlashNeoXLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
if self.use_parallel_residual:
......@@ -272,7 +271,7 @@ class FlashNeoXLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -296,7 +295,7 @@ class FlashNeoXLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -350,7 +349,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_in(input_ids)
......@@ -372,7 +371,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -404,7 +403,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -417,7 +416,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
if lm_head_indices is not None:
......
......@@ -19,6 +19,7 @@ from torch import nn
from typing import Optional, List, Tuple
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
......@@ -70,7 +71,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -107,7 +108,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
seqlen=seqlen,
max_s=max_s,
)
......
......@@ -10,6 +10,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
......@@ -159,7 +160,7 @@ class FlashPhiAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
# Compute query, key, value and split
......@@ -192,12 +193,10 @@ class FlashPhiAttention(torch.nn.Module):
if cu_seqlen_prefill is not None:
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
......@@ -209,7 +208,7 @@ class FlashPhiAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -276,7 +275,7 @@ class FlashPhiLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
hidden_states, res = self.input_layernorm(hidden_states, residual)
......@@ -289,7 +288,7 @@ class FlashPhiLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -341,7 +340,7 @@ class FlashPhiModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
......@@ -363,7 +362,7 @@ class FlashPhiModel(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -396,7 +395,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -409,7 +408,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
if lm_head_indices is not None:
......
......@@ -9,6 +9,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
......@@ -104,7 +105,7 @@ class Qwen2Attention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
):
......@@ -135,12 +136,10 @@ class Qwen2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
window_size_left=self.max_past,
)
......@@ -153,7 +152,7 @@ class Qwen2Attention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -225,7 +224,7 @@ class Qwen2Layer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
):
......@@ -240,7 +239,7 @@ class Qwen2Layer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
)
......@@ -296,7 +295,7 @@ class Qwen2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
......@@ -320,7 +319,7 @@ class Qwen2Model(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
)
......@@ -361,7 +360,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -374,7 +373,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model(
input_ids,
......@@ -383,7 +382,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,
......
......@@ -9,6 +9,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
......@@ -268,7 +269,7 @@ class FlashMQAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
qkv = self.c_attn(hidden_states)
......@@ -291,12 +292,10 @@ class FlashMQAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
......@@ -308,7 +307,7 @@ class FlashMQAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -373,7 +372,7 @@ class Block(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
hidden_states, residual = self.ln_1(hidden_states, residual)
......@@ -383,7 +382,7 @@ class Block(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -437,7 +436,7 @@ class FlashSantacoderModel(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
) -> torch.Tensor:
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
......@@ -454,7 +453,7 @@ class FlashSantacoderModel(nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -486,7 +485,7 @@ class FlashSantacoderForCausalLM(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -499,7 +498,7 @@ class FlashSantacoderForCausalLM(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
if lm_head_indices is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment