Commit 425dbcb6 authored by Tri Dao's avatar Tri Dao
Browse files

[MHA] Implement MQA/GQA

parent ec9f74ab
# Copyright (c) 2023, Tri Dao. # Copyright (c) 2023, Tri Dao.
from typing import Tuple from typing import Tuple, Optional
import math import math
import torch import torch
...@@ -151,6 +151,51 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function): ...@@ -151,6 +151,51 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
class ApplyRotaryEmbKV_(torch.autograd.Function):
@staticmethod
def forward(ctx, kv, cos, sin, interleaved=False):
"""
kv: (batch_size, seqlen, 2, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of k.
"""
batch, seqlen, two, nheads, headdim = kv.shape
assert two == 2
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
k_ro = kv[:, :, 0, :, :rotary_dim]
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
rotary_emb.apply_rotary(k1, k2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), k1, k2,
False) # conj=False since this is the forward pass
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
return kv
@staticmethod
def backward(ctx, dkv):
cos, sin = ctx.saved_tensors
_, seqlen, _, _, headdim = dkv.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
dk_ro = dkv[:, :, 0, :, :rotary_dim]
dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
else (dk_ro[..., ::2], dk_ro[..., 1::2]))
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), dk1, dk2,
True) # conj=True since this is the backward pass
return dkv, None, None, None
apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply
class RotaryEmbedding(torch.nn.Module): class RotaryEmbedding(torch.nn.Module):
""" """
The rotary position embeddings from RoFormer_ (Su et. al). The rotary position embeddings from RoFormer_ (Su et. al).
...@@ -249,21 +294,43 @@ class RotaryEmbedding(torch.nn.Module): ...@@ -249,21 +294,43 @@ class RotaryEmbedding(torch.nn.Module):
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, qkv: torch.Tensor, kv: Optional[torch.Tensor] = None,
seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
qkv: (batch, seqlen, 3, nheads, headdim) qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
else it's just q of shape (batch, seqlen, nheads, headdim)
kv: (batch, seqlen, 2, nheads, headdim)
seqlen_offset: can be used in generation where the qkv being passed in is only the last seqlen_offset: can be used in generation where the qkv being passed in is only the last
token in the batch. token in the batch.
""" """
self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype) seqlen = qkv.shape[1]
if self.scale is None: self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
return apply_rotary_emb_qkv_( if kv is None:
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], if self.scale is None:
None, None, self.interleaved return apply_rotary_emb_qkv_(
) qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
None, None, self.interleaved
)
else:
return apply_rotary_emb_qkv_(
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:],
self.interleaved
)
else: else:
return apply_rotary_emb_qkv_( q = qkv
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], q = apply_rotary_emb_func(
self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:], q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
self.interleaved self.interleaved, True
) )
if self.scale is None:
kv = apply_rotary_emb_kv_(
kv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
self.interleaved
)
else:
kv = apply_rotary_emb_kv_(
kv, self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:],
self.interleaved
)
return q, kv
...@@ -88,7 +88,9 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt ...@@ -88,7 +88,9 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
parallel_kwargs = ({'process_group': process_group, parallel_kwargs = ({'process_group': process_group,
'sequence_parallel': getattr(config, 'sequence_parallel', True)} 'sequence_parallel': getattr(config, 'sequence_parallel', True)}
if process_group is not None else {}) if process_group is not None else {})
num_heads_kv = getattr(config, "n_head_kv", None)
mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads, mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads,
num_heads_kv=num_heads_kv,
qkv_proj_bias=qkv_proj_bias, out_proj_bias=out_proj_bias, qkv_proj_bias=qkv_proj_bias, out_proj_bias=out_proj_bias,
dropout=config.attn_pdrop, dropout=config.attn_pdrop,
softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx, softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx,
...@@ -503,20 +505,37 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): ...@@ -503,20 +505,37 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
assert inner_dim % world_size == 0 assert inner_dim % world_size == 0
def shard_first_dim(state_dict, key): def shard_first_dim(state_dict, key):
x = state_dict[key] if key in state_dict:
dim = x.shape[0] // world_size x = state_dict[key]
state_dict[key] = x[rank * dim:(rank + 1) * dim] dim = x.shape[0] // world_size
state_dict[key] = x[rank * dim:(rank + 1) * dim]
def shard_last_dim(state_dict, key): def shard_last_dim(state_dict, key):
x = state_dict[key] if key in state_dict:
dim = x.shape[-1] // world_size x = state_dict[key]
state_dict[key] = x[..., rank * dim:(rank + 1) * dim] dim = x.shape[-1] // world_size
state_dict[key] = x[..., rank * dim:(rank + 1) * dim]
def shard_qkv_headdim(state_dict, key): def shard_qkv_headdim(state_dict, key):
x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3) if key in state_dict:
dim = x.shape[1] // world_size n_head = config.n_head
state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim], n_head_kv = getattr(config, 'n_head_kv', n_head)
'three d ... -> (three d) ...') assert n_head % world_size == 0 and n_head_kv % world_size == 0
if n_head_kv == n_head:
x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3)
dim = x.shape[1] // world_size
state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim],
'three d ... -> (three d) ...')
else:
n_head_per_rank = n_head // world_size
n_head_kv_per_rank = n_head_kv // world_size
x = rearrange(state_dict[key], '(nheadqkv headdim) ... -> nheadqkv headdim ...',
nheadqkv=n_head + 2 * n_head_kv)
state_dict[key] = rearrange(torch.cat([
x[rank * n_head_per_rank:(rank + 1) * n_head_per_rank],
x[n_head + rank * n_head_kv_per_rank:n_head + (rank + 1) * n_head_kv_per_rank],
x[n_head + n_head_kv + rank * n_head_kv_per_rank:n_head + n_head_kv + (rank + 1) * n_head_kv_per_rank],
], dim=0), "nheadqkv headdim ... -> (nheadqkv headdim) ...")
shard_first_dim(state_dict, 'transformer.embeddings.word_embeddings.weight') shard_first_dim(state_dict, 'transformer.embeddings.word_embeddings.weight')
if 'lm_head.weight' in state_dict: if 'lm_head.weight' in state_dict:
...@@ -528,12 +547,12 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): ...@@ -528,12 +547,12 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias') shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias')
shard_last_dim(state_dict, f'transformer.layers.{i}.mixer.out_proj.weight') shard_last_dim(state_dict, f'transformer.layers.{i}.mixer.out_proj.weight')
if rank != 0: if rank != 0:
state_dict.pop(f'transformer.layers.{i}.mixer.out_proj.bias') state_dict.pop(f'transformer.layers.{i}.mixer.out_proj.bias', None)
shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight') shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight')
shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias') shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias')
shard_last_dim(state_dict, f'transformer.layers.{i}.mlp.fc2.weight') shard_last_dim(state_dict, f'transformer.layers.{i}.mlp.fc2.weight')
if rank != 0: if rank != 0:
state_dict.pop(f'transformer.layers.{i}.mlp.fc2.bias') state_dict.pop(f'transformer.layers.{i}.mlp.fc2.bias', None)
return state_dict return state_dict
...@@ -561,9 +580,23 @@ def combine_state_dicts_tp(state_dicts, config): ...@@ -561,9 +580,23 @@ def combine_state_dicts_tp(state_dicts, config):
state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim) state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
def combine_qkv_headdim(state_dicts, state_dict, key): def combine_qkv_headdim(state_dicts, state_dict, key):
n_head = config.n_head
n_head_kv = getattr(config, 'n_head_kv', n_head)
assert n_head % world_size == 0 and n_head_kv % world_size == 0
n_head_per_rank = n_head // world_size
n_head_kv_per_rank = n_head_kv // world_size
if key in state_dict: if key in state_dict:
xs = [rearrange(s[key], '(three d) ... -> three d ...', three=3) for s in state_dicts] if n_head_kv == n_head:
state_dict[key] = rearrange(torch.cat(xs, dim=1), 'three d ... -> (three d) ...') xs = [rearrange(s[key], '(three d) ... -> three d ...', three=3) for s in state_dicts]
state_dict[key] = rearrange(torch.cat(xs, dim=1), 'three d ... -> (three d) ...')
else:
xs = [rearrange(s[key], '(nheadqkv headdim) ... -> nheadqkv headdim ...',
nheadqkv=n_head + 2 * n_head_kv) for s in state_dicts]
state_dict[key] = rearrange(torch.cat([
torch.cat([x[:n_head_per_rank] for x in xs], dim=0),
torch.cat([x[n_head_per_rank:n_head_per_rank + n_head_kv_per_rank] for x in xs], dim=0),
torch.cat([x[-n_head_kv_per_rank:] for x in xs], dim=0),
], dim=0), "nheadqkv headdim ... -> (nheadqkv headdim) ...")
def combine_gated_mlp(state_dicts, state_dict, key): def combine_gated_mlp(state_dicts, state_dict, key):
if key in state_dict: if key in state_dict:
......
This diff is collapsed.
...@@ -60,9 +60,9 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel): ...@@ -60,9 +60,9 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
torch.manual_seed(0) torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and", input_ids = tokenizer("Hello, my dog is cute and he",
return_tensors="pt").input_ids.to(device=device) return_tensors="pt").input_ids.to(device=device)
max_length = 30 max_length = 25
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda') # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40 # max_length = input_ids.shape[1] + 40
...@@ -143,9 +143,9 @@ def test_greedy_decode_opt(model_name): ...@@ -143,9 +143,9 @@ def test_greedy_decode_opt(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
eos_token_id = tokenizer.eos_token_id eos_token_id = tokenizer.eos_token_id
input_ids = tokenizer("Hello, my dog is cute and", input_ids = tokenizer("Hello, my dog is cute and he",
return_tensors="pt").input_ids.to(device=device) return_tensors="pt").input_ids.to(device=device)
max_length = 60 max_length = 25
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda') # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40 # max_length = input_ids.shape[1] + 40
......
...@@ -48,7 +48,6 @@ def test_gptj_optimized(model_name): ...@@ -48,7 +48,6 @@ def test_gptj_optimized(model_name):
torch.manual_seed(0) torch.manual_seed(0)
batch_size = 2 batch_size = 2
max_seqlen = 256 max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device=device) device=device)
with torch.no_grad(): with torch.no_grad():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment