Commit 63670fd8 authored by Tri Dao's avatar Tri Dao
Browse files

Implement generation for GPT

parent 9d797d88
...@@ -20,6 +20,7 @@ from flash_attn.modules.block import Block ...@@ -20,6 +20,7 @@ from flash_attn.modules.block import Block
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
from flash_attn.utils.distributed import sync_sequence_parallel_params from flash_attn.utils.distributed import sync_sequence_parallel_params
from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import GenerationMixin
try: try:
from flash_attn.ops.fused_dense import ColumnParallelLinear from flash_attn.ops.fused_dense import ColumnParallelLinear
...@@ -61,7 +62,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt ...@@ -61,7 +62,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
if process_group is None else {}) if process_group is None else {})
parallel_kwargs = {'process_group': process_group} if process_group is not None else {} parallel_kwargs = {'process_group': process_group} if process_group is not None else {}
mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads, dropout=config.attn_pdrop, mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads, dropout=config.attn_pdrop,
softmax_scale=softmax_scale, causal=True, softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx,
rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base, rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base,
use_flash_attn=use_flash_attn, use_flash_attn=use_flash_attn,
**serial_kwargs, **parallel_kwargs, **factory_kwargs) **serial_kwargs, **parallel_kwargs, **factory_kwargs)
...@@ -220,7 +221,7 @@ class GPTModel(GPTPreTrainedModel): ...@@ -220,7 +221,7 @@ class GPTModel(GPTPreTrainedModel):
if self.process_group is not None: if self.process_group is not None:
sync_sequence_parallel_params(self, self.process_group) sync_sequence_parallel_params(self, self.process_group)
def forward(self, input_ids, position_ids=None): def forward(self, input_ids, position_ids=None, inference_params=None):
# If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
# dimensions so that we can split on it easily, in case of small batch size. # dimensions so that we can split on it easily, in case of small batch size.
# Only the attention layers need to know the seqlen. # Only the attention layers need to know the seqlen.
...@@ -238,12 +239,14 @@ class GPTModel(GPTPreTrainedModel): ...@@ -238,12 +239,14 @@ class GPTModel(GPTPreTrainedModel):
residual_in_fp32=True residual_in_fp32=True
) )
mixer_kwargs = ({'seqlen': input_ids.shape[1]} if self.process_group is not None else {}) mixer_kwargs = ({'seqlen': input_ids.shape[1]} if self.process_group is not None else {})
if inference_params is not None:
mixer_kwargs['inference_params'] = inference_params
for layer in self.layers: for layer in self.layers:
hidden_states, residual = layer(hidden_states, residual, mixer_kwargs=mixer_kwargs) hidden_states, residual = layer(hidden_states, residual, mixer_kwargs=mixer_kwargs)
return hidden_states return hidden_states
class GPTLMHeadModel(GPTPreTrainedModel): class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None): def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {'device': device, 'dtype': dtype}
...@@ -267,8 +270,13 @@ class GPTLMHeadModel(GPTPreTrainedModel): ...@@ -267,8 +270,13 @@ class GPTLMHeadModel(GPTPreTrainedModel):
def tie_weights(self): def tie_weights(self):
self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
def forward(self, input_ids, position_ids=None): def forward(self, input_ids, position_ids=None, inference_params=None):
hidden_states = self.transformer(input_ids, position_ids=position_ids) """
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
"""
hidden_states = self.transformer(input_ids, position_ids=position_ids,
inference_params=inference_params)
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
CausalLMOutput = namedtuple('CausalLMOutput', ['logits']) CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
return CausalLMOutput(logits=lm_logits) return CausalLMOutput(logits=lm_logits)
......
...@@ -53,7 +53,7 @@ class FlashSelfAttention(nn.Module): ...@@ -53,7 +53,7 @@ class FlashSelfAttention(nn.Module):
self.dropout_p = attention_dropout self.dropout_p = attention_dropout
self.triton = triton self.triton = triton
def forward(self, qkv, cu_seqlens=None, max_seqlen=None): def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
"""Implements the multihead softmax attention. """Implements the multihead softmax attention.
Arguments Arguments
--------- ---------
...@@ -61,6 +61,7 @@ class FlashSelfAttention(nn.Module): ...@@ -61,6 +61,7 @@ class FlashSelfAttention(nn.Module):
If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D). If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
If cu_seqlens is not None and max_seqlen is not None, then qkv has shape If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
(total, 3, H, D), where total is the sum of the sequence lengths in the batch. (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
causal: if passed, will override self.causal
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into qkv. of the sequences in the batch, used to index into qkv.
max_seqlen: int. Maximum sequence length in the batch. max_seqlen: int. Maximum sequence length in the batch.
...@@ -71,6 +72,7 @@ class FlashSelfAttention(nn.Module): ...@@ -71,6 +72,7 @@ class FlashSelfAttention(nn.Module):
""" """
assert qkv.dtype in [torch.float16, torch.bfloat16] assert qkv.dtype in [torch.float16, torch.bfloat16]
assert qkv.is_cuda assert qkv.is_cuda
causal = self.causal if causal is None else causal
unpadded = cu_seqlens is not None unpadded = cu_seqlens is not None
if unpadded: if unpadded:
assert cu_seqlens.dtype == torch.int32 assert cu_seqlens.dtype == torch.int32
...@@ -78,13 +80,13 @@ class FlashSelfAttention(nn.Module): ...@@ -78,13 +80,13 @@ class FlashSelfAttention(nn.Module):
assert isinstance(max_seqlen, int) assert isinstance(max_seqlen, int)
return flash_attn_unpadded_qkvpacked_func( return flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0, qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal softmax_scale=self.softmax_scale, causal=causal
) )
else: else:
batch_size, seqlen = qkv.shape[0], qkv.shape[1] batch_size, seqlen = qkv.shape[0], qkv.shape[1]
# Triton version doesn't support dropout # Triton version doesn't support dropout
if self.triton and (self.dropout_p == 0 or not self.training): if self.triton and (self.dropout_p == 0 or not self.training):
output = flash_attn_qkvpacked_func(qkv, None, self.causal, self.softmax_scale) output = flash_attn_qkvpacked_func(qkv, None, causal, self.softmax_scale)
else: else:
qkv = rearrange(qkv, 'b s ... -> (b s) ...') qkv = rearrange(qkv, 'b s ... -> (b s) ...')
max_seqlen = seqlen max_seqlen = seqlen
...@@ -92,7 +94,7 @@ class FlashSelfAttention(nn.Module): ...@@ -92,7 +94,7 @@ class FlashSelfAttention(nn.Module):
device=qkv.device) device=qkv.device)
output = flash_attn_unpadded_qkvpacked_func( output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0, qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal softmax_scale=self.softmax_scale, causal=causal
) )
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output return output
...@@ -120,12 +122,14 @@ class FlashCrossAttention(nn.Module): ...@@ -120,12 +122,14 @@ class FlashCrossAttention(nn.Module):
self.dropout_p = attention_dropout self.dropout_p = attention_dropout
self.triton = triton self.triton = triton
def forward(self, q, kv, cu_seqlens=None, max_seqlen=None, cu_seqlens_k=None, max_seqlen_k=None): def forward(self, q, kv, causal=None, cu_seqlens=None, max_seqlen=None,
cu_seqlens_k=None, max_seqlen_k=None):
"""Implements the multihead softmax attention. """Implements the multihead softmax attention.
Arguments Arguments
--------- ---------
q: The tensor containing the query. (B, Sq, H, D) q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D) kv: The tensor containing the key and value. (B, Sk, 2, H, D)
causal: if passed, will override self.causal
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q. of the sequences in the batch, used to index into q.
max_seqlen: int. Maximum sequence length in the batch of q. max_seqlen: int. Maximum sequence length in the batch of q.
...@@ -135,6 +139,7 @@ class FlashCrossAttention(nn.Module): ...@@ -135,6 +139,7 @@ class FlashCrossAttention(nn.Module):
""" """
assert q.dtype in [torch.float16, torch.bfloat16] assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda and kv.is_cuda assert q.is_cuda and kv.is_cuda
causal = self.causal if causal is None else causal
unpadded = cu_seqlens is not None unpadded = cu_seqlens is not None
if unpadded: if unpadded:
assert cu_seqlens.dtype == torch.int32 assert cu_seqlens.dtype == torch.int32
...@@ -147,14 +152,14 @@ class FlashCrossAttention(nn.Module): ...@@ -147,14 +152,14 @@ class FlashCrossAttention(nn.Module):
return flash_attn_unpadded_kvpacked_func( return flash_attn_unpadded_kvpacked_func(
q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k, q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k,
self.dropout_p if self.training else 0.0, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal softmax_scale=self.softmax_scale, causal=causal
) )
else: else:
batch_size, seqlen_q = q.shape[0], q.shape[1] batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = kv.shape[1] seqlen_k = kv.shape[1]
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3] assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
if self.triton and (self.dropout_p == 0.0 or not self.training): # Triton version doesn't support dropout if self.triton and (self.dropout_p == 0.0 or not self.training): # Triton version doesn't support dropout
output = flash_attn_kvpacked_func(q, kv, None, self.causal, self.softmax_scale) output = flash_attn_kvpacked_func(q, kv, None, causal, self.softmax_scale)
else: else:
q = rearrange(q, 'b s ... -> (b s) ...') q = rearrange(q, 'b s ... -> (b s) ...')
kv = rearrange(kv, 'b s ... -> (b s) ...') kv = rearrange(kv, 'b s ... -> (b s) ...')
...@@ -165,7 +170,7 @@ class FlashCrossAttention(nn.Module): ...@@ -165,7 +170,7 @@ class FlashCrossAttention(nn.Module):
output = flash_attn_unpadded_kvpacked_func( output = flash_attn_unpadded_kvpacked_func(
q, kv, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, q, kv, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
self.dropout_p if self.training else 0.0, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal softmax_scale=self.softmax_scale, causal=causal
) )
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output return output
...@@ -187,15 +192,17 @@ class SelfAttention(nn.Module): ...@@ -187,15 +192,17 @@ class SelfAttention(nn.Module):
self.softmax_scale = softmax_scale self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout self.dropout_p = attention_dropout
def forward(self, qkv, key_padding_mask=None): def forward(self, qkv, causal=None, key_padding_mask=None):
"""Implements the multihead softmax attention. """Implements the multihead softmax attention.
Arguments Arguments
--------- ---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep, key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, S) False means to mask out. (B, S)
""" """
batch_size, seqlen = qkv.shape[0], qkv.shape[1] batch_size, seqlen = qkv.shape[0], qkv.shape[1]
causal = self.causal if causal is None else causal
q, k, v = qkv.unbind(dim=2) q, k, v = qkv.unbind(dim=2)
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale) scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
...@@ -205,7 +212,7 @@ class SelfAttention(nn.Module): ...@@ -205,7 +212,7 @@ class SelfAttention(nn.Module):
padding_mask.masked_fill_(key_padding_mask, 0.0) padding_mask.masked_fill_(key_padding_mask, 0.0)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s') scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
if self.causal: if causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16' # "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float # So we have to construct the mask in float
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
...@@ -233,16 +240,18 @@ class CrossAttention(nn.Module): ...@@ -233,16 +240,18 @@ class CrossAttention(nn.Module):
self.softmax_scale = softmax_scale self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout self.dropout_p = attention_dropout
def forward(self, q, kv, key_padding_mask=None): def forward(self, q, kv, causal=None, key_padding_mask=None):
"""Implements the multihead softmax attention. """Implements the multihead softmax attention.
Arguments Arguments
--------- ---------
q: The tensor containing the query. (B, Sq, H, D) q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D) kv: The tensor containing the key and value. (B, Sk, 2, H, D)
causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep, key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, Sk) False means to mask out. (B, Sk)
""" """
batch_size, seqlen_q = q.shape[0], q.shape[1] batch_size, seqlen_q = q.shape[0], q.shape[1]
causal = self.causal if causal is None else causal
seqlen_k = kv.shape[1] seqlen_k = kv.shape[1]
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3] assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
k, v = kv.unbind(dim=2) k, v = kv.unbind(dim=2)
...@@ -254,7 +263,7 @@ class CrossAttention(nn.Module): ...@@ -254,7 +263,7 @@ class CrossAttention(nn.Module):
padding_mask.masked_fill_(key_padding_mask, 0.0) padding_mask.masked_fill_(key_padding_mask, 0.0)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s') scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
if self.causal: if causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16' # "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float # So we have to construct the mask in float
causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0,
...@@ -280,7 +289,7 @@ class MHA(nn.Module): ...@@ -280,7 +289,7 @@ class MHA(nn.Module):
""" """
def __init__(self, embed_dim, num_heads, cross_attn=False, bias=True, dropout=0.0, def __init__(self, embed_dim, num_heads, cross_attn=False, bias=True, dropout=0.0,
softmax_scale=None, causal=False, dwconv=False, rotary_emb_dim=0, softmax_scale=None, causal=False, layer_idx=None, dwconv=False, rotary_emb_dim=0,
rotary_emb_scale_base=0, rotary_emb_scale_base=0,
fused_bias_fc=False, use_flash_attn=False, return_residual=False, fused_bias_fc=False, use_flash_attn=False, return_residual=False,
checkpointing=False, device=None, dtype=None) -> None: checkpointing=False, device=None, dtype=None) -> None:
...@@ -294,6 +303,7 @@ class MHA(nn.Module): ...@@ -294,6 +303,7 @@ class MHA(nn.Module):
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.cross_attn = cross_attn self.cross_attn = cross_attn
self.causal = causal self.causal = causal
self.layer_idx = layer_idx
self.dwconv = dwconv self.dwconv = dwconv
self.rotary_emb_dim = rotary_emb_dim self.rotary_emb_dim = rotary_emb_dim
self.use_flash_attn = use_flash_attn self.use_flash_attn = use_flash_attn
...@@ -315,6 +325,8 @@ class MHA(nn.Module): ...@@ -315,6 +325,8 @@ class MHA(nn.Module):
linear_cls = nn.Linear if not fused_bias_fc else FusedDense linear_cls = nn.Linear if not fused_bias_fc else FusedDense
linear_resid_cls = (LinearResidual if not fused_bias_fc linear_resid_cls = (LinearResidual if not fused_bias_fc
else partial(FusedDense, return_residual=True)) else partial(FusedDense, return_residual=True))
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
if not self.cross_attn: if not self.cross_attn:
if not self.return_residual: if not self.return_residual:
self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
...@@ -323,7 +335,6 @@ class MHA(nn.Module): ...@@ -323,7 +335,6 @@ class MHA(nn.Module):
if self.dwconv: if self.dwconv:
self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2, self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2,
groups=3 * embed_dim) groups=3 * embed_dim)
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
else: else:
self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs) self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs)
if not self.return_residual: if not self.return_residual:
...@@ -335,14 +346,41 @@ class MHA(nn.Module): ...@@ -335,14 +346,41 @@ class MHA(nn.Module):
groups=embed_dim) groups=embed_dim)
self.dwconv_kv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, kernel_size=3, padding=2, self.dwconv_kv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, kernel_size=3, padding=2,
groups=2 * embed_dim) groups=2 * embed_dim)
inner_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
attention_dropout=dropout) attention_dropout=dropout)
self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale,
attention_dropout=dropout)
# output projection always have the bias (for now) # output projection always have the bias (for now)
self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs) self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs)
def _update_kv_cache(self, kv, inference_params):
"""kv: (batch_size, 1, nheads, head_dim)
"""
assert not self.dwconv, 'Generation does not support dwconv yet'
assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor'
# Pre-allocate memory for key-values for inference.
if self.layer_idx not in inference_params.key_value_memory_dict:
inference_kv_cache = torch.empty(
inference_params.max_batch_size, inference_params.max_sequence_len, 2,
self.num_heads, self.head_dim, dtype=kv.dtype, device=kv.device
)
inference_params.key_value_memory_dict[self.layer_idx] = inference_kv_cache
else:
inference_kv_cache = inference_params.key_value_memory_dict[self.layer_idx]
# Adjust key and value for inference
batch_start = inference_params.batch_size_offset
batch_end = batch_start + kv.shape[0]
assert batch_end <= inference_kv_cache.shape[0]
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + kv.shape[1]
assert sequence_end <= inference_kv_cache.shape[1]
# Copy key and values.
inference_kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
kv = inference_kv_cache[batch_start:batch_end, :sequence_end, ...]
return kv
def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None, def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None,
**kwargs): inference_params=None, **kwargs):
""" """
Arguments: Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
...@@ -355,6 +393,8 @@ class MHA(nn.Module): ...@@ -355,6 +393,8 @@ class MHA(nn.Module):
max_seqlen: int. Maximum sequence length in the batch. max_seqlen: int. Maximum sequence length in the batch.
key_padding_mask: boolean mask, True means to keep, False means to mask out. key_padding_mask: boolean mask, True means to keep, False means to mask out.
(batch, seqlen). Only applicable when not using FlashAttention. (batch, seqlen). Only applicable when not using FlashAttention.
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
""" """
if cu_seqlens is not None: if cu_seqlens is not None:
assert max_seqlen is not None assert max_seqlen is not None
...@@ -366,6 +406,10 @@ class MHA(nn.Module): ...@@ -366,6 +406,10 @@ class MHA(nn.Module):
assert cu_seqlens is None assert cu_seqlens is None
assert max_seqlen is None assert max_seqlen is None
assert not self.use_flash_attn assert not self.use_flash_attn
if inference_params is not None:
assert key_padding_mask is None
assert cu_seqlens is None and max_seqlen is None
assert not self.dwconv
kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, **kwargs} kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, **kwargs}
if self.use_flash_attn else {'key_padding_mask': key_padding_mask, **kwargs}) if self.use_flash_attn else {'key_padding_mask': key_padding_mask, **kwargs})
...@@ -378,12 +422,22 @@ class MHA(nn.Module): ...@@ -378,12 +422,22 @@ class MHA(nn.Module):
qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2], qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2],
'b d s -> b s d').contiguous() 'b d s -> b s d').contiguous()
qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, d=self.head_dim) qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, d=self.head_dim)
if self.rotary_emb_dim > 0: if inference_params is None:
qkv = self.rotary_emb(qkv) if self.rotary_emb_dim > 0:
if not self.checkpointing: qkv = self.rotary_emb(qkv)
context = self.inner_attn(qkv, **kwargs) if not self.checkpointing:
context = self.inner_attn(qkv, **kwargs)
else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
else: else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv, seqlen_offset=inference_params.sequence_len_offset)
q = qkv[:, :, 0]
kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal = False if inference_params.sequence_len_offset == 0 else None
context = self.inner_cross_attn(q, kv, causal=causal)
else: else:
if not self.return_residual: if not self.return_residual:
q = self.Wq(x) q = self.Wq(x)
...@@ -401,10 +455,14 @@ class MHA(nn.Module): ...@@ -401,10 +455,14 @@ class MHA(nn.Module):
'b d s -> b s d').contiguous() 'b d s -> b s d').contiguous()
kv = rearrange(self.dwconv_kv(rearrange(kv, 'b s d -> b d s'))[..., :-2], kv = rearrange(self.dwconv_kv(rearrange(kv, 'b s d -> b d s'))[..., :-2],
'b d s -> b s d').contiguous() 'b d s -> b s d').contiguous()
if not self.checkpointing: if inference_params is None:
context = self.inner_attn(q, kv, **kwargs) if not self.checkpointing:
context = self.inner_attn(q, kv, **kwargs)
else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv, **kwargs)
else: else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv, **kwargs) kv = self._update_kv_cache(kv)
context = self.inner_cross_attn(q, kv, causal=False)
out = self.out_proj(rearrange(context, '... h d -> ... (h d)')) out = self.out_proj(rearrange(context, '... h d -> ... (h d)'))
return out if not self.return_residual else (out, x) return out if not self.return_residual else (out, x)
......
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
from dataclasses import dataclass, field
import torch
from einops import rearrange
from transformers.generation import GreedySearchDecoderOnlyOutput
@dataclass
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_sequence_len: int
max_batch_size: int
sequence_len_offset: int = 0
batch_size_offset: int = 0
key_value_memory_dict: dict = field(default_factory=dict)
def greedy_decode(input_ids, model, max_length):
"""Greedy decoding. This is a very simple implementation.
We assume that all sequences in the same batch have the same length.
Arguments:
input_ids: (batch, seq_len)
max_length: int
Returns: GreedySearchDecoderOnlyOutput, with the following fields:
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
"""
batch_size, seqlen_og = input_ids.shape
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size)
scores = []
with torch.inference_mode():
logits = model(input_ids, inference_params=inference_params).logits[:, -1]
scores.append(logits)
next_token = logits.argmax(dim=-1)
sequences = [next_token]
inference_params.sequence_len_offset = seqlen_og
while True:
position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset,
dtype=torch.long, device=input_ids.device)
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
inference_params=inference_params).logits[:, -1]
scores.append(logits)
next_token = logits.argmax(dim=-1)
sequences.append(next_token)
inference_params.sequence_len_offset += 1
if inference_params.sequence_len_offset >= max_length - 1:
break
return GreedySearchDecoderOnlyOutput(
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
scores=tuple(scores)
)
class GenerationMixin:
def generate(self, input_ids, max_length, return_dict_in_generate=False, output_scores=False):
output = greedy_decode(input_ids, self, max_length)
if not output_scores:
output.scores = None
return output if return_dict_in_generate else output.sequences
...@@ -23,16 +23,6 @@ def test_gpt2_state_dict(model_name): ...@@ -23,16 +23,6 @@ def test_gpt2_state_dict(model_name):
assert state_dict[k].shape == pretrained_state_dict[k].shape assert state_dict[k].shape == pretrained_state_dict[k].shape
def get_hf_models(model_name, config, dtype):
pretrained_state_dict = state_dict_from_pretrained(model_name)
model_hf = GPT2LMHeadModelHF(config)
# Missing key(s) in state_dict: "bert.embeddings.position_ids", "cls.predictions.decoder.bias"
# position_ids is a buffer, and predictions.decoder.bias is tied to predictions.bias.
model_hf.load_state_dict(pretrained_state_dict, strict=False)
model_hf.cuda().to(dtype=dtype)
return model_hf
@pytest.mark.parametrize('model_name', ["gpt2", "gpt2-medium"]) @pytest.mark.parametrize('model_name', ["gpt2", "gpt2-medium"])
# @pytest.mark.parametrize('model_name', ["gpt2"]) # @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_non_optimized(model_name): def test_gpt2_non_optimized(model_name):
......
import re
import torch
import pytest
from einops import rearrange
from transformers import GPT2Config, GPT2Tokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_gpt2
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import greedy_decode
# TODO: test with rotary embedding
@pytest.mark.parametrize('optimized', [False, True])
# @pytest.mark.parametrize('optimized', [False])
@pytest.mark.parametrize('model_name', ["gpt2"])
def test_greedy_decode(model_name, optimized):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
dtype = torch.float16
rtol, atol = 3e-3, 3e-1
config = GPT2Config.from_pretrained(model_name)
if optimized:
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_dense_gelu_dense = True
config.fused_dropout_add_ln = True
model = GPTLMHeadModel.from_pretrained(model_name, config)
model = model.cuda().to(dtype=dtype)
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda()
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype)
model.eval()
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and ", return_tensors="pt").input_ids.cuda()
max_length = 30
# Slow generation for reference
sequences = []
scores = []
cur_input_ids = input_ids
with torch.inference_mode():
scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1))
for _ in range(input_ids.shape[1] + 1, max_length):
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1)
scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1))
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
scores = tuple(scores)
out = model.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
assert torch.all(out.sequences == sequences)
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
rtol=rtol, atol=atol)
assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment