Unverified Commit 5ce89059 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(server): pre-allocate past key values for flash causal LM (#412)

parent ca650e5b
flash_att_commit := d478eeec8f16c7939c54e4617dbd36f59b8eeed7 flash_att_commit := 06ece1a1525ebcf4e183ac76b1e5108d2872f57f
flash-attention: flash-attention:
# Clone flash attention # Clone flash attention
pip install packaging pip install packaging
git clone https://github.com/HazyResearch/flash-attention.git git clone https://github.com/OlivierDehaene/flash-attention.git
build-flash-attention: flash-attention build-flash-attention: flash-attention
cd flash-attention && git fetch && git checkout $(flash_att_commit) cd flash-attention && git fetch && git checkout $(flash_att_commit)
......
...@@ -128,11 +128,14 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -128,11 +128,14 @@ class FlashLlamaAttention(torch.nn.Module):
hidden_states, hidden_states,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size) qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
...@@ -142,7 +145,7 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -142,7 +145,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(qkv[:, 1], cos, sin) self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill # Prefill
if layer_past_present_indices is None: if prefill:
# Copy to layer past # Copy to layer past
layer_past[...] = qkv[:, 1:] layer_past[...] = qkv[:, 1:]
...@@ -154,8 +157,10 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -154,8 +157,10 @@ class FlashLlamaAttention(torch.nn.Module):
qkv[:, 1], qkv[:, 1],
qkv[:, 2], qkv[:, 2],
attn_output, attn_output,
cu_seqlens, start_seq,
cu_seqlens, end_seq,
start_seq,
end_seq,
max_s, max_s,
max_s, max_s,
0.0, 0.0,
...@@ -170,7 +175,7 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -170,7 +175,7 @@ class FlashLlamaAttention(torch.nn.Module):
else: else:
query = qkv[:, 0] query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices # Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = qkv[:, 1:] layer_past[past_present_indices] = qkv[:, 1:]
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
...@@ -180,8 +185,10 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -180,8 +185,10 @@ class FlashLlamaAttention(torch.nn.Module):
layer_past[:, 0], layer_past[:, 0],
layer_past[:, 1], layer_past[:, 1],
attn_output, attn_output,
cu_seqlens_q, start_seq_q,
cu_seqlens, end_seq_q,
start_seq,
end_seq,
1, 1,
max_s, max_s,
0.0, 0.0,
...@@ -258,11 +265,14 @@ class FlashLlamaLayer(nn.Module): ...@@ -258,11 +265,14 @@ class FlashLlamaLayer(nn.Module):
residual, residual,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
...@@ -271,11 +281,14 @@ class FlashLlamaLayer(nn.Module): ...@@ -271,11 +281,14 @@ class FlashLlamaLayer(nn.Module):
normed_hidden_states, normed_hidden_states,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
) )
# faster post attention rms norm # faster post attention rms norm
...@@ -322,35 +335,37 @@ class FlashLlamaModel(torch.nn.Module): ...@@ -322,35 +335,37 @@ class FlashLlamaModel(torch.nn.Module):
self, self,
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, start_seq,
cu_seqlens_q, end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
past_key_values: Optional[torch.Tensor] = None, past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
): ):
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
# Prefill # Prefill
if past_key_values is None: if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor # Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty( past_key_values = hidden_states.new_empty(
( (
len(input_ids),
len(self.layers), len(self.layers),
len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
2, 2,
self.num_heads, self.num_heads,
self.head_size, self.head_size,
) )
) )
layer_past_present_indices = None
slice_past_index = len(hidden_states)
# Decode # Decode
else: else:
# Create indices from cumulative sequence lengths prefill = False
layer_past_present_indices = cu_seqlens[1:] - 1
slice_past_index = None
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
...@@ -360,24 +375,35 @@ class FlashLlamaModel(torch.nn.Module): ...@@ -360,24 +375,35 @@ class FlashLlamaModel(torch.nn.Module):
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
# We added padding that we now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
residual, residual,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past_key_values, past_key_values[:, i],
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.layers),
2,
self.num_heads,
self.head_size,
)
) )
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
...@@ -399,9 +425,12 @@ class FlashLlamaForCausalLM(torch.nn.Module): ...@@ -399,9 +425,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
self, self,
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, start_seq,
cu_seqlens_q, end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -409,9 +438,12 @@ class FlashLlamaForCausalLM(torch.nn.Module): ...@@ -409,9 +438,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
hidden_states, present = self.model( hidden_states, present = self.model(
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, start_seq,
cu_seqlens_q, end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
past_present_indices,
past_key_values, past_key_values,
pre_allocate_past_size, pre_allocate_past_size,
) )
......
...@@ -113,11 +113,14 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -113,11 +113,14 @@ class FlashNeoxAttention(torch.nn.Module):
hidden_states, hidden_states,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size) qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
...@@ -127,7 +130,7 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -127,7 +130,7 @@ class FlashNeoxAttention(torch.nn.Module):
self.rotary_emb(qkv[:, 1], cos, sin) self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill # Prefill
if layer_past_present_indices is None: if prefill:
# Copy to layer past # Copy to layer past
layer_past[...] = qkv[:, 1:] layer_past[...] = qkv[:, 1:]
...@@ -139,8 +142,10 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -139,8 +142,10 @@ class FlashNeoxAttention(torch.nn.Module):
qkv[:, 1], qkv[:, 1],
qkv[:, 2], qkv[:, 2],
attn_output, attn_output,
cu_seqlens, start_seq,
cu_seqlens, end_seq,
start_seq,
end_seq,
max_s, max_s,
max_s, max_s,
0.0, 0.0,
...@@ -155,7 +160,7 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -155,7 +160,7 @@ class FlashNeoxAttention(torch.nn.Module):
else: else:
query = qkv[:, 0] query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices # Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = qkv[:, 1:] layer_past[past_present_indices] = qkv[:, 1:]
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
...@@ -165,8 +170,10 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -165,8 +170,10 @@ class FlashNeoxAttention(torch.nn.Module):
layer_past[:, 0], layer_past[:, 0],
layer_past[:, 1], layer_past[:, 1],
attn_output, attn_output,
cu_seqlens_q, start_seq_q,
cu_seqlens, end_seq_q,
start_seq,
end_seq,
1, 1,
max_s, max_s,
0.0, 0.0,
...@@ -240,11 +247,14 @@ class FlashNeoXLayer(nn.Module): ...@@ -240,11 +247,14 @@ class FlashNeoXLayer(nn.Module):
residual, residual,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
): ):
if self.use_parallel_residual: if self.use_parallel_residual:
ln1_hidden_states, _ = self.input_layernorm(hidden_states) ln1_hidden_states, _ = self.input_layernorm(hidden_states)
...@@ -253,11 +263,14 @@ class FlashNeoXLayer(nn.Module): ...@@ -253,11 +263,14 @@ class FlashNeoXLayer(nn.Module):
ln1_hidden_states, ln1_hidden_states,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
) )
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states) ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
...@@ -276,11 +289,14 @@ class FlashNeoXLayer(nn.Module): ...@@ -276,11 +289,14 @@ class FlashNeoXLayer(nn.Module):
hidden_states, hidden_states,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
) )
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
...@@ -329,9 +345,12 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): ...@@ -329,9 +345,12 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
self, self,
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, start_seq,
cu_seqlens_q, end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
past_present_indices,
past_key_values=None, past_key_values=None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
): ):
...@@ -339,25 +358,24 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): ...@@ -339,25 +358,24 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
# Prefill # Prefill
if past_key_values is None: if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor # Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty( past_key_values = hidden_states.new_empty(
( (
len(input_ids),
len(self.layers), len(self.layers),
len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
2, 2,
self.num_heads, self.num_heads,
self.head_size, self.head_size,
) )
) )
layer_past_present_indices = None
slice_past_index = len(hidden_states)
# Decode # Decode
else: else:
# Create indices from cumulative sequence lengths prefill = False
layer_past_present_indices = cu_seqlens[1:] - 1
slice_past_index = None
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
...@@ -367,24 +385,35 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): ...@@ -367,24 +385,35 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
# We added padding that we now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
residual, residual,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past_key_values, past_key_values[:, i],
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.layers),
2,
self.num_heads,
self.head_size,
)
) )
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.final_layer_norm(hidden_states, residual) hidden_states, _ = self.final_layer_norm(hidden_states, residual)
...@@ -404,9 +433,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): ...@@ -404,9 +433,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
self, self,
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, start_seq,
cu_seqlens_q, end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -414,9 +446,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): ...@@ -414,9 +446,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
hidden_states, present = self.gpt_neox( hidden_states, present = self.gpt_neox(
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, start_seq,
cu_seqlens_q, end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
past_present_indices,
past_key_values, past_key_values,
pre_allocate_past_size, pre_allocate_past_size,
) )
......
...@@ -130,11 +130,14 @@ class FlashRWAttention(torch.nn.Module): ...@@ -130,11 +130,14 @@ class FlashRWAttention(torch.nn.Module):
hidden_states, hidden_states,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
...@@ -150,10 +153,10 @@ class FlashRWAttention(torch.nn.Module): ...@@ -150,10 +153,10 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary # Inplace rotary
self.rotary_emb(query, cos, sin) self.rotary_emb(query, cos, sin)
self.rotary_emb(kv[:, 0], cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
# Prefill # Prefill
if layer_past_present_indices is None: if prefill:
# Copy to layer past # Copy to layer past
layer_past[...] = kv layer_past[...] = kv
# Expand to query shape # Expand to query shape
...@@ -164,11 +167,13 @@ class FlashRWAttention(torch.nn.Module): ...@@ -164,11 +167,13 @@ class FlashRWAttention(torch.nn.Module):
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, query,
kv[:, 0], torch.select(kv, dim=1, index=0),
kv[:, 1], torch.select(kv, dim=1, index=1),
attn_output, attn_output,
cu_seqlens, start_seq,
cu_seqlens, end_seq,
start_seq,
end_seq,
max_s, max_s,
max_s, max_s,
0.0, 0.0,
...@@ -182,7 +187,7 @@ class FlashRWAttention(torch.nn.Module): ...@@ -182,7 +187,7 @@ class FlashRWAttention(torch.nn.Module):
# Decode # Decode
else: else:
# Add present to the layer_past tensor at the correct indices # Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = kv layer_past[past_present_indices] = kv
# Expand to query shape # Expand to query shape
kv = layer_past.expand(-1, 2, self.num_heads, self.head_size) kv = layer_past.expand(-1, 2, self.num_heads, self.head_size)
...@@ -191,11 +196,13 @@ class FlashRWAttention(torch.nn.Module): ...@@ -191,11 +196,13 @@ class FlashRWAttention(torch.nn.Module):
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, query,
kv[:, 0], torch.select(kv, dim=1, index=0),
kv[:, 1], torch.select(kv, dim=1, index=1),
attn_output, attn_output,
cu_seqlens_q, start_seq_q,
cu_seqlens, end_seq_q,
start_seq,
end_seq,
1, 1,
max_s, max_s,
0.0, 0.0,
...@@ -261,11 +268,14 @@ class FlashRWLargeAttention(torch.nn.Module): ...@@ -261,11 +268,14 @@ class FlashRWLargeAttention(torch.nn.Module):
hidden_states, hidden_states,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
...@@ -280,10 +290,10 @@ class FlashRWLargeAttention(torch.nn.Module): ...@@ -280,10 +290,10 @@ class FlashRWLargeAttention(torch.nn.Module):
# Inplace rotary # Inplace rotary
self.rotary_emb(query, cos, sin) self.rotary_emb(query, cos, sin)
self.rotary_emb(kv[:, :, 0], cos, sin) self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
# Prefill # Prefill
if layer_past_present_indices is None: if prefill:
# Copy to layer past # Copy to layer past
layer_past[...] = kv layer_past[...] = kv
# Expand to query shape # Expand to query shape
...@@ -298,11 +308,13 @@ class FlashRWLargeAttention(torch.nn.Module): ...@@ -298,11 +308,13 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, query,
kv[:, :, 0], torch.select(kv, dim=2, index=0),
kv[:, :, 1], torch.select(kv, dim=2, index=1),
attn_output, attn_output,
cu_seqlens, start_seq,
cu_seqlens, end_seq,
start_seq,
end_seq,
max_s, max_s,
max_s, max_s,
0.0, 0.0,
...@@ -316,7 +328,7 @@ class FlashRWLargeAttention(torch.nn.Module): ...@@ -316,7 +328,7 @@ class FlashRWLargeAttention(torch.nn.Module):
# Decode # Decode
else: else:
# Add present to the layer_past tensor at the correct indices # Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = kv layer_past[past_present_indices] = kv
# Expand to query shape # Expand to query shape
kv = ( kv = (
layer_past.unsqueeze(2) layer_past.unsqueeze(2)
...@@ -329,11 +341,13 @@ class FlashRWLargeAttention(torch.nn.Module): ...@@ -329,11 +341,13 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, query,
kv[:, :, 0], torch.select(kv, dim=2, index=0),
kv[:, :, 1], torch.select(kv, dim=2, index=1),
attn_output, attn_output,
cu_seqlens_q, start_seq_q,
cu_seqlens, end_seq_q,
start_seq,
end_seq,
1, 1,
max_s, max_s,
0.0, 0.0,
...@@ -417,11 +431,14 @@ class FlashRWLayer(nn.Module): ...@@ -417,11 +431,14 @@ class FlashRWLayer(nn.Module):
residual, residual,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
): ):
if self.parallel_attn: if self.parallel_attn:
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual) ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
...@@ -430,11 +447,14 @@ class FlashRWLayer(nn.Module): ...@@ -430,11 +447,14 @@ class FlashRWLayer(nn.Module):
ln_hidden_states, ln_hidden_states,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
) )
mlp_output = self.mlp(ln_hidden_states) mlp_output = self.mlp(ln_hidden_states)
...@@ -451,11 +471,14 @@ class FlashRWLayer(nn.Module): ...@@ -451,11 +471,14 @@ class FlashRWLayer(nn.Module):
hidden_states, hidden_states,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
) )
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
...@@ -499,11 +522,14 @@ class FlashRWLargeLayer(nn.Module): ...@@ -499,11 +522,14 @@ class FlashRWLargeLayer(nn.Module):
residual, residual,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
): ):
ln_attn, residual = self.ln_attn(hidden_states, residual) ln_attn, residual = self.ln_attn(hidden_states, residual)
ln_mlp, _ = self.ln_mlp(residual) ln_mlp, _ = self.ln_mlp(residual)
...@@ -513,11 +539,14 @@ class FlashRWLargeLayer(nn.Module): ...@@ -513,11 +539,14 @@ class FlashRWLargeLayer(nn.Module):
ln_attn, ln_attn,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
) )
# MLP. # MLP.
...@@ -584,9 +613,12 @@ class FlashRWModel(FlashRWPreTrainedModel): ...@@ -584,9 +613,12 @@ class FlashRWModel(FlashRWPreTrainedModel):
self, self,
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, start_seq,
cu_seqlens_q, end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
past_present_indices,
past_key_values=None, past_key_values=None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
): ):
...@@ -594,23 +626,22 @@ class FlashRWModel(FlashRWPreTrainedModel): ...@@ -594,23 +626,22 @@ class FlashRWModel(FlashRWPreTrainedModel):
# Prefill # Prefill
if past_key_values is None: if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor # Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty( past_key_values = hidden_states.new_empty(
( (
len(input_ids),
len(self.h), len(self.h),
len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
*self.cache_size, *self.cache_size,
) )
) )
layer_past_present_indices = None
slice_past_index = len(hidden_states)
# Decode # Decode
else: else:
# Create indices from cumulative sequence lengths prefill = False
layer_past_present_indices = cu_seqlens[1:] - 1
slice_past_index = None
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
...@@ -620,24 +651,33 @@ class FlashRWModel(FlashRWPreTrainedModel): ...@@ -620,24 +651,33 @@ class FlashRWModel(FlashRWPreTrainedModel):
residual = None residual = None
for i, layer in enumerate(self.h): for i, layer in enumerate(self.h):
# We added padding that we now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
residual, residual,
cos, cos,
sin, sin,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past_key_values, torch.select(past_key_values, dim=1, index=i),
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.h),
*self.cache_size,
)
) )
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)
...@@ -658,9 +698,12 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): ...@@ -658,9 +698,12 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
self, self,
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, start_seq,
cu_seqlens_q, end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -668,9 +711,12 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): ...@@ -668,9 +711,12 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
hidden_states, present = self.transformer( hidden_states, present = self.transformer(
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, start_seq,
cu_seqlens_q, end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
past_present_indices,
past_key_values, past_key_values,
pre_allocate_past_size, pre_allocate_past_size,
) )
......
...@@ -7,6 +7,7 @@ from typing import Optional ...@@ -7,6 +7,7 @@ from typing import Optional
# Flash attention imports # Flash attention imports
import flash_attn_cuda import flash_attn_cuda
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
...@@ -148,11 +149,14 @@ class FlashMQAttention(torch.nn.Module): ...@@ -148,11 +149,14 @@ class FlashMQAttention(torch.nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
): ):
qkv = self.c_attn(hidden_states) qkv = self.c_attn(hidden_states)
...@@ -166,7 +170,7 @@ class FlashMQAttention(torch.nn.Module): ...@@ -166,7 +170,7 @@ class FlashMQAttention(torch.nn.Module):
key_value = key_value.view(-1, 2, 1, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size)
# Prefill # Prefill
if layer_past_present_indices is None: if prefill:
# Copy to layer past # Copy to layer past
layer_past[...] = key_value layer_past[...] = key_value
# Expand from 1 to num_heads # Expand from 1 to num_heads
...@@ -177,11 +181,13 @@ class FlashMQAttention(torch.nn.Module): ...@@ -177,11 +181,13 @@ class FlashMQAttention(torch.nn.Module):
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, query,
key_value[:, 0], torch.select(key_value, dim=1, index=0),
key_value[:, 1], torch.select(key_value, dim=1, index=1),
attn_output, attn_output,
cu_seqlens, start_seq,
cu_seqlens, end_seq,
start_seq,
end_seq,
max_s, max_s,
max_s, max_s,
0.0, 0.0,
...@@ -195,7 +201,7 @@ class FlashMQAttention(torch.nn.Module): ...@@ -195,7 +201,7 @@ class FlashMQAttention(torch.nn.Module):
# Decode # Decode
else: else:
# Add present to the layer_past tensor at the correct indices # Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = key_value layer_past[past_present_indices] = key_value
# Expand from 1 to num_heads # Expand from 1 to num_heads
key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size) key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size)
...@@ -204,11 +210,13 @@ class FlashMQAttention(torch.nn.Module): ...@@ -204,11 +210,13 @@ class FlashMQAttention(torch.nn.Module):
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, query,
key_value[:, 0], torch.select(key_value, dim=1, index=0),
key_value[:, 1], torch.select(key_value, dim=1, index=1),
attn_output, attn_output,
cu_seqlens_q, start_seq_q,
cu_seqlens, end_seq_q,
start_seq,
end_seq,
1, 1,
max_s, max_s,
0.0, 0.0,
...@@ -277,21 +285,27 @@ class Block(nn.Module): ...@@ -277,21 +285,27 @@ class Block(nn.Module):
self, self,
hidden_states, hidden_states,
residual, residual,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
): ):
hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn( hidden_states = self.attn(
hidden_states, hidden_states,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past, layer_past,
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
) )
hidden_states, residual = self.ln_2(hidden_states, residual) hidden_states, residual = self.ln_2(hidden_states, residual)
...@@ -339,10 +353,13 @@ class FlashSantacoderModel(nn.Module): ...@@ -339,10 +353,13 @@ class FlashSantacoderModel(nn.Module):
self, self,
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, start_seq,
cu_seqlens_q, end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
past_key_values: Optional[torch.Tensor] = None, past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
): ):
hidden_states = self.wte(input_ids) + self.wpe(position_ids) hidden_states = self.wte(input_ids) + self.wpe(position_ids)
...@@ -352,44 +369,42 @@ class FlashSantacoderModel(nn.Module): ...@@ -352,44 +369,42 @@ class FlashSantacoderModel(nn.Module):
# Prefill # Prefill
if past_key_values is None: if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor # Create past tensor
past_key_values = hidden_states.new_empty( # We create a tensor of the same size as input_ids as we don't want to slice at every layer
( past_key_values = hidden_states.new_zeros(
len(self.h), (len(input_ids), len(self.h), 2, 1, self.head_size)
len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
2,
1,
self.head_size,
)
) )
layer_past_present_indices = None
slice_past_index = len(hidden_states)
# Decode # Decode
else: else:
# Create indices from cumulative sequence lengths prefill = False
layer_past_present_indices = cu_seqlens[1:] - 1
slice_past_index = None
residual = None residual = None
for i, layer in enumerate(self.h): for i, layer in enumerate(self.h):
# We added padding that we now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
residual, residual,
cu_seqlens, start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
layer_past_key_values, torch.select(past_key_values, dim=1, index=i),
layer_past_present_indices, past_present_indices,
cu_seqlens_q, prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(pre_allocate_past_size, len(self.h), 2, 1, self.head_size)
) )
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)
...@@ -408,9 +423,12 @@ class FlashSantacoderForCausalLM(nn.Module): ...@@ -408,9 +423,12 @@ class FlashSantacoderForCausalLM(nn.Module):
self, self,
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, start_seq,
cu_seqlens_q, end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
...@@ -418,9 +436,12 @@ class FlashSantacoderForCausalLM(nn.Module): ...@@ -418,9 +436,12 @@ class FlashSantacoderForCausalLM(nn.Module):
hidden_states, present = self.transformer( hidden_states, present = self.transformer(
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, start_seq,
cu_seqlens_q, end_seq,
start_seq_q,
end_seq_q,
max_s, max_s,
past_present_indices,
past_key_values, past_key_values,
pre_allocate_past_size, pre_allocate_past_size,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment