Commit 425dbcb6 authored by Tri Dao's avatar Tri Dao
Browse files

[MHA] Implement MQA/GQA

parent ec9f74ab
# Copyright (c) 2023, Tri Dao. # Copyright (c) 2023, Tri Dao.
from typing import Tuple from typing import Tuple, Optional
import math import math
import torch import torch
...@@ -151,6 +151,51 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function): ...@@ -151,6 +151,51 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
class ApplyRotaryEmbKV_(torch.autograd.Function):
@staticmethod
def forward(ctx, kv, cos, sin, interleaved=False):
"""
kv: (batch_size, seqlen, 2, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of k.
"""
batch, seqlen, two, nheads, headdim = kv.shape
assert two == 2
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
k_ro = kv[:, :, 0, :, :rotary_dim]
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
rotary_emb.apply_rotary(k1, k2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), k1, k2,
False) # conj=False since this is the forward pass
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
return kv
@staticmethod
def backward(ctx, dkv):
cos, sin = ctx.saved_tensors
_, seqlen, _, _, headdim = dkv.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
dk_ro = dkv[:, :, 0, :, :rotary_dim]
dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
else (dk_ro[..., ::2], dk_ro[..., 1::2]))
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), dk1, dk2,
True) # conj=True since this is the backward pass
return dkv, None, None, None
apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply
class RotaryEmbedding(torch.nn.Module): class RotaryEmbedding(torch.nn.Module):
""" """
The rotary position embeddings from RoFormer_ (Su et. al). The rotary position embeddings from RoFormer_ (Su et. al).
...@@ -249,21 +294,43 @@ class RotaryEmbedding(torch.nn.Module): ...@@ -249,21 +294,43 @@ class RotaryEmbedding(torch.nn.Module):
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, qkv: torch.Tensor, kv: Optional[torch.Tensor] = None,
seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
qkv: (batch, seqlen, 3, nheads, headdim) qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
else it's just q of shape (batch, seqlen, nheads, headdim)
kv: (batch, seqlen, 2, nheads, headdim)
seqlen_offset: can be used in generation where the qkv being passed in is only the last seqlen_offset: can be used in generation where the qkv being passed in is only the last
token in the batch. token in the batch.
""" """
self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype) seqlen = qkv.shape[1]
if self.scale is None: self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
return apply_rotary_emb_qkv_( if kv is None:
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], if self.scale is None:
None, None, self.interleaved return apply_rotary_emb_qkv_(
) qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
None, None, self.interleaved
)
else:
return apply_rotary_emb_qkv_(
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:],
self.interleaved
)
else: else:
return apply_rotary_emb_qkv_( q = qkv
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], q = apply_rotary_emb_func(
self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:], q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
self.interleaved self.interleaved, True
) )
if self.scale is None:
kv = apply_rotary_emb_kv_(
kv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
self.interleaved
)
else:
kv = apply_rotary_emb_kv_(
kv, self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:],
self.interleaved
)
return q, kv
...@@ -88,7 +88,9 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt ...@@ -88,7 +88,9 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
parallel_kwargs = ({'process_group': process_group, parallel_kwargs = ({'process_group': process_group,
'sequence_parallel': getattr(config, 'sequence_parallel', True)} 'sequence_parallel': getattr(config, 'sequence_parallel', True)}
if process_group is not None else {}) if process_group is not None else {})
num_heads_kv = getattr(config, "n_head_kv", None)
mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads, mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads,
num_heads_kv=num_heads_kv,
qkv_proj_bias=qkv_proj_bias, out_proj_bias=out_proj_bias, qkv_proj_bias=qkv_proj_bias, out_proj_bias=out_proj_bias,
dropout=config.attn_pdrop, dropout=config.attn_pdrop,
softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx, softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx,
...@@ -503,20 +505,37 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): ...@@ -503,20 +505,37 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
assert inner_dim % world_size == 0 assert inner_dim % world_size == 0
def shard_first_dim(state_dict, key): def shard_first_dim(state_dict, key):
x = state_dict[key] if key in state_dict:
dim = x.shape[0] // world_size x = state_dict[key]
state_dict[key] = x[rank * dim:(rank + 1) * dim] dim = x.shape[0] // world_size
state_dict[key] = x[rank * dim:(rank + 1) * dim]
def shard_last_dim(state_dict, key): def shard_last_dim(state_dict, key):
x = state_dict[key] if key in state_dict:
dim = x.shape[-1] // world_size x = state_dict[key]
state_dict[key] = x[..., rank * dim:(rank + 1) * dim] dim = x.shape[-1] // world_size
state_dict[key] = x[..., rank * dim:(rank + 1) * dim]
def shard_qkv_headdim(state_dict, key): def shard_qkv_headdim(state_dict, key):
x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3) if key in state_dict:
dim = x.shape[1] // world_size n_head = config.n_head
state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim], n_head_kv = getattr(config, 'n_head_kv', n_head)
'three d ... -> (three d) ...') assert n_head % world_size == 0 and n_head_kv % world_size == 0
if n_head_kv == n_head:
x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3)
dim = x.shape[1] // world_size
state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim],
'three d ... -> (three d) ...')
else:
n_head_per_rank = n_head // world_size
n_head_kv_per_rank = n_head_kv // world_size
x = rearrange(state_dict[key], '(nheadqkv headdim) ... -> nheadqkv headdim ...',
nheadqkv=n_head + 2 * n_head_kv)
state_dict[key] = rearrange(torch.cat([
x[rank * n_head_per_rank:(rank + 1) * n_head_per_rank],
x[n_head + rank * n_head_kv_per_rank:n_head + (rank + 1) * n_head_kv_per_rank],
x[n_head + n_head_kv + rank * n_head_kv_per_rank:n_head + n_head_kv + (rank + 1) * n_head_kv_per_rank],
], dim=0), "nheadqkv headdim ... -> (nheadqkv headdim) ...")
shard_first_dim(state_dict, 'transformer.embeddings.word_embeddings.weight') shard_first_dim(state_dict, 'transformer.embeddings.word_embeddings.weight')
if 'lm_head.weight' in state_dict: if 'lm_head.weight' in state_dict:
...@@ -528,12 +547,12 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): ...@@ -528,12 +547,12 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias') shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias')
shard_last_dim(state_dict, f'transformer.layers.{i}.mixer.out_proj.weight') shard_last_dim(state_dict, f'transformer.layers.{i}.mixer.out_proj.weight')
if rank != 0: if rank != 0:
state_dict.pop(f'transformer.layers.{i}.mixer.out_proj.bias') state_dict.pop(f'transformer.layers.{i}.mixer.out_proj.bias', None)
shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight') shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight')
shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias') shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias')
shard_last_dim(state_dict, f'transformer.layers.{i}.mlp.fc2.weight') shard_last_dim(state_dict, f'transformer.layers.{i}.mlp.fc2.weight')
if rank != 0: if rank != 0:
state_dict.pop(f'transformer.layers.{i}.mlp.fc2.bias') state_dict.pop(f'transformer.layers.{i}.mlp.fc2.bias', None)
return state_dict return state_dict
...@@ -561,9 +580,23 @@ def combine_state_dicts_tp(state_dicts, config): ...@@ -561,9 +580,23 @@ def combine_state_dicts_tp(state_dicts, config):
state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim) state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
def combine_qkv_headdim(state_dicts, state_dict, key): def combine_qkv_headdim(state_dicts, state_dict, key):
n_head = config.n_head
n_head_kv = getattr(config, 'n_head_kv', n_head)
assert n_head % world_size == 0 and n_head_kv % world_size == 0
n_head_per_rank = n_head // world_size
n_head_kv_per_rank = n_head_kv // world_size
if key in state_dict: if key in state_dict:
xs = [rearrange(s[key], '(three d) ... -> three d ...', three=3) for s in state_dicts] if n_head_kv == n_head:
state_dict[key] = rearrange(torch.cat(xs, dim=1), 'three d ... -> (three d) ...') xs = [rearrange(s[key], '(three d) ... -> three d ...', three=3) for s in state_dicts]
state_dict[key] = rearrange(torch.cat(xs, dim=1), 'three d ... -> (three d) ...')
else:
xs = [rearrange(s[key], '(nheadqkv headdim) ... -> nheadqkv headdim ...',
nheadqkv=n_head + 2 * n_head_kv) for s in state_dicts]
state_dict[key] = rearrange(torch.cat([
torch.cat([x[:n_head_per_rank] for x in xs], dim=0),
torch.cat([x[n_head_per_rank:n_head_per_rank + n_head_kv_per_rank] for x in xs], dim=0),
torch.cat([x[-n_head_kv_per_rank:] for x in xs], dim=0),
], dim=0), "nheadqkv headdim ... -> (nheadqkv headdim) ...")
def combine_gated_mlp(state_dicts, state_dict, key): def combine_gated_mlp(state_dicts, state_dict, key):
if key in state_dict: if key in state_dict:
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange, repeat
try: try:
from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func
...@@ -211,7 +211,7 @@ class CrossAttention(nn.Module): ...@@ -211,7 +211,7 @@ class CrossAttention(nn.Module):
Arguments Arguments
--------- ---------
q: The tensor containing the query. (B, Sq, H, D) q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D) kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
causal: if passed, will override self.causal causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep, key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, Sk) False means to mask out. (B, Sk)
...@@ -219,7 +219,9 @@ class CrossAttention(nn.Module): ...@@ -219,7 +219,9 @@ class CrossAttention(nn.Module):
batch_size, seqlen_q = q.shape[0], q.shape[1] batch_size, seqlen_q = q.shape[0], q.shape[1]
causal = self.causal if causal is None else causal causal = self.causal if causal is None else causal
seqlen_k = kv.shape[1] seqlen_k = kv.shape[1]
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3] assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
if kv.shape[3] != q.shape[2]: # MQA/GQA
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
k, v = kv.unbind(dim=2) k, v = kv.unbind(dim=2)
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale) scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
...@@ -304,17 +306,52 @@ def _update_kv_cache(kv, inference_params, layer_idx): ...@@ -304,17 +306,52 @@ def _update_kv_cache(kv, inference_params, layer_idx):
return kv return kv
def _apply_rotary_single_query_attention(qkv, inference_params, layer_idx, rotary_emb_dim,
rotary_emb_base, kv=None, rotary_emb_interleaved=False):
"""
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
q of shape (batch_size, 1, nheads, head_dim)
kv: (batch_size, 1, 2, nheads_kv, head_dim)
"""
assert inference_params.fused_ft_kernel
assert ft_attention is not None
if kv is None:
q, k, v = rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1)
else:
q = rearrange(qkv, 'b 1 h d -> b h d')
k, v = rearrange(kv, 'b 1 two h d -> b two h d').unbind(dim=1)
batch_start = inference_params.batch_size_offset
batch_end = batch_start + q.shape[0]
k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
lengths_per_sample = (inference_params.lengths_per_sample[batch_start:batch_end]
if inference_params.lengths_per_sample is not None else None)
context = ft_attention.single_query_attention(
q, k, v,
k_cache[batch_start:batch_end],
v_cache[batch_start:batch_end],
lengths_per_sample,
None, # rotary_cos_
None, # rotary_sin_
None, # nnz_head_idx
inference_params.sequence_len_offset,
rotary_emb_dim, rotary_emb_base,
not rotary_emb_interleaved # neox_rotary_style
)
return rearrange(context, 'b h d -> b 1 h d')
class MHA(nn.Module): class MHA(nn.Module):
"""Multi-head self-attention and cross-attention """Multi-head self-attention and cross-attention
""" """
def __init__(self, embed_dim, num_heads, cross_attn=False, def __init__(self, embed_dim, num_heads, num_heads_kv=None, cross_attn=False,
qkv_proj_bias=True, out_proj_bias=True, qkv_proj_bias=True, out_proj_bias=True,
dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, dwconv=False, dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, dwconv=False,
rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None, rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None,
rotary_emb_interleaved=False, fused_bias_fc=False, use_flash_attn=False, rotary_emb_interleaved=False, fused_bias_fc=False, use_flash_attn=False,
return_residual=False, checkpointing=False, device=None, dtype=None) -> None: return_residual=False, checkpointing=False, device=None, dtype=None) -> None:
""" """
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
return_residual: whether to return the input x along with the output. This is for return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection. to fuse the backward of nn.Linear with the residual connection.
...@@ -332,8 +369,12 @@ class MHA(nn.Module): ...@@ -332,8 +369,12 @@ class MHA(nn.Module):
self.checkpointing = checkpointing self.checkpointing = checkpointing
self.num_heads = num_heads self.num_heads = num_heads
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
assert self.num_heads % self.num_heads_kv == 0, "num_heads must be divisible by num_heads_kv"
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.head_dim = self.embed_dim // num_heads self.head_dim = self.embed_dim // num_heads
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
kv_dim = 2 * self.head_dim * self.num_heads_kv
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet' assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet'
...@@ -347,31 +388,23 @@ class MHA(nn.Module): ...@@ -347,31 +388,23 @@ class MHA(nn.Module):
linear_cls = nn.Linear if not fused_bias_fc else FusedDense linear_cls = nn.Linear if not fused_bias_fc else FusedDense
linear_resid_cls = (LinearResidual if not fused_bias_fc linear_resid_cls = (LinearResidual if not fused_bias_fc
else partial(FusedDense, return_residual=True)) else partial(FusedDense, return_residual=True))
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
if not self.cross_attn: if not self.cross_attn:
if not self.return_residual: self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=qkv_proj_bias,
**factory_kwargs)
else:
self.Wqkv = linear_resid_cls(embed_dim, 3 * embed_dim, bias=qkv_proj_bias,
**factory_kwargs)
if self.dwconv:
self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2,
groups=3 * embed_dim)
else: else:
self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs) self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
if not self.return_residual: self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=qkv_proj_bias, if self.dwconv:
**factory_kwargs) if self.num_heads_kv == self.num_heads:
self.dwconv_qkv = nn.Conv1d(qkv_dim, qkv_dim, kernel_size=3, padding=2,
groups=qkv_dim)
else: else:
self.Wkv = linear_resid_cls(embed_dim, 2 * embed_dim, bias=qkv_proj_bias,
**factory_kwargs)
if self.dwconv:
self.dwconv_q = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=2, self.dwconv_q = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=2,
groups=embed_dim) groups=embed_dim)
self.dwconv_kv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, kernel_size=3, padding=2, self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2,
groups=2 * embed_dim) groups=kv_dim)
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
attention_dropout=dropout) attention_dropout=dropout)
self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale, self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale,
...@@ -382,15 +415,15 @@ class MHA(nn.Module): ...@@ -382,15 +415,15 @@ class MHA(nn.Module):
dtype = self.out_proj.weight.dtype if dtype is None else dtype dtype = self.out_proj.weight.dtype if dtype is None else dtype
device = self.out_proj.weight.device device = self.out_proj.weight.device
if not fused_ft_kernel: if not fused_ft_kernel:
return torch.empty(batch_size, max_seqlen, 2, self.num_heads, self.head_dim, return torch.empty(batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim,
dtype=dtype, device=device) dtype=dtype, device=device)
else: else:
assert dtype in [torch.float16, torch.bfloat16, torch.float32] assert dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if dtype == torch.float32 else 8 packsize = 4 if dtype == torch.float32 else 8
assert self.head_dim % packsize == 0 assert self.head_dim % packsize == 0
k_cache = torch.empty(batch_size, self.num_heads, self.head_dim // packsize, max_seqlen, k_cache = torch.empty(batch_size, self.num_heads_kv, self.head_dim // packsize,
packsize, dtype=dtype, device=device) max_seqlen, packsize, dtype=dtype, device=device)
v_cache = torch.empty(batch_size, self.num_heads, max_seqlen, self.head_dim, v_cache = torch.empty(batch_size, self.num_heads_kv, max_seqlen, self.head_dim,
dtype=dtype, device=device) dtype=dtype, device=device)
return k_cache, v_cache return k_cache, v_cache
...@@ -401,6 +434,18 @@ class MHA(nn.Module): ...@@ -401,6 +434,18 @@ class MHA(nn.Module):
assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor' assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor'
return _update_kv_cache(kv, inference_params, self.layer_idx) return _update_kv_cache(kv, inference_params, self.layer_idx)
def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None):
"""
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
q of shape (batch_size, 1, nheads, head_dim)
kv: (batch_size, 1, 2, nheads_kv, head_dim)
"""
rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
return _apply_rotary_single_query_attention(
qkv, inference_params, self.layer_idx, self.rotary_emb_dim, rotary_emb_base, kv=kv,
rotary_emb_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
)
def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None, def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None,
mixer_subset=None, inference_params=None, **kwargs): mixer_subset=None, inference_params=None, **kwargs):
""" """
...@@ -438,7 +483,8 @@ class MHA(nn.Module): ...@@ -438,7 +483,8 @@ class MHA(nn.Module):
kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, **kwargs} kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, **kwargs}
if self.use_flash_attn else {'key_padding_mask': key_padding_mask, **kwargs}) if self.use_flash_attn else {'key_padding_mask': key_padding_mask, **kwargs})
if not self.cross_attn: seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
if not self.cross_attn and self.num_heads_kv == self.num_heads:
assert x_kv is None and mixer_subset is None assert x_kv is None and mixer_subset is None
if not self.return_residual: if not self.return_residual:
qkv = self.Wqkv(x) qkv = self.Wqkv(x)
...@@ -448,71 +494,69 @@ class MHA(nn.Module): ...@@ -448,71 +494,69 @@ class MHA(nn.Module):
qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2], qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2],
'b d s -> b s d').contiguous() 'b d s -> b s d').contiguous()
qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, d=self.head_dim) qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, d=self.head_dim)
if inference_params is None: if (inference_params is None or inference_params.sequence_len_offset == 0
or not inference_params.fused_ft_kernel):
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv) qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
if not self.checkpointing: if inference_params is None:
context = self.inner_attn(qkv, **kwargs) if not self.checkpointing:
context = self.inner_attn(qkv, **kwargs)
else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv,
**kwargs)
else: else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
else:
if (not inference_params.fused_ft_kernel) or inference_params.sequence_len_offset == 0:
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv, seqlen_offset=inference_params.sequence_len_offset)
q = qkv[:, :, 0] q = qkv[:, :, 0]
kv = self._update_kv_cache(qkv[:, :, 1:], inference_params) kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
# If we're processing the prompt, causal=None (use self.causal). # If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False. # If we're decoding, then causal=False.
causal = None if inference_params.sequence_len_offset == 0 else False causal = None if inference_params.sequence_len_offset == 0 else False
context = self.inner_cross_attn(q, kv, causal=causal) context = self.inner_cross_attn(q, kv, causal=causal)
else: else:
assert inference_params.fused_ft_kernel context = self._apply_rotary_single_query_attention(qkv, inference_params)
assert ft_attention is not None
batch_start = inference_params.batch_size_offset
batch_end = batch_start + qkv.shape[0]
k_cache, v_cache = inference_params.key_value_memory_dict[self.layer_idx]
lengths_per_sample = (inference_params.lengths_per_sample[batch_start:batch_end]
if inference_params.lengths_per_sample is not None else None)
rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
context = ft_attention.single_query_attention(
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
k_cache[batch_start:batch_end],
v_cache[batch_start:batch_end],
lengths_per_sample,
None, # rotary_cos_
None, # rotary_sin_
None, # nnz_head_idx
inference_params.sequence_len_offset,
self.rotary_emb_dim, rotary_emb_base,
# neox_rotary_style
(not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True
)
context = rearrange(context, 'b h d -> b 1 h d')
else: else:
if not self.return_residual: if self.cross_attn:
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) if not self.return_residual:
kv = self.Wkv(x_kv if x_kv is not None else x) q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
kv = self.Wkv(x_kv if x_kv is not None else x)
else:
if x_kv is not None:
kv, x_kv = self.Wkv(x_kv)
else:
kv, x = self.Wkv(x)
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
else: else:
if x_kv is not None: assert self.num_heads_kv != self.num_heads
kv, x_kv = self.Wkv(x_kv) if not self.return_residual:
qkv = self.Wqkv(x)
else: else:
kv, x = self.Wkv(x) qkv, x = self.Wqkv(x)
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) q = qkv[..., :self.num_heads * self.head_dim]
kv = qkv[..., self.num_heads * self.head_dim:]
q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim) q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
kv = rearrange(kv, '... (two h d) -> ... two h d', two=2, d=self.head_dim) kv = rearrange(kv, '... (two hkv d) -> ... two hkv d', two=2, d=self.head_dim)
if self.dwconv: if self.dwconv:
q = rearrange(self.dwconv_q(rearrange(q, 'b s d -> b d s'))[..., :-2], q = rearrange(self.dwconv_q(rearrange(q, 'b s d -> b d s'))[..., :-2],
'b d s -> b s d').contiguous() 'b d s -> b s d').contiguous()
kv = rearrange(self.dwconv_kv(rearrange(kv, 'b s d -> b d s'))[..., :-2], kv = rearrange(self.dwconv_kv(rearrange(kv, 'b s d -> b d s'))[..., :-2],
'b d s -> b s d').contiguous() 'b d s -> b s d').contiguous()
if inference_params is None: if (inference_params is None or inference_params.sequence_len_offset == 0
if not self.checkpointing: or not inference_params.fused_ft_kernel):
context = self.inner_cross_attn(q, kv, **kwargs) if self.rotary_emb_dim > 0:
q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset)
if inference_params is None:
if not self.checkpointing:
context = self.inner_cross_attn(q, kv, **kwargs)
else:
context = torch.utils.checkpoint.checkpoint(self.inner_cross_attn, q, kv,
**kwargs)
else: else:
context = torch.utils.checkpoint.checkpoint(self.inner_cross_attn, q, kv, **kwargs) kv = self._update_kv_cache(kv, inference_params)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal = None if inference_params.sequence_len_offset == 0 else False
context = self.inner_cross_attn(q, kv, causal=causal)
else: else:
kv = self._update_kv_cache(kv) context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
context = self.inner_cross_attn(q, kv, causal=False)
out = self.out_proj(rearrange(context, '... h d -> ... (h d)')) out = self.out_proj(rearrange(context, '... h d -> ... (h d)'))
return out if not self.return_residual else (out, x) return out if not self.return_residual else (out, x)
...@@ -521,7 +565,8 @@ class ParallelMHA(nn.Module): ...@@ -521,7 +565,8 @@ class ParallelMHA(nn.Module):
"""Multi-head self-attention and cross-attention """Multi-head self-attention and cross-attention
""" """
def __init__(self, embed_dim, num_heads, process_group, qkv_proj_bias=True, out_proj_bias=True, def __init__(self, embed_dim, num_heads, process_group, num_heads_kv=None,
qkv_proj_bias=True, out_proj_bias=True,
dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, dropout=0.0, softmax_scale=None, causal=False, layer_idx=None,
rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None, rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None,
rotary_emb_interleaved=False, use_flash_attn=False, checkpointing=False, rotary_emb_interleaved=False, use_flash_attn=False, checkpointing=False,
...@@ -534,10 +579,19 @@ class ParallelMHA(nn.Module): ...@@ -534,10 +579,19 @@ class ParallelMHA(nn.Module):
self.rotary_emb_dim = rotary_emb_dim self.rotary_emb_dim = rotary_emb_dim
self.use_flash_attn = use_flash_attn self.use_flash_attn = use_flash_attn
self.checkpointing = checkpointing self.checkpointing = checkpointing
self.process_group = process_group
self.world_size = process_group.size() if process_group is not None else 1
self.num_heads = num_heads self.num_heads = num_heads
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
self.num_heads_per_rank = num_heads // self.world_size
self.num_heads_kv_per_rank = self.num_heads_kv // self.world_size
assert self.num_heads % self.num_heads_kv == 0, "num_heads must be divisible by num_heads_kv"
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
assert self.num_heads_kv % self.world_size == 0, "num_heads_kv must be divisible by world_size"
self.head_dim = self.embed_dim // num_heads self.head_dim = self.embed_dim // num_heads
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
kv_dim = 2 * self.head_dim * self.num_heads_kv
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
assert RotaryEmbedding is not None, 'rotary_emb is not installed' assert RotaryEmbedding is not None, 'rotary_emb is not installed'
...@@ -547,7 +601,7 @@ class ParallelMHA(nn.Module): ...@@ -547,7 +601,7 @@ class ParallelMHA(nn.Module):
if ColumnParallelLinear is None or RowParallelLinear is None: if ColumnParallelLinear is None or RowParallelLinear is None:
raise ImportError('fused_dense is not installed') raise ImportError('fused_dense is not installed')
self.Wqkv = ColumnParallelLinear(embed_dim, 3 * embed_dim, process_group, self.Wqkv = ColumnParallelLinear(embed_dim, qkv_dim, process_group,
bias=qkv_proj_bias, bias=qkv_proj_bias,
sequence_parallel=sequence_parallel, **factory_kwargs) sequence_parallel=sequence_parallel, **factory_kwargs)
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
...@@ -560,6 +614,41 @@ class ParallelMHA(nn.Module): ...@@ -560,6 +614,41 @@ class ParallelMHA(nn.Module):
bias=out_proj_bias, bias=out_proj_bias,
sequence_parallel=sequence_parallel, **factory_kwargs) sequence_parallel=sequence_parallel, **factory_kwargs)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True):
dtype = self.out_proj.weight.dtype if dtype is None else dtype
device = self.out_proj.weight.device
if not fused_ft_kernel:
return torch.empty(batch_size, max_seqlen, 2, self.num_heads_kv_per_rank,
self.head_dim, dtype=dtype, device=device)
else:
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if dtype == torch.float32 else 8
assert self.head_dim % packsize == 0
k_cache = torch.empty(batch_size, self.num_heads_kv_per_rank,
self.head_dim // packsize,
max_seqlen, packsize, dtype=dtype, device=device)
v_cache = torch.empty(batch_size, self.num_heads_kv_per_rank, max_seqlen,
self.head_dim, dtype=dtype, device=device)
return k_cache, v_cache
def _update_kv_cache(self, kv, inference_params):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
"""
assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor'
return _update_kv_cache(kv, inference_params, self.layer_idx)
def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None):
"""
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
q of shape (batch_size, 1, nheads, head_dim)
kv: (batch_size, 1, 2, nheads_kv, head_dim)
"""
rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
return _apply_rotary_single_query_attention(
qkv, inference_params, self.layer_idx, self.rotary_emb_dim, rotary_emb_base, kv=kv,
rotary_emb_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
)
def forward(self, x, seqlen=None, inference_params=None, **kwargs): def forward(self, x, seqlen=None, inference_params=None, **kwargs):
""" """
Arguments: Arguments:
...@@ -569,55 +658,54 @@ class ParallelMHA(nn.Module): ...@@ -569,55 +658,54 @@ class ParallelMHA(nn.Module):
(in case batch is small). (in case batch is small).
""" """
qkv = self.Wqkv(x) qkv = self.Wqkv(x)
if seqlen is None: if seqlen is not None:
qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
if self.num_heads_kv == self.num_heads:
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, d=self.head_dim) qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, d=self.head_dim)
else: if (inference_params is None or inference_params.sequence_len_offset == 0
qkv = rearrange(qkv, '(b s) (three h d) -> b s three h d', s=seqlen, three=3, or not inference_params.fused_ft_kernel):
d=self.head_dim) if self.rotary_emb_dim > 0:
if inference_params is None: qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
if self.rotary_emb_dim > 0: if inference_params is None:
qkv = self.rotary_emb(qkv) if not self.checkpointing:
if not self.checkpointing: context = self.inner_attn(qkv, **kwargs)
context = self.inner_attn(qkv, **kwargs) else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
else:
q = qkv[:, :, 0]
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal = None if inference_params.sequence_len_offset == 0 else False
context = self.inner_cross_attn(q, kv, causal=causal)
else: else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) context = self._apply_rotary_single_query_attention(qkv, inference_params)
else: else:
if (not inference_params.fused_ft_kernel) or inference_params.sequence_len_offset == 0: q = rearrange(qkv[..., :self.num_heads_per_rank * self.head_dim],
"... (h d) -> ... h d", d=self.head_dim)
kv = rearrange(qkv[..., self.num_heads_per_rank * self.head_dim:],
"... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
if (inference_params is None or inference_params.sequence_len_offset == 0
or not inference_params.fused_ft_kernel):
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv, seqlen_offset=inference_params.sequence_len_offset) q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset)
q = qkv[:, :, 0] if inference_params is None:
assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor' if not self.checkpointing:
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx) context = self.inner_cross_attn(q, kv, **kwargs)
# If we're processing the prompt, causal=None (use self.causal). else:
# If we're decoding, then causal=False. context = torch.utils.checkpoint.checkpoint(self.inner_cross_attn, q, kv,
causal = None if inference_params.sequence_len_offset == 0 else False **kwargs)
context = self.inner_cross_attn(q, kv, causal=causal) else:
kv = self._update_kv_cache(kv, inference_params)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal = None if inference_params.sequence_len_offset == 0 else False
context = self.inner_cross_attn(q, kv, causal=causal)
else: else:
assert inference_params.fused_ft_kernel context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
assert ft_attention is not None context = rearrange(context, 'b s h d -> b s (h d)')
batch_start = inference_params.batch_size_offset if seqlen is not None:
batch_end = batch_start + qkv.shape[0] context = rearrange(context, 'b s d -> (b s) d')
k_cache, v_cache = inference_params.key_value_memory_dict[self.layer_idx]
lengths_per_sample = (inference_params.lengths_per_sample[batch_start:batch_end]
if inference_params.lengths_per_sample is not None else None)
rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
context = ft_attention.single_query_attention(
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
k_cache[batch_start:batch_end],
v_cache[batch_start:batch_end],
lengths_per_sample,
None, # rotary_cos_
None, # rotary_sin_
None, # nnz_head_idx
inference_params.sequence_len_offset,
self.rotary_emb_dim, rotary_emb_base,
# neox_rotary_style
(not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True
)
context = rearrange(context, 'b h d -> b 1 h d')
if seqlen is None:
context = rearrange(context, 'b s h d -> b s (h d)')
else:
context = rearrange(context, 'b s h d -> (b s) (h d)')
out = self.out_proj(context) out = self.out_proj(context)
return out return out
...@@ -60,9 +60,9 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel): ...@@ -60,9 +60,9 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
torch.manual_seed(0) torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and", input_ids = tokenizer("Hello, my dog is cute and he",
return_tensors="pt").input_ids.to(device=device) return_tensors="pt").input_ids.to(device=device)
max_length = 30 max_length = 25
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda') # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40 # max_length = input_ids.shape[1] + 40
...@@ -143,9 +143,9 @@ def test_greedy_decode_opt(model_name): ...@@ -143,9 +143,9 @@ def test_greedy_decode_opt(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
eos_token_id = tokenizer.eos_token_id eos_token_id = tokenizer.eos_token_id
input_ids = tokenizer("Hello, my dog is cute and", input_ids = tokenizer("Hello, my dog is cute and he",
return_tensors="pt").input_ids.to(device=device) return_tensors="pt").input_ids.to(device=device)
max_length = 60 max_length = 25
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda') # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40 # max_length = input_ids.shape[1] + 40
......
...@@ -48,7 +48,6 @@ def test_gptj_optimized(model_name): ...@@ -48,7 +48,6 @@ def test_gptj_optimized(model_name):
torch.manual_seed(0) torch.manual_seed(0)
batch_size = 2 batch_size = 2
max_seqlen = 256 max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device=device) device=device)
with torch.no_grad(): with torch.no_grad():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment