Commit a668890f authored by Tri Dao's avatar Tri Dao
Browse files

[Gen] Add option to run generation with FT attention kernel

parent be1afaa2
...@@ -30,6 +30,11 @@ try: ...@@ -30,6 +30,11 @@ try:
except ImportError: except ImportError:
RotaryEmbedding = None RotaryEmbedding = None
try:
import ft_attention
except ImportError:
ft_attention = None
class FlashSelfAttention(nn.Module): class FlashSelfAttention(nn.Module):
"""Implement the scaled dot product attention with softmax. """Implement the scaled dot product attention with softmax.
...@@ -360,23 +365,32 @@ class MHA(nn.Module): ...@@ -360,23 +365,32 @@ class MHA(nn.Module):
assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor' assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor'
# Pre-allocate memory for key-values for inference. # Pre-allocate memory for key-values for inference.
if self.layer_idx not in inference_params.key_value_memory_dict: if self.layer_idx not in inference_params.key_value_memory_dict:
inference_kv_cache = torch.empty( kv_cache = torch.empty(
inference_params.max_batch_size, inference_params.max_sequence_len, 2, inference_params.max_batch_size, inference_params.max_sequence_len, 2,
self.num_heads, self.head_dim, dtype=kv.dtype, device=kv.device self.num_heads, self.head_dim, dtype=kv.dtype, device=kv.device
) )
inference_params.key_value_memory_dict[self.layer_idx] = inference_kv_cache inference_params.key_value_memory_dict[self.layer_idx] = kv_cache
else: else:
inference_kv_cache = inference_params.key_value_memory_dict[self.layer_idx] assert not inference_params.fused_ft_kernel, 'fused_ft_kernel should not take this path'
kv_cache = inference_params.key_value_memory_dict[self.layer_idx]
# Adjust key and value for inference # Adjust key and value for inference
batch_start = inference_params.batch_size_offset batch_start = inference_params.batch_size_offset
batch_end = batch_start + kv.shape[0] batch_end = batch_start + kv.shape[0]
assert batch_end <= inference_kv_cache.shape[0] assert batch_end <= kv_cache.shape[0]
sequence_start = inference_params.sequence_len_offset sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + kv.shape[1] sequence_end = sequence_start + kv.shape[1]
assert sequence_end <= inference_kv_cache.shape[1] assert sequence_end <= kv_cache.shape[1]
# Copy key and values. # Copy key and values.
inference_kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
kv = inference_kv_cache[batch_start:batch_end, :sequence_end, ...] kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
if inference_params.fused_ft_kernel:
# FT kernel requires different layouts for the k_cache and v_cache.
assert kv_cache.dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if kv_cache.dtype == torch.float32 else 8
k_cache = rearrange(kv_cache[:, :, 0], 'b s h (d packsize) -> b h d s packsize',
packsize=packsize).contiguous()
v_cache = rearrange(kv_cache[:, :, 1], 'b s h d -> b h s d').contiguous()
inference_params.key_value_memory_dict[self.layer_idx] = (k_cache, v_cache)
return kv return kv
def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None, def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None,
...@@ -430,6 +444,7 @@ class MHA(nn.Module): ...@@ -430,6 +444,7 @@ class MHA(nn.Module):
else: else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
else: else:
if (not inference_params.fused_ft_kernel) or inference_params.sequence_len_offset == 0:
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv, seqlen_offset=inference_params.sequence_len_offset) qkv = self.rotary_emb(qkv, seqlen_offset=inference_params.sequence_len_offset)
q = qkv[:, :, 0] q = qkv[:, :, 0]
...@@ -438,6 +453,15 @@ class MHA(nn.Module): ...@@ -438,6 +453,15 @@ class MHA(nn.Module):
# If we're decoding, then causal=False. # If we're decoding, then causal=False.
causal = None if inference_params.sequence_len_offset == 0 else False causal = None if inference_params.sequence_len_offset == 0 else False
context = self.inner_cross_attn(q, kv, causal=causal) context = self.inner_cross_attn(q, kv, causal=causal)
else:
assert ft_attention is not None
context = ft_attention.single_query_attention(
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
*inference_params.key_value_memory_dict[self.layer_idx],
inference_params.lengths_per_sample, inference_params.sequence_len_offset,
self.rotary_emb_dim
)
context = rearrange(context, 'b h d -> b 1 h d')
else: else:
if not self.return_residual: if not self.return_residual:
q = self.Wq(x) q = self.Wq(x)
......
# Copyright (c) 2022, Tri Dao. # Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31 # Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
from typing import Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
import torch import torch
from torch import Tensor
from einops import rearrange from einops import rearrange
...@@ -17,9 +20,11 @@ class InferenceParams: ...@@ -17,9 +20,11 @@ class InferenceParams:
sequence_len_offset: int = 0 sequence_len_offset: int = 0
batch_size_offset: int = 0 batch_size_offset: int = 0
key_value_memory_dict: dict = field(default_factory=dict) key_value_memory_dict: dict = field(default_factory=dict)
fused_ft_kernel: bool = False
lengths_per_sample: Optional[Tensor] = None
def greedy_decode(input_ids, model, max_length): def greedy_decode(input_ids, model, max_length, fused_ft_kernel=True):
"""Greedy decoding. This is a very simple implementation. """Greedy decoding. This is a very simple implementation.
We assume that all sequences in the same batch have the same length. We assume that all sequences in the same batch have the same length.
Arguments: Arguments:
...@@ -30,7 +35,8 @@ def greedy_decode(input_ids, model, max_length): ...@@ -30,7 +35,8 @@ def greedy_decode(input_ids, model, max_length):
scores: tuples of (batch, vocab_size) scores: tuples of (batch, vocab_size)
""" """
batch_size, seqlen_og = input_ids.shape batch_size, seqlen_og = input_ids.shape
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size) inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size,
fused_ft_kernel=fused_ft_kernel)
scores = [] scores = []
with torch.inference_mode(): with torch.inference_mode():
logits = model(input_ids, inference_params=inference_params).logits[:, -1] logits = model(input_ids, inference_params=inference_params).logits[:, -1]
...@@ -57,8 +63,9 @@ def greedy_decode(input_ids, model, max_length): ...@@ -57,8 +63,9 @@ def greedy_decode(input_ids, model, max_length):
class GenerationMixin: class GenerationMixin:
def generate(self, input_ids, max_length, return_dict_in_generate=False, output_scores=False): def generate(self, input_ids, max_length, return_dict_in_generate=False, output_scores=False,
output = greedy_decode(input_ids, self, max_length) **kwargs):
output = greedy_decode(input_ids, self, max_length, **kwargs)
if not output_scores: if not output_scores:
output.scores = None output.scores = None
return output if return_dict_in_generate else output.sequences return output if return_dict_in_generate else output.sequences
...@@ -15,10 +15,11 @@ from flash_attn.utils.generation import greedy_decode ...@@ -15,10 +15,11 @@ from flash_attn.utils.generation import greedy_decode
# TODO: test with rotary embedding # TODO: test with rotary embedding
@pytest.mark.parametrize('fused_ft_kernel', [False, True])
@pytest.mark.parametrize('optimized', [False, True]) @pytest.mark.parametrize('optimized', [False, True])
# @pytest.mark.parametrize('optimized', [False]) # @pytest.mark.parametrize('optimized', [True])
@pytest.mark.parametrize('model_name', ["gpt2"]) @pytest.mark.parametrize('model_name', ["gpt2"])
def test_greedy_decode(model_name, optimized): def test_greedy_decode(model_name, optimized, fused_ft_kernel):
"""Check that our implementation of GPT2 generation matches the HF implementation: """Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32. the HF scores in fp32.
...@@ -62,6 +63,7 @@ def test_greedy_decode(model_name, optimized): ...@@ -62,6 +63,7 @@ def test_greedy_decode(model_name, optimized):
scores = tuple(scores) scores = tuple(scores)
out = model.generate(input_ids=input_ids, max_length=max_length, out = model.generate(input_ids=input_ids, max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, output_scores=True) return_dict_in_generate=True, output_scores=True)
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length, out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment