Unverified Commit e415b690 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Lots of improvements (Still 2 allocators) (#2449)



* Making prefix/flashinfer the default and testing the full release tests.

* Include flashinfer in the docker.

* Using prebuilt.

* Allowing window_left_size (dummy version).

* Disabling flashinfer/prefix caching on odd head_dim

* Disable prefix caching for lora.

* More specific codes.

* Update lock

* Updating integration tests with new values with FI/FD.

Remove paged as a default too, and using FD everywhere.

* Update cargo lock ?

* Upgrade to 1.80 because of bitstream...

* Everywhere 1.80

* Forgot last default place.

* Apply suggestions from code review
Co-authored-by: default avatardrbh <david.richard.holtz@gmail.com>

* Updated flake lock

* Tmp

* Upgrade resolution system for less errors in resolution.

* Remove lambda for cleaner function.

* Handling debugger.

* OVerride the env in server tests.

* Is this enough to make it work ?

* This seems to be working.

* Downgrade some logs.

* Fixing the default for vlm.

* Don't enable prefix caching on VLM just yet.

* Change `add_special_tokens` in order to have the correct tokens for chat
input and not (since it's super important with the prefixing now)

* Fixing prefix caching for flashdecoding.

* Update all models.

* Fixed flashinfer version.

* add_special_tokens is internal only

* Fixing seqlen with the new vlms.

* Fixing the issue with `add_special_tokens` not being passed around.

* Fixing the test.

* Removing encoder_decoder (seq2seq).

* Update the chat test.

* Fixing the batching tokenization in flash causal lm.

* Truncating left for radix purposes.

* Oops this doesn't belong here.

* Put back default pure shell.

* Update server tests

- Default to throughput test in k6
- Use TGI_WIGGLE_ROOM to adjust wiggle room

* Only n_heads / process_group.size() are necessary.

* Revert the integrationt tests change (seem linked to head_size
modification).

* Adding error message when assert is violated.

* Fixing the free algorithm to handle times where the common prefix is
smaller.

* Apply suggestions from code review
Co-authored-by: default avatarOlivierDehaene <olivier@huggingface.co>

* Update server/text_generation_server/layers/attention/common.py
Co-authored-by: default avatarOlivierDehaene <olivier@huggingface.co>

* Fix disabling prefix caching - Fix windowing checks.

* Revert the Cohere tokenizer change (for now using a revision instead).

* Fmt.

---------
Co-authored-by: default avatardrbh <david.richard.holtz@gmail.com>
Co-authored-by: default avatarOlivierDehaene <olivier@huggingface.co>
parent 4e821c00
...@@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( ...@@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -213,7 +214,7 @@ class FlashGemma2Attention(torch.nn.Module): ...@@ -213,7 +214,7 @@ class FlashGemma2Attention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
...@@ -236,12 +237,10 @@ class FlashGemma2Attention(torch.nn.Module): ...@@ -236,12 +237,10 @@ class FlashGemma2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
causal=self.causal, causal=self.causal,
window_size_left=self.window_size, window_size_left=self.window_size,
...@@ -256,7 +255,7 @@ class FlashGemma2Attention(torch.nn.Module): ...@@ -256,7 +255,7 @@ class FlashGemma2Attention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
softcap=self.softcap, softcap=self.softcap,
) )
...@@ -343,7 +342,7 @@ class FlashGemma2Layer(nn.Module): ...@@ -343,7 +342,7 @@ class FlashGemma2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
...@@ -357,7 +356,7 @@ class FlashGemma2Layer(nn.Module): ...@@ -357,7 +356,7 @@ class FlashGemma2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -408,7 +407,7 @@ class FlashGemma2Model(torch.nn.Module): ...@@ -408,7 +407,7 @@ class FlashGemma2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -430,7 +429,7 @@ class FlashGemma2Model(torch.nn.Module): ...@@ -430,7 +429,7 @@ class FlashGemma2Model(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -477,7 +476,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module): ...@@ -477,7 +476,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -491,7 +490,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module): ...@@ -491,7 +490,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
......
...@@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( ...@@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -207,7 +208,7 @@ class FlashGemmaAttention(torch.nn.Module): ...@@ -207,7 +208,7 @@ class FlashGemmaAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
...@@ -230,12 +231,10 @@ class FlashGemmaAttention(torch.nn.Module): ...@@ -230,12 +231,10 @@ class FlashGemmaAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
causal=self.causal, causal=self.causal,
) )
...@@ -248,7 +247,7 @@ class FlashGemmaAttention(torch.nn.Module): ...@@ -248,7 +247,7 @@ class FlashGemmaAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -320,7 +319,7 @@ class FlashGemmaLayer(nn.Module): ...@@ -320,7 +319,7 @@ class FlashGemmaLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
...@@ -334,7 +333,7 @@ class FlashGemmaLayer(nn.Module): ...@@ -334,7 +333,7 @@ class FlashGemmaLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -382,7 +381,7 @@ class FlashGemmaModel(torch.nn.Module): ...@@ -382,7 +381,7 @@ class FlashGemmaModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -404,7 +403,7 @@ class FlashGemmaModel(torch.nn.Module): ...@@ -404,7 +403,7 @@ class FlashGemmaModel(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -449,7 +448,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): ...@@ -449,7 +448,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -463,7 +462,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): ...@@ -463,7 +462,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
......
...@@ -29,6 +29,7 @@ from text_generation_server.layers.attention import ( ...@@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -213,7 +214,7 @@ class FlashGPT2Attention(torch.nn.Module): ...@@ -213,7 +214,7 @@ class FlashGPT2Attention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
query, key, value = self.query_key_value(hidden_states).split( query, key, value = self.query_key_value(hidden_states).split(
...@@ -230,12 +231,10 @@ class FlashGPT2Attention(torch.nn.Module): ...@@ -230,12 +231,10 @@ class FlashGPT2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
key,
value,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
...@@ -247,7 +246,7 @@ class FlashGPT2Attention(torch.nn.Module): ...@@ -247,7 +246,7 @@ class FlashGPT2Attention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -316,7 +315,7 @@ class FlashGPT2Layer(nn.Module): ...@@ -316,7 +315,7 @@ class FlashGPT2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
residual = hidden_states residual = hidden_states
...@@ -329,7 +328,7 @@ class FlashGPT2Layer(nn.Module): ...@@ -329,7 +328,7 @@ class FlashGPT2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -382,7 +381,7 @@ class FlashGPT2Model(torch.nn.Module): ...@@ -382,7 +381,7 @@ class FlashGPT2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
...@@ -398,7 +397,7 @@ class FlashGPT2Model(torch.nn.Module): ...@@ -398,7 +397,7 @@ class FlashGPT2Model(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -435,7 +434,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module): ...@@ -435,7 +434,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -451,7 +450,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module): ...@@ -451,7 +450,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
true_max_s=max_s, true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,
......
...@@ -29,6 +29,7 @@ from text_generation_server.layers.attention import ( ...@@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -167,7 +168,7 @@ class FlashGPTJAttention(torch.nn.Module): ...@@ -167,7 +168,7 @@ class FlashGPTJAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
query, key, value = self.query_key_value(hidden_states).split( query, key, value = self.query_key_value(hidden_states).split(
...@@ -192,10 +193,10 @@ class FlashGPTJAttention(torch.nn.Module): ...@@ -192,10 +193,10 @@ class FlashGPTJAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
key, kv_cache[0],
value, kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
...@@ -207,7 +208,7 @@ class FlashGPTJAttention(torch.nn.Module): ...@@ -207,7 +208,7 @@ class FlashGPTJAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -268,7 +269,7 @@ class FlashGPTJLayer(nn.Module): ...@@ -268,7 +269,7 @@ class FlashGPTJLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states, residual = self.input_layernorm(hidden_states, residual)
...@@ -281,7 +282,7 @@ class FlashGPTJLayer(nn.Module): ...@@ -281,7 +282,7 @@ class FlashGPTJLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -328,7 +329,7 @@ class FlashGPTJModel(torch.nn.Module): ...@@ -328,7 +329,7 @@ class FlashGPTJModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -351,7 +352,7 @@ class FlashGPTJModel(torch.nn.Module): ...@@ -351,7 +352,7 @@ class FlashGPTJModel(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -382,7 +383,7 @@ class FlashGPTJForCausalLM(torch.nn.Module): ...@@ -382,7 +383,7 @@ class FlashGPTJForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -395,7 +396,7 @@ class FlashGPTJForCausalLM(torch.nn.Module): ...@@ -395,7 +396,7 @@ class FlashGPTJForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,
) )
......
...@@ -32,6 +32,7 @@ from text_generation_server.layers.attention import ( ...@@ -32,6 +32,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -194,7 +195,7 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -194,7 +195,7 @@ class FlashLlamaAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
adapter_data, adapter_data,
): ):
...@@ -218,12 +219,10 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -218,12 +219,10 @@ class FlashLlamaAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
...@@ -235,7 +234,7 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -235,7 +234,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -375,7 +374,7 @@ class FlashLlamaLayer(nn.Module): ...@@ -375,7 +374,7 @@ class FlashLlamaLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
adapter_data, adapter_data,
): ):
...@@ -390,7 +389,7 @@ class FlashLlamaLayer(nn.Module): ...@@ -390,7 +389,7 @@ class FlashLlamaLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
adapter_data, adapter_data,
) )
...@@ -479,7 +478,7 @@ class FlashLlamaModel(torch.nn.Module): ...@@ -479,7 +478,7 @@ class FlashLlamaModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
...@@ -504,7 +503,7 @@ class FlashLlamaModel(torch.nn.Module): ...@@ -504,7 +503,7 @@ class FlashLlamaModel(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
adapter_data, adapter_data,
) )
...@@ -548,7 +547,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): ...@@ -548,7 +547,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -562,7 +561,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): ...@@ -562,7 +561,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
true_max_s=max_s, true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,
......
...@@ -31,6 +31,7 @@ from text_generation_server.layers.attention import ( ...@@ -31,6 +31,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -185,7 +186,7 @@ class MistralAttention(torch.nn.Module): ...@@ -185,7 +186,7 @@ class MistralAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data, adapter_data,
...@@ -217,12 +218,10 @@ class MistralAttention(torch.nn.Module): ...@@ -217,12 +218,10 @@ class MistralAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
...@@ -235,7 +234,7 @@ class MistralAttention(torch.nn.Module): ...@@ -235,7 +234,7 @@ class MistralAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -356,7 +355,7 @@ class MistralLayer(nn.Module): ...@@ -356,7 +355,7 @@ class MistralLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data, adapter_data,
...@@ -372,7 +371,7 @@ class MistralLayer(nn.Module): ...@@ -372,7 +371,7 @@ class MistralLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data, adapter_data,
...@@ -424,7 +423,7 @@ class MistralModel(torch.nn.Module): ...@@ -424,7 +423,7 @@ class MistralModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
...@@ -448,7 +447,7 @@ class MistralModel(torch.nn.Module): ...@@ -448,7 +447,7 @@ class MistralModel(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data, adapter_data,
...@@ -499,7 +498,7 @@ class FlashMistralForCausalLM(torch.nn.Module): ...@@ -499,7 +498,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -512,7 +511,7 @@ class FlashMistralForCausalLM(torch.nn.Module): ...@@ -512,7 +511,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
input_lengths = input_lengths.clamp(max=self.max_past_tensor) seqlen = seqlen.clamp(max=self.max_past_tensor)
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(
...@@ -522,7 +521,7 @@ class FlashMistralForCausalLM(torch.nn.Module): ...@@ -522,7 +521,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
true_max_s, true_max_s,
prefill_cache_indices, prefill_cache_indices,
......
...@@ -35,6 +35,7 @@ from text_generation_server.layers.attention import ( ...@@ -35,6 +35,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
FastLinear, FastLinear,
...@@ -243,7 +244,7 @@ class MixtralAttention(torch.nn.Module): ...@@ -243,7 +244,7 @@ class MixtralAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
): ):
...@@ -274,12 +275,10 @@ class MixtralAttention(torch.nn.Module): ...@@ -274,12 +275,10 @@ class MixtralAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
...@@ -292,7 +291,7 @@ class MixtralAttention(torch.nn.Module): ...@@ -292,7 +291,7 @@ class MixtralAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -498,7 +497,7 @@ class MixtralLayer(nn.Module): ...@@ -498,7 +497,7 @@ class MixtralLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
): ):
...@@ -513,7 +512,7 @@ class MixtralLayer(nn.Module): ...@@ -513,7 +512,7 @@ class MixtralLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
) )
...@@ -568,7 +567,7 @@ class MixtralModel(torch.nn.Module): ...@@ -568,7 +567,7 @@ class MixtralModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
...@@ -592,7 +591,7 @@ class MixtralModel(torch.nn.Module): ...@@ -592,7 +591,7 @@ class MixtralModel(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
) )
...@@ -627,7 +626,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): ...@@ -627,7 +626,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -640,7 +639,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): ...@@ -640,7 +639,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
input_lengths = input_lengths.clamp(max=self.max_past_tensor) seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
...@@ -649,7 +648,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): ...@@ -649,7 +648,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
true_max_s, true_max_s,
prefill_cache_indices, prefill_cache_indices,
......
...@@ -31,6 +31,7 @@ from text_generation_server.layers.attention import ( ...@@ -31,6 +31,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -147,7 +148,7 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -147,7 +148,7 @@ class FlashNeoxAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
...@@ -171,12 +172,10 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -171,12 +172,10 @@ class FlashNeoxAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
qkv[:, 0], qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
...@@ -188,7 +187,7 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -188,7 +187,7 @@ class FlashNeoxAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -258,7 +257,7 @@ class FlashNeoXLayer(nn.Module): ...@@ -258,7 +257,7 @@ class FlashNeoXLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
if self.use_parallel_residual: if self.use_parallel_residual:
...@@ -272,7 +271,7 @@ class FlashNeoXLayer(nn.Module): ...@@ -272,7 +271,7 @@ class FlashNeoXLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -296,7 +295,7 @@ class FlashNeoXLayer(nn.Module): ...@@ -296,7 +295,7 @@ class FlashNeoXLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -350,7 +349,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): ...@@ -350,7 +349,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_in(input_ids) hidden_states = self.embed_in(input_ids)
...@@ -372,7 +371,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): ...@@ -372,7 +371,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -404,7 +403,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): ...@@ -404,7 +403,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -417,7 +416,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): ...@@ -417,7 +416,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
......
...@@ -19,6 +19,7 @@ from torch import nn ...@@ -19,6 +19,7 @@ from torch import nn
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.custom_modeling.vlm import ( from text_generation_server.models.custom_modeling.vlm import (
load_text_model, load_text_model,
load_vision_model, load_vision_model,
...@@ -70,7 +71,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): ...@@ -70,7 +71,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -107,7 +108,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): ...@@ -107,7 +108,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths, seqlen=seqlen,
max_s=max_s, max_s=max_s,
) )
......
...@@ -10,6 +10,7 @@ from text_generation_server.layers.attention import ( ...@@ -10,6 +10,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -159,7 +160,7 @@ class FlashPhiAttention(torch.nn.Module): ...@@ -159,7 +160,7 @@ class FlashPhiAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
# Compute query, key, value and split # Compute query, key, value and split
...@@ -192,12 +193,10 @@ class FlashPhiAttention(torch.nn.Module): ...@@ -192,12 +193,10 @@ class FlashPhiAttention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
...@@ -209,7 +208,7 @@ class FlashPhiAttention(torch.nn.Module): ...@@ -209,7 +208,7 @@ class FlashPhiAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -276,7 +275,7 @@ class FlashPhiLayer(nn.Module): ...@@ -276,7 +275,7 @@ class FlashPhiLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
hidden_states, res = self.input_layernorm(hidden_states, residual) hidden_states, res = self.input_layernorm(hidden_states, residual)
...@@ -289,7 +288,7 @@ class FlashPhiLayer(nn.Module): ...@@ -289,7 +288,7 @@ class FlashPhiLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -341,7 +340,7 @@ class FlashPhiModel(torch.nn.Module): ...@@ -341,7 +340,7 @@ class FlashPhiModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
...@@ -363,7 +362,7 @@ class FlashPhiModel(torch.nn.Module): ...@@ -363,7 +362,7 @@ class FlashPhiModel(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -396,7 +395,7 @@ class FlashPhiForCausalLM(torch.nn.Module): ...@@ -396,7 +395,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -409,7 +408,7 @@ class FlashPhiForCausalLM(torch.nn.Module): ...@@ -409,7 +408,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
......
...@@ -9,6 +9,7 @@ from text_generation_server.layers.attention import ( ...@@ -9,6 +9,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -104,7 +105,7 @@ class Qwen2Attention(torch.nn.Module): ...@@ -104,7 +105,7 @@ class Qwen2Attention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
): ):
...@@ -135,12 +136,10 @@ class Qwen2Attention(torch.nn.Module): ...@@ -135,12 +136,10 @@ class Qwen2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
...@@ -153,7 +152,7 @@ class Qwen2Attention(torch.nn.Module): ...@@ -153,7 +152,7 @@ class Qwen2Attention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -225,7 +224,7 @@ class Qwen2Layer(nn.Module): ...@@ -225,7 +224,7 @@ class Qwen2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
): ):
...@@ -240,7 +239,7 @@ class Qwen2Layer(nn.Module): ...@@ -240,7 +239,7 @@ class Qwen2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
) )
...@@ -296,7 +295,7 @@ class Qwen2Model(torch.nn.Module): ...@@ -296,7 +295,7 @@ class Qwen2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
...@@ -320,7 +319,7 @@ class Qwen2Model(torch.nn.Module): ...@@ -320,7 +319,7 @@ class Qwen2Model(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
) )
...@@ -361,7 +360,7 @@ class Qwen2ForCausalLM(torch.nn.Module): ...@@ -361,7 +360,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -374,7 +373,7 @@ class Qwen2ForCausalLM(torch.nn.Module): ...@@ -374,7 +373,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
input_lengths = input_lengths.clamp(max=self.max_past_tensor) seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
...@@ -383,7 +382,7 @@ class Qwen2ForCausalLM(torch.nn.Module): ...@@ -383,7 +382,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
true_max_s, true_max_s,
prefill_cache_indices, prefill_cache_indices,
......
...@@ -9,6 +9,7 @@ from text_generation_server.layers.attention import ( ...@@ -9,6 +9,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -268,7 +269,7 @@ class FlashMQAttention(torch.nn.Module): ...@@ -268,7 +269,7 @@ class FlashMQAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
qkv = self.c_attn(hidden_states) qkv = self.c_attn(hidden_states)
...@@ -291,12 +292,10 @@ class FlashMQAttention(torch.nn.Module): ...@@ -291,12 +292,10 @@ class FlashMQAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
...@@ -308,7 +307,7 @@ class FlashMQAttention(torch.nn.Module): ...@@ -308,7 +307,7 @@ class FlashMQAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -373,7 +372,7 @@ class Block(nn.Module): ...@@ -373,7 +372,7 @@ class Block(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
): ):
hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states, residual = self.ln_1(hidden_states, residual)
...@@ -383,7 +382,7 @@ class Block(nn.Module): ...@@ -383,7 +382,7 @@ class Block(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -437,7 +436,7 @@ class FlashSantacoderModel(nn.Module): ...@@ -437,7 +436,7 @@ class FlashSantacoderModel(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.wte(input_ids) + self.wpe(position_ids) hidden_states = self.wte(input_ids) + self.wpe(position_ids)
...@@ -454,7 +453,7 @@ class FlashSantacoderModel(nn.Module): ...@@ -454,7 +453,7 @@ class FlashSantacoderModel(nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
...@@ -486,7 +485,7 @@ class FlashSantacoderForCausalLM(nn.Module): ...@@ -486,7 +485,7 @@ class FlashSantacoderForCausalLM(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -499,7 +498,7 @@ class FlashSantacoderForCausalLM(nn.Module): ...@@ -499,7 +498,7 @@ class FlashSantacoderForCausalLM(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment