Unverified Commit e415b690 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Lots of improvements (Still 2 allocators) (#2449)



* Making prefix/flashinfer the default and testing the full release tests.

* Include flashinfer in the docker.

* Using prebuilt.

* Allowing window_left_size (dummy version).

* Disabling flashinfer/prefix caching on odd head_dim

* Disable prefix caching for lora.

* More specific codes.

* Update lock

* Updating integration tests with new values with FI/FD.

Remove paged as a default too, and using FD everywhere.

* Update cargo lock ?

* Upgrade to 1.80 because of bitstream...

* Everywhere 1.80

* Forgot last default place.

* Apply suggestions from code review
Co-authored-by: default avatardrbh <david.richard.holtz@gmail.com>

* Updated flake lock

* Tmp

* Upgrade resolution system for less errors in resolution.

* Remove lambda for cleaner function.

* Handling debugger.

* OVerride the env in server tests.

* Is this enough to make it work ?

* This seems to be working.

* Downgrade some logs.

* Fixing the default for vlm.

* Don't enable prefix caching on VLM just yet.

* Change `add_special_tokens` in order to have the correct tokens for chat
input and not (since it's super important with the prefixing now)

* Fixing prefix caching for flashdecoding.

* Update all models.

* Fixed flashinfer version.

* add_special_tokens is internal only

* Fixing seqlen with the new vlms.

* Fixing the issue with `add_special_tokens` not being passed around.

* Fixing the test.

* Removing encoder_decoder (seq2seq).

* Update the chat test.

* Fixing the batching tokenization in flash causal lm.

* Truncating left for radix purposes.

* Oops this doesn't belong here.

* Put back default pure shell.

* Update server tests

- Default to throughput test in k6
- Use TGI_WIGGLE_ROOM to adjust wiggle room

* Only n_heads / process_group.size() are necessary.

* Revert the integrationt tests change (seem linked to head_size
modification).

* Adding error message when assert is violated.

* Fixing the free algorithm to handle times where the common prefix is
smaller.

* Apply suggestions from code review
Co-authored-by: default avatarOlivierDehaene <olivier@huggingface.co>

* Update server/text_generation_server/layers/attention/common.py
Co-authored-by: default avatarOlivierDehaene <olivier@huggingface.co>

* Fix disabling prefix caching - Fix windowing checks.

* Revert the Cohere tokenizer change (for now using a revision instead).

* Fmt.

---------
Co-authored-by: default avatardrbh <david.richard.holtz@gmail.com>
Co-authored-by: default avatarOlivierDehaene <olivier@huggingface.co>
parent 4e821c00
...@@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( ...@@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -213,7 +214,7 @@ class FlashGemma2Attention(torch.nn.Module): ...@@ -213,7 +214,7 @@ class FlashGemma2Attention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
...@@ -236,12 +237,10 @@ class FlashGemma2Attention(torch.nn.Module): ...@@ -236,12 +237,10 @@ class FlashGemma2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
causal=self.causal, causal=self.causal,
window_size_left=self.window_size, window_size_left=self.window_size,
...@@ -256,7 +255,7 @@ class FlashGemma2Attention(torch.nn.Module): ...@@ -256,7 +255,7 @@ class FlashGemma2Attention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
softcap=self.softcap, softcap=self.softcap,
) )
...@@ -343,7 +342,7 @@ class FlashGemma2Layer(nn.Module): ...@@ -343,7 +342,7 @@ class FlashGemma2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
...@@ -357,7 +356,7 @@ class FlashGemma2Layer(nn.Module): ...@@ -357,7 +356,7 @@ class FlashGemma2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -408,7 +407,7 @@ class FlashGemma2Model(torch.nn.Module): ...@@ -408,7 +407,7 @@ class FlashGemma2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -430,7 +429,7 @@ class FlashGemma2Model(torch.nn.Module): ...@@ -430,7 +429,7 @@ class FlashGemma2Model(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -477,7 +476,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module): ...@@ -477,7 +476,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -491,7 +490,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module): ...@@ -491,7 +490,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
......
...@@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( ...@@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -207,7 +208,7 @@ class FlashGemmaAttention(torch.nn.Module): ...@@ -207,7 +208,7 @@ class FlashGemmaAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
...@@ -230,12 +231,10 @@ class FlashGemmaAttention(torch.nn.Module): ...@@ -230,12 +231,10 @@ class FlashGemmaAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
causal=self.causal, causal=self.causal,
) )
...@@ -248,7 +247,7 @@ class FlashGemmaAttention(torch.nn.Module): ...@@ -248,7 +247,7 @@ class FlashGemmaAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -320,7 +319,7 @@ class FlashGemmaLayer(nn.Module): ...@@ -320,7 +319,7 @@ class FlashGemmaLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
...@@ -334,7 +333,7 @@ class FlashGemmaLayer(nn.Module): ...@@ -334,7 +333,7 @@ class FlashGemmaLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -382,7 +381,7 @@ class FlashGemmaModel(torch.nn.Module): ...@@ -382,7 +381,7 @@ class FlashGemmaModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -404,7 +403,7 @@ class FlashGemmaModel(torch.nn.Module): ...@@ -404,7 +403,7 @@ class FlashGemmaModel(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -449,7 +448,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): ...@@ -449,7 +448,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -463,7 +462,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): ...@@ -463,7 +462,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
......
...@@ -29,6 +29,7 @@ from text_generation_server.layers.attention import ( ...@@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -213,7 +214,7 @@ class FlashGPT2Attention(torch.nn.Module): ...@@ -213,7 +214,7 @@ class FlashGPT2Attention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
query, key, value = self.query_key_value(hidden_states).split( query, key, value = self.query_key_value(hidden_states).split(
...@@ -230,12 +231,10 @@ class FlashGPT2Attention(torch.nn.Module): ...@@ -230,12 +231,10 @@ class FlashGPT2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
key,
value,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
...@@ -247,7 +246,7 @@ class FlashGPT2Attention(torch.nn.Module): ...@@ -247,7 +246,7 @@ class FlashGPT2Attention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -316,7 +315,7 @@ class FlashGPT2Layer(nn.Module): ...@@ -316,7 +315,7 @@ class FlashGPT2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
residual = hidden_states residual = hidden_states
...@@ -329,7 +328,7 @@ class FlashGPT2Layer(nn.Module): ...@@ -329,7 +328,7 @@ class FlashGPT2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -382,7 +381,7 @@ class FlashGPT2Model(torch.nn.Module): ...@@ -382,7 +381,7 @@ class FlashGPT2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
...@@ -398,7 +397,7 @@ class FlashGPT2Model(torch.nn.Module): ...@@ -398,7 +397,7 @@ class FlashGPT2Model(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -435,7 +434,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module): ...@@ -435,7 +434,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -451,7 +450,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module): ...@@ -451,7 +450,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
true_max_s=max_s, true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,
......
...@@ -29,6 +29,7 @@ from text_generation_server.layers.attention import ( ...@@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -167,7 +168,7 @@ class FlashGPTJAttention(torch.nn.Module): ...@@ -167,7 +168,7 @@ class FlashGPTJAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
query, key, value = self.query_key_value(hidden_states).split( query, key, value = self.query_key_value(hidden_states).split(
...@@ -192,10 +193,10 @@ class FlashGPTJAttention(torch.nn.Module): ...@@ -192,10 +193,10 @@ class FlashGPTJAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
key, kv_cache[0],
value, kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
...@@ -207,7 +208,7 @@ class FlashGPTJAttention(torch.nn.Module): ...@@ -207,7 +208,7 @@ class FlashGPTJAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -268,7 +269,7 @@ class FlashGPTJLayer(nn.Module): ...@@ -268,7 +269,7 @@ class FlashGPTJLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states, residual = self.input_layernorm(hidden_states, residual)
...@@ -281,7 +282,7 @@ class FlashGPTJLayer(nn.Module): ...@@ -281,7 +282,7 @@ class FlashGPTJLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -328,7 +329,7 @@ class FlashGPTJModel(torch.nn.Module): ...@@ -328,7 +329,7 @@ class FlashGPTJModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -351,7 +352,7 @@ class FlashGPTJModel(torch.nn.Module): ...@@ -351,7 +352,7 @@ class FlashGPTJModel(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -382,7 +383,7 @@ class FlashGPTJForCausalLM(torch.nn.Module): ...@@ -382,7 +383,7 @@ class FlashGPTJForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -395,7 +396,7 @@ class FlashGPTJForCausalLM(torch.nn.Module): ...@@ -395,7 +396,7 @@ class FlashGPTJForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,
) )
......
...@@ -32,6 +32,7 @@ from text_generation_server.layers.attention import ( ...@@ -32,6 +32,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -194,7 +195,7 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -194,7 +195,7 @@ class FlashLlamaAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
adapter_data, adapter_data,
): ):
...@@ -218,12 +219,10 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -218,12 +219,10 @@ class FlashLlamaAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
...@@ -235,7 +234,7 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -235,7 +234,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -375,7 +374,7 @@ class FlashLlamaLayer(nn.Module): ...@@ -375,7 +374,7 @@ class FlashLlamaLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
adapter_data, adapter_data,
): ):
...@@ -390,7 +389,7 @@ class FlashLlamaLayer(nn.Module): ...@@ -390,7 +389,7 @@ class FlashLlamaLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
adapter_data, adapter_data,
) )
...@@ -479,7 +478,7 @@ class FlashLlamaModel(torch.nn.Module): ...@@ -479,7 +478,7 @@ class FlashLlamaModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
...@@ -504,7 +503,7 @@ class FlashLlamaModel(torch.nn.Module): ...@@ -504,7 +503,7 @@ class FlashLlamaModel(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
adapter_data, adapter_data,
) )
...@@ -548,7 +547,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): ...@@ -548,7 +547,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -562,7 +561,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): ...@@ -562,7 +561,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
true_max_s=max_s, true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,
......
...@@ -31,6 +31,7 @@ from text_generation_server.layers.attention import ( ...@@ -31,6 +31,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -185,7 +186,7 @@ class MistralAttention(torch.nn.Module): ...@@ -185,7 +186,7 @@ class MistralAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data, adapter_data,
...@@ -217,12 +218,10 @@ class MistralAttention(torch.nn.Module): ...@@ -217,12 +218,10 @@ class MistralAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
...@@ -235,7 +234,7 @@ class MistralAttention(torch.nn.Module): ...@@ -235,7 +234,7 @@ class MistralAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -356,7 +355,7 @@ class MistralLayer(nn.Module): ...@@ -356,7 +355,7 @@ class MistralLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data, adapter_data,
...@@ -372,7 +371,7 @@ class MistralLayer(nn.Module): ...@@ -372,7 +371,7 @@ class MistralLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data, adapter_data,
...@@ -424,7 +423,7 @@ class MistralModel(torch.nn.Module): ...@@ -424,7 +423,7 @@ class MistralModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
...@@ -448,7 +447,7 @@ class MistralModel(torch.nn.Module): ...@@ -448,7 +447,7 @@ class MistralModel(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data, adapter_data,
...@@ -499,7 +498,7 @@ class FlashMistralForCausalLM(torch.nn.Module): ...@@ -499,7 +498,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -512,7 +511,7 @@ class FlashMistralForCausalLM(torch.nn.Module): ...@@ -512,7 +511,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
input_lengths = input_lengths.clamp(max=self.max_past_tensor) seqlen = seqlen.clamp(max=self.max_past_tensor)
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(
...@@ -522,7 +521,7 @@ class FlashMistralForCausalLM(torch.nn.Module): ...@@ -522,7 +521,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
true_max_s, true_max_s,
prefill_cache_indices, prefill_cache_indices,
......
...@@ -35,6 +35,7 @@ from text_generation_server.layers.attention import ( ...@@ -35,6 +35,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
FastLinear, FastLinear,
...@@ -243,7 +244,7 @@ class MixtralAttention(torch.nn.Module): ...@@ -243,7 +244,7 @@ class MixtralAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
): ):
...@@ -274,12 +275,10 @@ class MixtralAttention(torch.nn.Module): ...@@ -274,12 +275,10 @@ class MixtralAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
...@@ -292,7 +291,7 @@ class MixtralAttention(torch.nn.Module): ...@@ -292,7 +291,7 @@ class MixtralAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -498,7 +497,7 @@ class MixtralLayer(nn.Module): ...@@ -498,7 +497,7 @@ class MixtralLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
): ):
...@@ -513,7 +512,7 @@ class MixtralLayer(nn.Module): ...@@ -513,7 +512,7 @@ class MixtralLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
) )
...@@ -568,7 +567,7 @@ class MixtralModel(torch.nn.Module): ...@@ -568,7 +567,7 @@ class MixtralModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
...@@ -592,7 +591,7 @@ class MixtralModel(torch.nn.Module): ...@@ -592,7 +591,7 @@ class MixtralModel(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
) )
...@@ -627,7 +626,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): ...@@ -627,7 +626,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -640,7 +639,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): ...@@ -640,7 +639,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
input_lengths = input_lengths.clamp(max=self.max_past_tensor) seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
...@@ -649,7 +648,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): ...@@ -649,7 +648,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
true_max_s, true_max_s,
prefill_cache_indices, prefill_cache_indices,
......
...@@ -31,6 +31,7 @@ from text_generation_server.layers.attention import ( ...@@ -31,6 +31,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -147,7 +148,7 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -147,7 +148,7 @@ class FlashNeoxAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
...@@ -171,12 +172,10 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -171,12 +172,10 @@ class FlashNeoxAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
qkv[:, 0], qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
...@@ -188,7 +187,7 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -188,7 +187,7 @@ class FlashNeoxAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -258,7 +257,7 @@ class FlashNeoXLayer(nn.Module): ...@@ -258,7 +257,7 @@ class FlashNeoXLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
if self.use_parallel_residual: if self.use_parallel_residual:
...@@ -272,7 +271,7 @@ class FlashNeoXLayer(nn.Module): ...@@ -272,7 +271,7 @@ class FlashNeoXLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -296,7 +295,7 @@ class FlashNeoXLayer(nn.Module): ...@@ -296,7 +295,7 @@ class FlashNeoXLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -350,7 +349,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): ...@@ -350,7 +349,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_in(input_ids) hidden_states = self.embed_in(input_ids)
...@@ -372,7 +371,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): ...@@ -372,7 +371,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -404,7 +403,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): ...@@ -404,7 +403,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -417,7 +416,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): ...@@ -417,7 +416,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
......
...@@ -19,6 +19,7 @@ from torch import nn ...@@ -19,6 +19,7 @@ from torch import nn
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.custom_modeling.vlm import ( from text_generation_server.models.custom_modeling.vlm import (
load_text_model, load_text_model,
load_vision_model, load_vision_model,
...@@ -70,7 +71,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): ...@@ -70,7 +71,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -107,7 +108,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): ...@@ -107,7 +108,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths, seqlen=seqlen,
max_s=max_s, max_s=max_s,
) )
......
...@@ -10,6 +10,7 @@ from text_generation_server.layers.attention import ( ...@@ -10,6 +10,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -159,7 +160,7 @@ class FlashPhiAttention(torch.nn.Module): ...@@ -159,7 +160,7 @@ class FlashPhiAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
# Compute query, key, value and split # Compute query, key, value and split
...@@ -192,12 +193,10 @@ class FlashPhiAttention(torch.nn.Module): ...@@ -192,12 +193,10 @@ class FlashPhiAttention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
...@@ -209,7 +208,7 @@ class FlashPhiAttention(torch.nn.Module): ...@@ -209,7 +208,7 @@ class FlashPhiAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -276,7 +275,7 @@ class FlashPhiLayer(nn.Module): ...@@ -276,7 +275,7 @@ class FlashPhiLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
hidden_states, res = self.input_layernorm(hidden_states, residual) hidden_states, res = self.input_layernorm(hidden_states, residual)
...@@ -289,7 +288,7 @@ class FlashPhiLayer(nn.Module): ...@@ -289,7 +288,7 @@ class FlashPhiLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -341,7 +340,7 @@ class FlashPhiModel(torch.nn.Module): ...@@ -341,7 +340,7 @@ class FlashPhiModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
...@@ -363,7 +362,7 @@ class FlashPhiModel(torch.nn.Module): ...@@ -363,7 +362,7 @@ class FlashPhiModel(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -396,7 +395,7 @@ class FlashPhiForCausalLM(torch.nn.Module): ...@@ -396,7 +395,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -409,7 +408,7 @@ class FlashPhiForCausalLM(torch.nn.Module): ...@@ -409,7 +408,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
......
...@@ -9,6 +9,7 @@ from text_generation_server.layers.attention import ( ...@@ -9,6 +9,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -104,7 +105,7 @@ class Qwen2Attention(torch.nn.Module): ...@@ -104,7 +105,7 @@ class Qwen2Attention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
): ):
...@@ -135,12 +136,10 @@ class Qwen2Attention(torch.nn.Module): ...@@ -135,12 +136,10 @@ class Qwen2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
...@@ -153,7 +152,7 @@ class Qwen2Attention(torch.nn.Module): ...@@ -153,7 +152,7 @@ class Qwen2Attention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -225,7 +224,7 @@ class Qwen2Layer(nn.Module): ...@@ -225,7 +224,7 @@ class Qwen2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
): ):
...@@ -240,7 +239,7 @@ class Qwen2Layer(nn.Module): ...@@ -240,7 +239,7 @@ class Qwen2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
) )
...@@ -296,7 +295,7 @@ class Qwen2Model(torch.nn.Module): ...@@ -296,7 +295,7 @@ class Qwen2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
...@@ -320,7 +319,7 @@ class Qwen2Model(torch.nn.Module): ...@@ -320,7 +319,7 @@ class Qwen2Model(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
) )
...@@ -361,7 +360,7 @@ class Qwen2ForCausalLM(torch.nn.Module): ...@@ -361,7 +360,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -374,7 +373,7 @@ class Qwen2ForCausalLM(torch.nn.Module): ...@@ -374,7 +373,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
input_lengths = input_lengths.clamp(max=self.max_past_tensor) seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
...@@ -383,7 +382,7 @@ class Qwen2ForCausalLM(torch.nn.Module): ...@@ -383,7 +382,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
true_max_s, true_max_s,
prefill_cache_indices, prefill_cache_indices,
......
...@@ -19,6 +19,7 @@ from text_generation_server.layers.attention import ( ...@@ -19,6 +19,7 @@ from text_generation_server.layers.attention import (
attention, attention,
paged_attention, paged_attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
...@@ -181,7 +182,7 @@ class FlashRWAttention(torch.nn.Module): ...@@ -181,7 +182,7 @@ class FlashRWAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
...@@ -206,12 +207,10 @@ class FlashRWAttention(torch.nn.Module): ...@@ -206,12 +207,10 @@ class FlashRWAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
...@@ -223,7 +222,7 @@ class FlashRWAttention(torch.nn.Module): ...@@ -223,7 +222,7 @@ class FlashRWAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -296,7 +295,7 @@ class FlashRWLargeAttention(torch.nn.Module): ...@@ -296,7 +295,7 @@ class FlashRWLargeAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
...@@ -343,7 +342,7 @@ class FlashRWLargeAttention(torch.nn.Module): ...@@ -343,7 +342,7 @@ class FlashRWLargeAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -429,7 +428,7 @@ class FlashRWLayer(nn.Module): ...@@ -429,7 +428,7 @@ class FlashRWLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
if self.parallel_attn: if self.parallel_attn:
...@@ -443,7 +442,7 @@ class FlashRWLayer(nn.Module): ...@@ -443,7 +442,7 @@ class FlashRWLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -465,7 +464,7 @@ class FlashRWLayer(nn.Module): ...@@ -465,7 +464,7 @@ class FlashRWLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -552,7 +551,7 @@ class FlashRWLargeLayer(nn.Module): ...@@ -552,7 +551,7 @@ class FlashRWLargeLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
# Layer norm. # Layer norm.
...@@ -567,7 +566,7 @@ class FlashRWLargeLayer(nn.Module): ...@@ -567,7 +566,7 @@ class FlashRWLargeLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -628,7 +627,7 @@ class FlashRWModel(FlashRWPreTrainedModel): ...@@ -628,7 +627,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids) hidden_states = self.word_embeddings(input_ids)
...@@ -650,7 +649,7 @@ class FlashRWModel(FlashRWPreTrainedModel): ...@@ -650,7 +649,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -680,7 +679,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): ...@@ -680,7 +679,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -693,7 +692,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): ...@@ -693,7 +692,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
......
...@@ -9,6 +9,7 @@ from text_generation_server.layers.attention import ( ...@@ -9,6 +9,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -268,7 +269,7 @@ class FlashMQAttention(torch.nn.Module): ...@@ -268,7 +269,7 @@ class FlashMQAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
qkv = self.c_attn(hidden_states) qkv = self.c_attn(hidden_states)
...@@ -291,12 +292,10 @@ class FlashMQAttention(torch.nn.Module): ...@@ -291,12 +292,10 @@ class FlashMQAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
...@@ -308,7 +307,7 @@ class FlashMQAttention(torch.nn.Module): ...@@ -308,7 +307,7 @@ class FlashMQAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -373,7 +372,7 @@ class Block(nn.Module): ...@@ -373,7 +372,7 @@ class Block(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states, residual = self.ln_1(hidden_states, residual)
...@@ -383,7 +382,7 @@ class Block(nn.Module): ...@@ -383,7 +382,7 @@ class Block(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -437,7 +436,7 @@ class FlashSantacoderModel(nn.Module): ...@@ -437,7 +436,7 @@ class FlashSantacoderModel(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.wte(input_ids) + self.wpe(position_ids) hidden_states = self.wte(input_ids) + self.wpe(position_ids)
...@@ -454,7 +453,7 @@ class FlashSantacoderModel(nn.Module): ...@@ -454,7 +453,7 @@ class FlashSantacoderModel(nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -486,7 +485,7 @@ class FlashSantacoderForCausalLM(nn.Module): ...@@ -486,7 +485,7 @@ class FlashSantacoderForCausalLM(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -499,7 +498,7 @@ class FlashSantacoderForCausalLM(nn.Module): ...@@ -499,7 +498,7 @@ class FlashSantacoderForCausalLM(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
......
...@@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( ...@@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -209,7 +210,7 @@ class Starcoder2Attention(torch.nn.Module): ...@@ -209,7 +210,7 @@ class Starcoder2Attention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
): ):
...@@ -240,12 +241,10 @@ class Starcoder2Attention(torch.nn.Module): ...@@ -240,12 +241,10 @@ class Starcoder2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
...@@ -258,7 +257,7 @@ class Starcoder2Attention(torch.nn.Module): ...@@ -258,7 +257,7 @@ class Starcoder2Attention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -381,7 +380,7 @@ class Starcoder2Layer(nn.Module): ...@@ -381,7 +380,7 @@ class Starcoder2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
): ):
...@@ -396,7 +395,7 @@ class Starcoder2Layer(nn.Module): ...@@ -396,7 +395,7 @@ class Starcoder2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
) )
...@@ -449,7 +448,7 @@ class Starcoder2Model(torch.nn.Module): ...@@ -449,7 +448,7 @@ class Starcoder2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
...@@ -473,7 +472,7 @@ class Starcoder2Model(torch.nn.Module): ...@@ -473,7 +472,7 @@ class Starcoder2Model(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
) )
...@@ -521,7 +520,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): ...@@ -521,7 +520,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -534,7 +533,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): ...@@ -534,7 +533,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
input_lengths = input_lengths.clamp(max=self.max_past_tensor) seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
...@@ -543,7 +542,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): ...@@ -543,7 +542,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
true_max_s, true_max_s,
prefill_cache_indices, prefill_cache_indices,
......
...@@ -25,6 +25,7 @@ from transformers.activations import ACT2FN ...@@ -25,6 +25,7 @@ from transformers.activations import ACT2FN
from text_generation_server.models.custom_modeling.vlm import ( from text_generation_server.models.custom_modeling.vlm import (
load_text_model, load_text_model,
) )
from text_generation_server.layers.attention import Seqlen
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from text_generation_server.layers import ( from text_generation_server.layers import (
...@@ -740,7 +741,7 @@ class Idefics2ForConditionalGeneration(nn.Module): ...@@ -740,7 +741,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -826,7 +827,7 @@ class Idefics2ForConditionalGeneration(nn.Module): ...@@ -826,7 +827,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths, seqlen=seqlen,
max_s=max_s, max_s=max_s,
true_max_s=max_s, true_max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,
......
...@@ -23,6 +23,7 @@ from torch import nn ...@@ -23,6 +23,7 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution from transformers.image_processing_utils import select_best_resolution
from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.custom_modeling.vlm import ( from text_generation_server.models.custom_modeling.vlm import (
load_text_model, load_text_model,
load_vision_model, load_vision_model,
...@@ -170,7 +171,7 @@ class LlavaNextForConditionalGeneration(nn.Module): ...@@ -170,7 +171,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -276,7 +277,7 @@ class LlavaNextForConditionalGeneration(nn.Module): ...@@ -276,7 +277,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths, seqlen=seqlen,
max_s=max_s, max_s=max_s,
true_max_s=max_s, true_max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,
......
...@@ -43,7 +43,7 @@ from text_generation_server.models.globals import ( ...@@ -43,7 +43,7 @@ from text_generation_server.models.globals import (
ATTENTION, ATTENTION,
BLOCK_SIZE, BLOCK_SIZE,
CUDA_GRAPHS, CUDA_GRAPHS,
PREFIX_CACHING, TGI_WIGGLE_ROOM,
get_adapter_to_index, get_adapter_to_index,
) )
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
...@@ -189,16 +189,21 @@ class FlashCausalLMBatch(Batch): ...@@ -189,16 +189,21 @@ class FlashCausalLMBatch(Batch):
def batch_tokenized_inputs( def batch_tokenized_inputs(
cls, requests: Iterable[generate_pb2.Request], tokenizer cls, requests: Iterable[generate_pb2.Request], tokenizer
): ):
batch_inputs = [] max_length = 0
max_truncation = 0 all_input_ids = []
batch_size = 0
for r in requests: for r in requests:
batch_inputs.append(concat_text_chunks(r.input_chunks.chunks)) batch_size += 1
max_truncation = max(max_truncation, r.truncate) inputs = concat_text_chunks(r.input_chunks.chunks)
input_ids = tokenizer(
batch_tokenized_inputs = tokenizer( inputs,
batch_inputs, truncation=True, max_length=max_truncation truncation=True,
)["input_ids"] max_length=r.truncate,
return batch_tokenized_inputs add_special_tokens=r.add_special_tokens,
)["input_ids"]
max_length = max(max_length, len(input_ids))
all_input_ids.append(input_ids)
return all_input_ids
@classmethod @classmethod
def from_tokenized( def from_tokenized(
...@@ -257,22 +262,15 @@ class FlashCausalLMBatch(Batch): ...@@ -257,22 +262,15 @@ class FlashCausalLMBatch(Batch):
# request id -> idx in list mapping # request id -> idx in list mapping
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
tokenized_input = tokenized_input[-r.truncate :]
if (
tokenized_input[0] == tokenizer.bos_token_id
and tokenized_input[1] == tokenizer.bos_token_id
):
tokenized_input = tokenized_input[1:]
orig_input_length = len(tokenized_input) orig_input_length = len(tokenized_input)
if PREFIX_CACHING: prefix_len = r.prefix_len
prefix_len = r.prefix_len assert (
if prefix_len == orig_input_length: prefix_len <= orig_input_length
assert prefix_len > 0 ), f"Prefix {prefix_len} vs input {orig_input_length}"
prefix_len -= 1 if prefix_len == orig_input_length:
else: assert prefix_len > 0
prefix_len = 0 prefix_len -= 1
prefix_ids.append(tokenized_input[:prefix_len]) prefix_ids.append(tokenized_input[:prefix_len])
tokenized_input = tokenized_input[prefix_len:] tokenized_input = tokenized_input[prefix_len:]
...@@ -998,7 +996,7 @@ class FlashCausalLM(Model): ...@@ -998,7 +996,7 @@ class FlashCausalLM(Model):
config.sliding_window = None config.sliding_window = None
self.num_layers = config.num_hidden_layers self.num_layers = config.num_hidden_layers
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads // self.process_group.size()
# Validation is done in the model itself # Validation is done in the model itself
if num_kv_heads is None: if num_kv_heads is None:
num_kv_heads = getattr(config, "num_key_value_heads", None) num_kv_heads = getattr(config, "num_key_value_heads", None)
...@@ -1160,8 +1158,15 @@ class FlashCausalLM(Model): ...@@ -1160,8 +1158,15 @@ class FlashCausalLM(Model):
"block_tables": block_tables, "block_tables": block_tables,
"slots": slots, "slots": slots,
"input_lengths": input_lengths_tensor, "input_lengths": input_lengths_tensor,
"prefix_lengths": prefix_lengths_tensor,
} }
input_lengths_ = Seqlen(input_lengths=input_lengths_tensor) seqlen = Seqlen(
input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph self.cuda_graphs[bs]["graph"] = graph
...@@ -1204,7 +1209,7 @@ class FlashCausalLM(Model): ...@@ -1204,7 +1209,7 @@ class FlashCausalLM(Model):
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths_, seqlen=seqlen,
max_s=max_s, max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
...@@ -1213,7 +1218,13 @@ class FlashCausalLM(Model): ...@@ -1213,7 +1218,13 @@ class FlashCausalLM(Model):
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL): with torch.cuda.graph(graph, pool=MEM_POOL):
input_lengths_tensor = Seqlen(input_lengths=input_lengths_tensor) seqlen = Seqlen(
input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1221,7 +1232,7 @@ class FlashCausalLM(Model): ...@@ -1221,7 +1232,7 @@ class FlashCausalLM(Model):
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths_tensor, seqlen=seqlen,
max_s=max_s, max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
...@@ -1268,7 +1279,7 @@ class FlashCausalLM(Model): ...@@ -1268,7 +1279,7 @@ class FlashCausalLM(Model):
num_blocks = ( num_blocks = (
# Leave 5% for some wiggle room # Leave 5% for some wiggle room
int((free_memory * 0.95) // total_cache_size) int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size)
# Add batch.num_blocks as we allocated it above, so it is included in the peak memory. # Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
+ batch_num_blocks + batch_num_blocks
) )
...@@ -1360,18 +1371,26 @@ class FlashCausalLM(Model): ...@@ -1360,18 +1371,26 @@ class FlashCausalLM(Model):
# Dummy value, some models (starcoder2) don't accept `None`. # Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
input_lengths = Seqlen(input_lengths=input_lengths) prefix_lens_tensor = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
cu_seqlen_prefill = torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32
)
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=1,
max_k=seqlen,
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward( self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=torch.tensor( cu_seqlen_prefill=cu_seqlen_prefill,
[0, seqlen], device=self.device, dtype=torch.int32
),
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
block_tables=None, block_tables=None,
input_lengths=input_lengths, seqlen=seqlen,
slots=slots, slots=slots,
max_s=seqlen, max_s=seqlen,
lm_head_indices=None, lm_head_indices=None,
...@@ -1451,8 +1470,7 @@ class FlashCausalLM(Model): ...@@ -1451,8 +1470,7 @@ class FlashCausalLM(Model):
cuda_graph = None cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = input_lengths + prefix_lens_tensor if ATTENTION == "flashinfer":
if PREFIX_CACHING:
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
input_lengths=batch.input_lengths, input_lengths=batch.input_lengths,
...@@ -1462,11 +1480,18 @@ class FlashCausalLM(Model): ...@@ -1462,11 +1480,18 @@ class FlashCausalLM(Model):
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths=batch.input_lengths, input_lengths=batch.input_lengths,
input_lengths_tensor=input_lengths, input_lengths_tensor=input_lengths + prefix_lens_tensor,
prefix_lens=batch.prefix_lens, prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor, prefix_lens_tensor=prefix_lens_tensor,
): ):
input_lengths = Seqlen(input_lengths=input_lengths) max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s,
max_k=max_k,
)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1474,7 +1499,7 @@ class FlashCausalLM(Model): ...@@ -1474,7 +1499,7 @@ class FlashCausalLM(Model):
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths, seqlen=seqlen,
max_s=max_s, max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices, prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
......
...@@ -5,19 +5,22 @@ from typing import Dict, Optional ...@@ -5,19 +5,22 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "0").lower() in {"1", "true"} PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged") ATTENTION = os.getenv("ATTENTION")
_expected = {"paged", "flashdecoding", "flashinfer"} _expected = {"paged", "flashdecoding", "flashinfer"}
assert ( assert (
ATTENTION in _expected ATTENTION in _expected
), f"Attention is not valid {ATTENTION}, expected {_expected}" ), f"Attention is not valid {ATTENTION}, expected {_expected}"
log_master(logger.info, f"Using Attention = {ATTENTION}") log_master(logger.info, f"Using Attention = {ATTENTION}")
if PREFIX_CACHING and ATTENTION != "flashinfer": if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
raise RuntimeError("Prefix caching is only supported with flashinfer") raise RuntimeError("Prefix caching is only supported with flashinfer")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95"))
assert TGI_WIGGLE_ROOM > 0
assert TGI_WIGGLE_ROOM < 1
# This is overridden by the cli # This is overridden by the cli
BLOCK_SIZE: int BLOCK_SIZE: int
......
...@@ -372,7 +372,14 @@ class VlmCausalLM(FlashCausalLM): ...@@ -372,7 +372,14 @@ class VlmCausalLM(FlashCausalLM):
prefix_lens=batch.prefix_lens, prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor, prefix_lens_tensor=prefix_lens_tensor,
): ):
input_lengths = Seqlen(input_lengths=input_lengths) max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s,
max_k=max_k,
)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -380,7 +387,7 @@ class VlmCausalLM(FlashCausalLM): ...@@ -380,7 +387,7 @@ class VlmCausalLM(FlashCausalLM):
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths, seqlen=seqlen,
max_s=max_s, max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices, prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment