Unverified Commit 5ce89059 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(server): pre-allocate past key values for flash causal LM (#412)

parent ca650e5b
flash_att_commit := d478eeec8f16c7939c54e4617dbd36f59b8eeed7
flash_att_commit := 06ece1a1525ebcf4e183ac76b1e5108d2872f57f
flash-attention:
# Clone flash attention
pip install packaging
git clone https://github.com/HazyResearch/flash-attention.git
git clone https://github.com/OlivierDehaene/flash-attention.git
build-flash-attention: flash-attention
cd flash-attention && git fetch && git checkout $(flash_att_commit)
......
......@@ -128,11 +128,14 @@ class FlashLlamaAttention(torch.nn.Module):
hidden_states,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
......@@ -142,7 +145,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill
if layer_past_present_indices is None:
if prefill:
# Copy to layer past
layer_past[...] = qkv[:, 1:]
......@@ -154,8 +157,10 @@ class FlashLlamaAttention(torch.nn.Module):
qkv[:, 1],
qkv[:, 2],
attn_output,
cu_seqlens,
cu_seqlens,
start_seq,
end_seq,
start_seq,
end_seq,
max_s,
max_s,
0.0,
......@@ -170,7 +175,7 @@ class FlashLlamaAttention(torch.nn.Module):
else:
query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = qkv[:, 1:]
layer_past[past_present_indices] = qkv[:, 1:]
# output
attn_output = torch.empty_like(query)
......@@ -180,8 +185,10 @@ class FlashLlamaAttention(torch.nn.Module):
layer_past[:, 0],
layer_past[:, 1],
attn_output,
cu_seqlens_q,
cu_seqlens,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
......@@ -258,11 +265,14 @@ class FlashLlamaLayer(nn.Module):
residual,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
......@@ -271,11 +281,14 @@ class FlashLlamaLayer(nn.Module):
normed_hidden_states,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
)
# faster post attention rms norm
......@@ -322,35 +335,37 @@ class FlashLlamaModel(torch.nn.Module):
self,
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_key_values: Optional[torch.Tensor] = None,
past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
hidden_states = self.embed_tokens(input_ids)
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty(
(
len(input_ids),
len(self.layers),
len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
2,
self.num_heads,
self.head_size,
)
)
layer_past_present_indices = None
slice_past_index = len(hidden_states)
# Decode
else:
# Create indices from cumulative sequence lengths
layer_past_present_indices = cu_seqlens[1:] - 1
slice_past_index = None
prefill = False
# Get rotary cos and sin for this forward
# Avoid to index in each layer
......@@ -360,24 +375,35 @@ class FlashLlamaModel(torch.nn.Module):
residual = None
for i, layer in enumerate(self.layers):
# We added padding that we now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past_key_values,
layer_past_present_indices,
cu_seqlens_q,
past_key_values[:, i],
past_present_indices,
prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.norm(hidden_states, residual)
......@@ -399,9 +425,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
self,
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -409,9 +438,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
hidden_states, present = self.model(
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
)
......
......@@ -113,11 +113,14 @@ class FlashNeoxAttention(torch.nn.Module):
hidden_states,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
......@@ -127,7 +130,7 @@ class FlashNeoxAttention(torch.nn.Module):
self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill
if layer_past_present_indices is None:
if prefill:
# Copy to layer past
layer_past[...] = qkv[:, 1:]
......@@ -139,8 +142,10 @@ class FlashNeoxAttention(torch.nn.Module):
qkv[:, 1],
qkv[:, 2],
attn_output,
cu_seqlens,
cu_seqlens,
start_seq,
end_seq,
start_seq,
end_seq,
max_s,
max_s,
0.0,
......@@ -155,7 +160,7 @@ class FlashNeoxAttention(torch.nn.Module):
else:
query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = qkv[:, 1:]
layer_past[past_present_indices] = qkv[:, 1:]
# output
attn_output = torch.empty_like(query)
......@@ -165,8 +170,10 @@ class FlashNeoxAttention(torch.nn.Module):
layer_past[:, 0],
layer_past[:, 1],
attn_output,
cu_seqlens_q,
cu_seqlens,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
......@@ -240,11 +247,14 @@ class FlashNeoXLayer(nn.Module):
residual,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
):
if self.use_parallel_residual:
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
......@@ -253,11 +263,14 @@ class FlashNeoXLayer(nn.Module):
ln1_hidden_states,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
)
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
......@@ -276,11 +289,14 @@ class FlashNeoXLayer(nn.Module):
hidden_states,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
)
hidden_states, residual = self.post_attention_layernorm(
......@@ -329,9 +345,12 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
self,
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
......@@ -339,25 +358,24 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty(
(
len(input_ids),
len(self.layers),
len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
2,
self.num_heads,
self.head_size,
)
)
layer_past_present_indices = None
slice_past_index = len(hidden_states)
# Decode
else:
# Create indices from cumulative sequence lengths
layer_past_present_indices = cu_seqlens[1:] - 1
slice_past_index = None
prefill = False
# Get rotary cos and sin for this forward
# Avoid to index in each layer
......@@ -367,24 +385,35 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
residual = None
for i, layer in enumerate(self.layers):
# We added padding that we now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past_key_values,
layer_past_present_indices,
cu_seqlens_q,
past_key_values[:, i],
past_present_indices,
prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
......@@ -404,9 +433,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
self,
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -414,9 +446,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
hidden_states, present = self.gpt_neox(
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
)
......
......@@ -130,11 +130,14 @@ class FlashRWAttention(torch.nn.Module):
hidden_states,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
):
qkv = self.query_key_value(hidden_states)
......@@ -150,10 +153,10 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary
self.rotary_emb(query, cos, sin)
self.rotary_emb(kv[:, 0], cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
# Prefill
if layer_past_present_indices is None:
if prefill:
# Copy to layer past
layer_past[...] = kv
# Expand to query shape
......@@ -164,11 +167,13 @@ class FlashRWAttention(torch.nn.Module):
# flash attention
flash_attn_cuda.fwd(
query,
kv[:, 0],
kv[:, 1],
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlens,
cu_seqlens,
start_seq,
end_seq,
start_seq,
end_seq,
max_s,
max_s,
0.0,
......@@ -182,7 +187,7 @@ class FlashRWAttention(torch.nn.Module):
# Decode
else:
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = kv
layer_past[past_present_indices] = kv
# Expand to query shape
kv = layer_past.expand(-1, 2, self.num_heads, self.head_size)
......@@ -191,11 +196,13 @@ class FlashRWAttention(torch.nn.Module):
# flash attention
flash_attn_cuda.fwd(
query,
kv[:, 0],
kv[:, 1],
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlens_q,
cu_seqlens,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
......@@ -261,11 +268,14 @@ class FlashRWLargeAttention(torch.nn.Module):
hidden_states,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
......@@ -280,10 +290,10 @@ class FlashRWLargeAttention(torch.nn.Module):
# Inplace rotary
self.rotary_emb(query, cos, sin)
self.rotary_emb(kv[:, :, 0], cos, sin)
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
# Prefill
if layer_past_present_indices is None:
if prefill:
# Copy to layer past
layer_past[...] = kv
# Expand to query shape
......@@ -298,11 +308,13 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention
flash_attn_cuda.fwd(
query,
kv[:, :, 0],
kv[:, :, 1],
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
attn_output,
cu_seqlens,
cu_seqlens,
start_seq,
end_seq,
start_seq,
end_seq,
max_s,
max_s,
0.0,
......@@ -316,7 +328,7 @@ class FlashRWLargeAttention(torch.nn.Module):
# Decode
else:
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = kv
layer_past[past_present_indices] = kv
# Expand to query shape
kv = (
layer_past.unsqueeze(2)
......@@ -329,11 +341,13 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention
flash_attn_cuda.fwd(
query,
kv[:, :, 0],
kv[:, :, 1],
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
attn_output,
cu_seqlens_q,
cu_seqlens,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
......@@ -417,11 +431,14 @@ class FlashRWLayer(nn.Module):
residual,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
):
if self.parallel_attn:
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
......@@ -430,11 +447,14 @@ class FlashRWLayer(nn.Module):
ln_hidden_states,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
)
mlp_output = self.mlp(ln_hidden_states)
......@@ -451,11 +471,14 @@ class FlashRWLayer(nn.Module):
hidden_states,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
)
hidden_states, residual = self.post_attention_layernorm(
......@@ -499,11 +522,14 @@ class FlashRWLargeLayer(nn.Module):
residual,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
):
ln_attn, residual = self.ln_attn(hidden_states, residual)
ln_mlp, _ = self.ln_mlp(residual)
......@@ -513,11 +539,14 @@ class FlashRWLargeLayer(nn.Module):
ln_attn,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
)
# MLP.
......@@ -584,9 +613,12 @@ class FlashRWModel(FlashRWPreTrainedModel):
self,
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
......@@ -594,23 +626,22 @@ class FlashRWModel(FlashRWPreTrainedModel):
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty(
(
len(input_ids),
len(self.h),
len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
*self.cache_size,
)
)
layer_past_present_indices = None
slice_past_index = len(hidden_states)
# Decode
else:
# Create indices from cumulative sequence lengths
layer_past_present_indices = cu_seqlens[1:] - 1
slice_past_index = None
prefill = False
# Get rotary cos and sin for this forward
# Avoid to index in each layer
......@@ -620,24 +651,33 @@ class FlashRWModel(FlashRWPreTrainedModel):
residual = None
for i, layer in enumerate(self.h):
# We added padding that we now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past_key_values,
layer_past_present_indices,
cu_seqlens_q,
torch.select(past_key_values, dim=1, index=i),
past_present_indices,
prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.h),
*self.cache_size,
)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.ln_f(hidden_states, residual)
......@@ -658,9 +698,12 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
self,
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -668,9 +711,12 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
hidden_states, present = self.transformer(
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
)
......
......@@ -7,6 +7,7 @@ from typing import Optional
# Flash attention imports
import flash_attn_cuda
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
......@@ -148,11 +149,14 @@ class FlashMQAttention(torch.nn.Module):
def forward(
self,
hidden_states,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
):
qkv = self.c_attn(hidden_states)
......@@ -166,7 +170,7 @@ class FlashMQAttention(torch.nn.Module):
key_value = key_value.view(-1, 2, 1, self.head_size)
# Prefill
if layer_past_present_indices is None:
if prefill:
# Copy to layer past
layer_past[...] = key_value
# Expand from 1 to num_heads
......@@ -177,11 +181,13 @@ class FlashMQAttention(torch.nn.Module):
# flash attention
flash_attn_cuda.fwd(
query,
key_value[:, 0],
key_value[:, 1],
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
attn_output,
cu_seqlens,
cu_seqlens,
start_seq,
end_seq,
start_seq,
end_seq,
max_s,
max_s,
0.0,
......@@ -195,7 +201,7 @@ class FlashMQAttention(torch.nn.Module):
# Decode
else:
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = key_value
layer_past[past_present_indices] = key_value
# Expand from 1 to num_heads
key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size)
......@@ -204,11 +210,13 @@ class FlashMQAttention(torch.nn.Module):
# flash attention
flash_attn_cuda.fwd(
query,
key_value[:, 0],
key_value[:, 1],
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
attn_output,
cu_seqlens_q,
cu_seqlens,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
......@@ -277,21 +285,27 @@ class Block(nn.Module):
self,
hidden_states,
residual,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
):
hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn(
hidden_states,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
)
hidden_states, residual = self.ln_2(hidden_states, residual)
......@@ -339,10 +353,13 @@ class FlashSantacoderModel(nn.Module):
self,
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_key_values: Optional[torch.Tensor] = None,
past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
......@@ -352,44 +369,42 @@ class FlashSantacoderModel(nn.Module):
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
past_key_values = hidden_states.new_empty(
(
len(self.h),
len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
2,
1,
self.head_size,
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_zeros(
(len(input_ids), len(self.h), 2, 1, self.head_size)
)
)
layer_past_present_indices = None
slice_past_index = len(hidden_states)
# Decode
else:
# Create indices from cumulative sequence lengths
layer_past_present_indices = cu_seqlens[1:] - 1
slice_past_index = None
prefill = False
residual = None
for i, layer in enumerate(self.h):
# We added padding that we now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer(
hidden_states,
residual,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past_key_values,
layer_past_present_indices,
cu_seqlens_q,
torch.select(past_key_values, dim=1, index=i),
past_present_indices,
prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(pre_allocate_past_size, len(self.h), 2, 1, self.head_size)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.ln_f(hidden_states, residual)
......@@ -408,9 +423,12 @@ class FlashSantacoderForCausalLM(nn.Module):
self,
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -418,9 +436,12 @@ class FlashSantacoderForCausalLM(nn.Module):
hidden_states, present = self.transformer(
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment