Unverified Commit e415b690 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Lots of improvements (Still 2 allocators) (#2449)



* Making prefix/flashinfer the default and testing the full release tests.

* Include flashinfer in the docker.

* Using prebuilt.

* Allowing window_left_size (dummy version).

* Disabling flashinfer/prefix caching on odd head_dim

* Disable prefix caching for lora.

* More specific codes.

* Update lock

* Updating integration tests with new values with FI/FD.

Remove paged as a default too, and using FD everywhere.

* Update cargo lock ?

* Upgrade to 1.80 because of bitstream...

* Everywhere 1.80

* Forgot last default place.

* Apply suggestions from code review
Co-authored-by: default avatardrbh <david.richard.holtz@gmail.com>

* Updated flake lock

* Tmp

* Upgrade resolution system for less errors in resolution.

* Remove lambda for cleaner function.

* Handling debugger.

* OVerride the env in server tests.

* Is this enough to make it work ?

* This seems to be working.

* Downgrade some logs.

* Fixing the default for vlm.

* Don't enable prefix caching on VLM just yet.

* Change `add_special_tokens` in order to have the correct tokens for chat
input and not (since it's super important with the prefixing now)

* Fixing prefix caching for flashdecoding.

* Update all models.

* Fixed flashinfer version.

* add_special_tokens is internal only

* Fixing seqlen with the new vlms.

* Fixing the issue with `add_special_tokens` not being passed around.

* Fixing the test.

* Removing encoder_decoder (seq2seq).

* Update the chat test.

* Fixing the batching tokenization in flash causal lm.

* Truncating left for radix purposes.

* Oops this doesn't belong here.

* Put back default pure shell.

* Update server tests

- Default to throughput test in k6
- Use TGI_WIGGLE_ROOM to adjust wiggle room

* Only n_heads / process_group.size() are necessary.

* Revert the integrationt tests change (seem linked to head_size
modification).

* Adding error message when assert is violated.

* Fixing the free algorithm to handle times where the common prefix is
smaller.

* Apply suggestions from code review
Co-authored-by: default avatarOlivierDehaene <olivier@huggingface.co>

* Update server/text_generation_server/layers/attention/common.py
Co-authored-by: default avatarOlivierDehaene <olivier@huggingface.co>

* Fix disabling prefix caching - Fix windowing checks.

* Revert the Cohere tokenizer change (for now using a revision instead).

* Fmt.

---------
Co-authored-by: default avatardrbh <david.richard.holtz@gmail.com>
Co-authored-by: default avatarOlivierDehaene <olivier@huggingface.co>
parent 4e821c00
......@@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
......@@ -213,7 +214,7 @@ class FlashGemma2Attention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
qkv = self.query_key_value(hidden_states)
......@@ -236,12 +237,10 @@ class FlashGemma2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
causal=self.causal,
window_size_left=self.window_size,
......@@ -256,7 +255,7 @@ class FlashGemma2Attention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
softcap=self.softcap,
)
......@@ -343,7 +342,7 @@ class FlashGemma2Layer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
......@@ -357,7 +356,7 @@ class FlashGemma2Layer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -408,7 +407,7 @@ class FlashGemma2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
) -> torch.Tensor:
hidden_states = inputs_embeds
......@@ -430,7 +429,7 @@ class FlashGemma2Model(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -477,7 +476,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -491,7 +490,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
if lm_head_indices is not None:
......
......@@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
......@@ -207,7 +208,7 @@ class FlashGemmaAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
qkv = self.query_key_value(hidden_states)
......@@ -230,12 +231,10 @@ class FlashGemmaAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
causal=self.causal,
)
......@@ -248,7 +247,7 @@ class FlashGemmaAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -320,7 +319,7 @@ class FlashGemmaLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
......@@ -334,7 +333,7 @@ class FlashGemmaLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -382,7 +381,7 @@ class FlashGemmaModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
) -> torch.Tensor:
hidden_states = inputs_embeds
......@@ -404,7 +403,7 @@ class FlashGemmaModel(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -449,7 +448,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -463,7 +462,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
if lm_head_indices is not None:
......
......@@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
......@@ -213,7 +214,7 @@ class FlashGPT2Attention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
query, key, value = self.query_key_value(hidden_states).split(
......@@ -230,12 +231,10 @@ class FlashGPT2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
key,
value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
......@@ -247,7 +246,7 @@ class FlashGPT2Attention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -316,7 +315,7 @@ class FlashGPT2Layer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
residual = hidden_states
......@@ -329,7 +328,7 @@ class FlashGPT2Layer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -382,7 +381,7 @@ class FlashGPT2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
......@@ -398,7 +397,7 @@ class FlashGPT2Model(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -435,7 +434,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -451,7 +450,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
......
......@@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
......@@ -167,7 +168,7 @@ class FlashGPTJAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
query, key, value = self.query_key_value(hidden_states).split(
......@@ -192,10 +193,10 @@ class FlashGPTJAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
key,
value,
cu_seqlen_prefill,
max_s,
kv_cache[0],
kv_cache[1],
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
......@@ -207,7 +208,7 @@ class FlashGPTJAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -268,7 +269,7 @@ class FlashGPTJLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
hidden_states, residual = self.input_layernorm(hidden_states, residual)
......@@ -281,7 +282,7 @@ class FlashGPTJLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -328,7 +329,7 @@ class FlashGPTJModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
......@@ -351,7 +352,7 @@ class FlashGPTJModel(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -382,7 +383,7 @@ class FlashGPTJForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -395,7 +396,7 @@ class FlashGPTJForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices=prefill_cache_indices,
)
......
......@@ -32,6 +32,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
......@@ -194,7 +195,7 @@ class FlashLlamaAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
adapter_data,
):
......@@ -218,12 +219,10 @@ class FlashLlamaAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
......@@ -235,7 +234,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -375,7 +374,7 @@ class FlashLlamaLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
adapter_data,
):
......@@ -390,7 +389,7 @@ class FlashLlamaLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
adapter_data,
)
......@@ -479,7 +478,7 @@ class FlashLlamaModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
......@@ -504,7 +503,7 @@ class FlashLlamaModel(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
adapter_data,
)
......@@ -548,7 +547,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -562,7 +561,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
......
......@@ -31,6 +31,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
......@@ -185,7 +186,7 @@ class MistralAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
......@@ -217,12 +218,10 @@ class MistralAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
window_size_left=self.max_past,
)
......@@ -235,7 +234,7 @@ class MistralAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -356,7 +355,7 @@ class MistralLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
......@@ -372,7 +371,7 @@ class MistralLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
......@@ -424,7 +423,7 @@ class MistralModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
......@@ -448,7 +447,7 @@ class MistralModel(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
......@@ -499,7 +498,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -512,7 +511,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
seqlen = seqlen.clamp(max=self.max_past_tensor)
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
......@@ -522,7 +521,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,
......
......@@ -35,6 +35,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
FastLinear,
......@@ -243,7 +244,7 @@ class MixtralAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
):
......@@ -274,12 +275,10 @@ class MixtralAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
window_size_left=self.max_past,
)
......@@ -292,7 +291,7 @@ class MixtralAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -498,7 +497,7 @@ class MixtralLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
):
......@@ -513,7 +512,7 @@ class MixtralLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
)
......@@ -568,7 +567,7 @@ class MixtralModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
......@@ -592,7 +591,7 @@ class MixtralModel(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
)
......@@ -627,7 +626,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -640,7 +639,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model(
input_ids,
......@@ -649,7 +648,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,
......
......@@ -31,6 +31,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
......@@ -147,7 +148,7 @@ class FlashNeoxAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
qkv = self.query_key_value(hidden_states)
......@@ -171,12 +172,10 @@ class FlashNeoxAttention(torch.nn.Module):
# flash attention
attn_output = attention(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
......@@ -188,7 +187,7 @@ class FlashNeoxAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -258,7 +257,7 @@ class FlashNeoXLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
if self.use_parallel_residual:
......@@ -272,7 +271,7 @@ class FlashNeoXLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -296,7 +295,7 @@ class FlashNeoXLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -350,7 +349,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_in(input_ids)
......@@ -372,7 +371,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -404,7 +403,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -417,7 +416,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
if lm_head_indices is not None:
......
......@@ -19,6 +19,7 @@ from torch import nn
from typing import Optional, List, Tuple
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
......@@ -70,7 +71,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -107,7 +108,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
seqlen=seqlen,
max_s=max_s,
)
......
......@@ -10,6 +10,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
......@@ -159,7 +160,7 @@ class FlashPhiAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
# Compute query, key, value and split
......@@ -192,12 +193,10 @@ class FlashPhiAttention(torch.nn.Module):
if cu_seqlen_prefill is not None:
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
......@@ -209,7 +208,7 @@ class FlashPhiAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -276,7 +275,7 @@ class FlashPhiLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
hidden_states, res = self.input_layernorm(hidden_states, residual)
......@@ -289,7 +288,7 @@ class FlashPhiLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -341,7 +340,7 @@ class FlashPhiModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
......@@ -363,7 +362,7 @@ class FlashPhiModel(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -396,7 +395,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -409,7 +408,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
if lm_head_indices is not None:
......
......@@ -9,6 +9,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
......@@ -104,7 +105,7 @@ class Qwen2Attention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
):
......@@ -135,12 +136,10 @@ class Qwen2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
window_size_left=self.max_past,
)
......@@ -153,7 +152,7 @@ class Qwen2Attention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -225,7 +224,7 @@ class Qwen2Layer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
):
......@@ -240,7 +239,7 @@ class Qwen2Layer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
)
......@@ -296,7 +295,7 @@ class Qwen2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
......@@ -320,7 +319,7 @@ class Qwen2Model(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
)
......@@ -361,7 +360,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -374,7 +373,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model(
input_ids,
......@@ -383,7 +382,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,
......
......@@ -19,6 +19,7 @@ from text_generation_server.layers.attention import (
attention,
paged_attention,
reshape_and_cache,
Seqlen,
)
......@@ -181,7 +182,7 @@ class FlashRWAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
qkv = self.query_key_value(hidden_states)
......@@ -206,12 +207,10 @@ class FlashRWAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
......@@ -223,7 +222,7 @@ class FlashRWAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -296,7 +295,7 @@ class FlashRWLargeAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
qkv = self.query_key_value(hidden_states)
......@@ -343,7 +342,7 @@ class FlashRWLargeAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -429,7 +428,7 @@ class FlashRWLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
if self.parallel_attn:
......@@ -443,7 +442,7 @@ class FlashRWLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -465,7 +464,7 @@ class FlashRWLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -552,7 +551,7 @@ class FlashRWLargeLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
# Layer norm.
......@@ -567,7 +566,7 @@ class FlashRWLargeLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -628,7 +627,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
......@@ -650,7 +649,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -680,7 +679,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -693,7 +692,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
if lm_head_indices is not None:
......
......@@ -9,6 +9,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
......@@ -268,7 +269,7 @@ class FlashMQAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
qkv = self.c_attn(hidden_states)
......@@ -291,12 +292,10 @@ class FlashMQAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
......@@ -308,7 +307,7 @@ class FlashMQAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -373,7 +372,7 @@ class Block(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
hidden_states, residual = self.ln_1(hidden_states, residual)
......@@ -383,7 +382,7 @@ class Block(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -437,7 +436,7 @@ class FlashSantacoderModel(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
) -> torch.Tensor:
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
......@@ -454,7 +453,7 @@ class FlashSantacoderModel(nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -486,7 +485,7 @@ class FlashSantacoderForCausalLM(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -499,7 +498,7 @@ class FlashSantacoderForCausalLM(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
if lm_head_indices is not None:
......
......@@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
......@@ -209,7 +210,7 @@ class Starcoder2Attention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
):
......@@ -240,12 +241,10 @@ class Starcoder2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
window_size_left=self.max_past,
)
......@@ -258,7 +257,7 @@ class Starcoder2Attention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -381,7 +380,7 @@ class Starcoder2Layer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
):
......@@ -396,7 +395,7 @@ class Starcoder2Layer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
)
......@@ -449,7 +448,7 @@ class Starcoder2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
......@@ -473,7 +472,7 @@ class Starcoder2Model(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
)
......@@ -521,7 +520,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -534,7 +533,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model(
input_ids,
......@@ -543,7 +542,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,
......
......@@ -25,6 +25,7 @@ from transformers.activations import ACT2FN
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
)
from text_generation_server.layers.attention import Seqlen
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from text_generation_server.layers import (
......@@ -740,7 +741,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -826,7 +827,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
seqlen=seqlen,
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=None,
......
......@@ -23,6 +23,7 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
......@@ -170,7 +171,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -276,7 +277,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
seqlen=seqlen,
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=None,
......
......@@ -43,7 +43,7 @@ from text_generation_server.models.globals import (
ATTENTION,
BLOCK_SIZE,
CUDA_GRAPHS,
PREFIX_CACHING,
TGI_WIGGLE_ROOM,
get_adapter_to_index,
)
from text_generation_server.layers.attention import Seqlen
......@@ -189,16 +189,21 @@ class FlashCausalLMBatch(Batch):
def batch_tokenized_inputs(
cls, requests: Iterable[generate_pb2.Request], tokenizer
):
batch_inputs = []
max_truncation = 0
max_length = 0
all_input_ids = []
batch_size = 0
for r in requests:
batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"]
return batch_tokenized_inputs
batch_size += 1
inputs = concat_text_chunks(r.input_chunks.chunks)
input_ids = tokenizer(
inputs,
truncation=True,
max_length=r.truncate,
add_special_tokens=r.add_special_tokens,
)["input_ids"]
max_length = max(max_length, len(input_ids))
all_input_ids.append(input_ids)
return all_input_ids
@classmethod
def from_tokenized(
......@@ -257,22 +262,15 @@ class FlashCausalLMBatch(Batch):
# request id -> idx in list mapping
requests_idx_mapping[r.id] = i
tokenized_input = tokenized_input[-r.truncate :]
if (
tokenized_input[0] == tokenizer.bos_token_id
and tokenized_input[1] == tokenizer.bos_token_id
):
tokenized_input = tokenized_input[1:]
orig_input_length = len(tokenized_input)
if PREFIX_CACHING:
prefix_len = r.prefix_len
if prefix_len == orig_input_length:
assert prefix_len > 0
prefix_len -= 1
else:
prefix_len = 0
prefix_len = r.prefix_len
assert (
prefix_len <= orig_input_length
), f"Prefix {prefix_len} vs input {orig_input_length}"
if prefix_len == orig_input_length:
assert prefix_len > 0
prefix_len -= 1
prefix_ids.append(tokenized_input[:prefix_len])
tokenized_input = tokenized_input[prefix_len:]
......@@ -998,7 +996,7 @@ class FlashCausalLM(Model):
config.sliding_window = None
self.num_layers = config.num_hidden_layers
self.num_heads = config.num_attention_heads
self.num_heads = config.num_attention_heads // self.process_group.size()
# Validation is done in the model itself
if num_kv_heads is None:
num_kv_heads = getattr(config, "num_key_value_heads", None)
......@@ -1160,8 +1158,15 @@ class FlashCausalLM(Model):
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths_tensor,
"prefix_lengths": prefix_lengths_tensor,
}
input_lengths_ = Seqlen(input_lengths=input_lengths_tensor)
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph
......@@ -1204,7 +1209,7 @@ class FlashCausalLM(Model):
kv_cache=self.kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths_,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
......@@ -1213,7 +1218,13 @@ class FlashCausalLM(Model):
torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL):
input_lengths_tensor = Seqlen(input_lengths=input_lengths_tensor)
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
......@@ -1221,7 +1232,7 @@ class FlashCausalLM(Model):
kv_cache=self.kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths_tensor,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
......@@ -1268,7 +1279,7 @@ class FlashCausalLM(Model):
num_blocks = (
# Leave 5% for some wiggle room
int((free_memory * 0.95) // total_cache_size)
int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size)
# Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
+ batch_num_blocks
)
......@@ -1360,18 +1371,26 @@ class FlashCausalLM(Model):
# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
input_lengths = Seqlen(input_lengths=input_lengths)
prefix_lens_tensor = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
cu_seqlen_prefill = torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32
)
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=1,
max_k=seqlen,
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32
),
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=self.kv_cache,
block_tables=None,
input_lengths=input_lengths,
seqlen=seqlen,
slots=slots,
max_s=seqlen,
lm_head_indices=None,
......@@ -1451,8 +1470,7 @@ class FlashCausalLM(Model):
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = input_lengths + prefix_lens_tensor
if PREFIX_CACHING:
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
......@@ -1462,11 +1480,18 @@ class FlashCausalLM(Model):
block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths=batch.input_lengths,
input_lengths_tensor=input_lengths,
input_lengths_tensor=input_lengths + prefix_lens_tensor,
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
):
input_lengths = Seqlen(input_lengths=input_lengths)
max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s,
max_k=max_k,
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
......@@ -1474,7 +1499,7 @@ class FlashCausalLM(Model):
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
......
......@@ -5,19 +5,22 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "0").lower() in {"1", "true"}
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged")
ATTENTION = os.getenv("ATTENTION")
_expected = {"paged", "flashdecoding", "flashinfer"}
assert (
ATTENTION in _expected
), f"Attention is not valid {ATTENTION}, expected {_expected}"
log_master(logger.info, f"Using Attention = {ATTENTION}")
if PREFIX_CACHING and ATTENTION != "flashinfer":
if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
raise RuntimeError("Prefix caching is only supported with flashinfer")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95"))
assert TGI_WIGGLE_ROOM > 0
assert TGI_WIGGLE_ROOM < 1
# This is overridden by the cli
BLOCK_SIZE: int
......
......@@ -372,7 +372,14 @@ class VlmCausalLM(FlashCausalLM):
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
):
input_lengths = Seqlen(input_lengths=input_lengths)
max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s,
max_k=max_k,
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
......@@ -380,7 +387,7 @@ class VlmCausalLM(FlashCausalLM):
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment