Commit f1a73d07 authored by Tri Dao's avatar Tri Dao
Browse files

Run isort and black on python files

parent cbb4cf5f
__version__ = "2.0.8"
from flash_attn.flash_attn_interface import flash_attn_func
from flash_attn.flash_attn_interface import flash_attn_kvpacked_func
from flash_attn.flash_attn_interface import flash_attn_qkvpacked_func
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func
from flash_attn.flash_attn_interface import flash_attn_varlen_func
from flash_attn.flash_attn_interface import (
flash_attn_func,
flash_attn_kvpacked_func,
flash_attn_qkvpacked_func,
flash_attn_varlen_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
......@@ -2,12 +2,10 @@
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
class IndexFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
......@@ -16,20 +14,24 @@ class IndexFirstAxis(torch.autograd.Function):
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
return torch.gather(rearrange(input, 'b ... -> b (...)'), 0,
repeat(indices, 'z -> z d', d=second_dim)).reshape(-1, *other_shape)
return torch.gather(
rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
).reshape(-1, *other_shape)
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
grad_output = rearrange(grad_output, 'b ... -> b (...)')
grad_input = torch.zeros([ctx.first_axis_dim, grad_output.shape[1]],
device=grad_output.device, dtype=grad_output.dtype)
grad_output = rearrange(grad_output, "b ... -> b (...)")
grad_input = torch.zeros(
[ctx.first_axis_dim, grad_output.shape[1]],
device=grad_output.device,
dtype=grad_output.dtype,
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# grad_input[indices] = grad_output
grad_input.scatter_(0, repeat(indices, 'z -> z d', d=grad_output.shape[1]), grad_output)
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
......@@ -37,14 +39,14 @@ index_first_axis = IndexFirstAxis.apply
class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, values, indices, first_axis_dim):
ctx.save_for_backward(indices)
assert indices.ndim == 1
assert values.ndim >= 2
output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device,
dtype=values.dtype)
output = torch.zeros(
first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output[indices] = values
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
......@@ -52,7 +54,7 @@ class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
(indices,) = ctx.saved_tensors
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
grad_values = grad_output[indices]
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
......@@ -63,7 +65,6 @@ index_put_first_axis = IndexPutFirstAxis.apply
class IndexFirstAxisResidual(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
......@@ -79,7 +80,7 @@ class IndexFirstAxisResidual(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output, grad_residual):
indices, = ctx.saved_tensors
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
assert grad_residual.shape[1:] == other_shape
......@@ -113,8 +114,12 @@ def unpad_input(hidden_states, attention_mask):
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'), indices), indices,
cu_seqlens, max_seqlen_in_batch)
return (
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def pad_input(hidden_states, indices, batch, seqlen):
......@@ -129,4 +134,4 @@ def pad_input(hidden_states, indices, batch, seqlen):
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output[indices] = hidden_states
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
return rearrange(output, '(b s) ... -> b s ...', b=batch)
return rearrange(output, "(b s) ... -> b s ...", b=batch)
This diff is collapsed.
This diff is collapsed.
......@@ -11,22 +11,41 @@ This is a Triton implementation of the Flash Attention algorithm
import pytest
import torch
import triton
import triton.language as tl
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
Q,
K,
V,
sm_scale,
TMP,
L,
M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
stride_oz,
stride_oh,
stride_om,
stride_on,
Z,
H,
N_CTX,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
......@@ -100,9 +119,13 @@ def _fwd_kernel(
@triton.jit
def _bwd_preprocess(
Out, DO, L,
NewDO, Delta,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
Out,
DO,
L,
NewDO,
Delta,
BLOCK_M: tl.constexpr,
D_HEAD: tl.constexpr,
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
......@@ -120,16 +143,36 @@ def _bwd_preprocess(
@triton.jit
def _bwd_kernel(
Q, K, V, sm_scale, Out, DO,
DQ, DK, DV,
L, M,
Q,
K,
V,
sm_scale,
Out,
DO,
DQ,
DK,
DV,
L,
M,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
Z,
H,
N_CTX,
num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
off_hz = tl.program_id(0)
......@@ -203,7 +246,6 @@ def _bwd_kernel(
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sm_scale):
BLOCK = 128
......@@ -213,22 +255,45 @@ class _attention(torch.autograd.Function):
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
tmp = torch.empty(
(q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
_fwd_kernel[grid](
q, k, v, sm_scale,
tmp, L, m,
q,
k,
v,
sm_scale,
tmp,
L,
m,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk, num_warps=num_warps,
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
o.stride(0),
o.stride(1),
o.stride(2),
o.stride(3),
q.shape[0],
q.shape[1],
q.shape[2],
BLOCK_M=BLOCK,
BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk,
num_warps=num_warps,
num_stages=1,
)
ctx.save_for_backward(q, k, v, o, L, m)
......@@ -247,27 +312,51 @@ class _attention(torch.autograd.Function):
dv = torch.empty_like(v)
do_scaled = torch.empty_like(do)
delta = torch.empty_like(l)
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do, l,
do_scaled, delta,
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)](
o,
do,
l,
do_scaled,
delta,
BLOCK_M=ctx.BLOCK,
D_HEAD=ctx.BLOCK_DMODEL,
)
# NOTE: kernel currently buggy for other values of `num_warps`
num_warps = 8
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l, m,
q,
k,
v,
ctx.sm_scale,
o,
do_scaled,
dq,
dk,
dv,
l,
m,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
q.shape[0],
q.shape[1],
q.shape[2],
ctx.grid[0],
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,
BLOCK_M=ctx.BLOCK,
BLOCK_N=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
num_warps=num_warps,
num_stages=1,
)
return dq.to(q.dtype), dk, dv, None
......
import math
import hydra
import torch
import torch.nn as nn
from einops import rearrange
import hydra
from flash_attn.flash_blocksparse_attn_interface import flash_blocksparse_attn_func
from flash_attn.flash_blocksparse_attn_interface import convert_blockmask
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
from flash_attn.flash_blocksparse_attn_interface import (
convert_blockmask,
flash_blocksparse_attn_func,
)
class FlashBlocksparseAttention(nn.Module):
......@@ -21,8 +22,16 @@ class FlashBlocksparseAttention(nn.Module):
attention_dropout: The dropout rate to apply to the attention
(default: 0.1)
"""
def __init__(self, sparsity_config, softmax_temp=None, attention_dropout=0.0,
max_seq_length=2048, device=None, dtype=None):
def __init__(
self,
sparsity_config,
softmax_temp=None,
attention_dropout=0.0,
max_seq_length=2048,
device=None,
dtype=None,
):
super().__init__()
self.sparsity_config = hydra.utils.instantiate(sparsity_config)
self.softmax_temp = softmax_temp
......@@ -36,8 +45,17 @@ class FlashBlocksparseAttention(nn.Module):
self.register_buffer("blockmask_converted", blockmask_converted)
# logger.info(f'Attention class {self.__class__}: saving={self.layout.float().mean()}')
def forward(self, qkv, attn_mask=None, key_padding_mask=None, causal=False, cu_seqlens=None,
max_s=None, need_weights=False, convert_mask=True):
def forward(
self,
qkv,
attn_mask=None,
key_padding_mask=None,
causal=False,
cu_seqlens=None,
max_s=None,
need_weights=False,
convert_mask=True,
):
"""Implements the multihead softmax attention.
Arguments
---------
......@@ -57,47 +75,76 @@ class FlashBlocksparseAttention(nn.Module):
seqlen = qkv.shape[1]
# Convert mask to take a subset
seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256
assert seqlen_rounded // 16 <= self.layout.shape[0], seqlen_rounded // 256 <= self.layout.shape[1]
blockmask = self.layout[:seqlen_rounded // 16, :seqlen_rounded // 256]
assert seqlen_rounded // 16 <= self.layout.shape[0], (
seqlen_rounded // 256 <= self.layout.shape[1]
)
blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256]
if key_padding_mask is None:
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device)
cu_seqlens = torch.arange(
0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device
)
output = flash_blocksparse_attn_func(
qkv, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal
qkv,
cu_seqlens,
blockmask,
self.dropout_p if self.training else 0.0,
max_s,
softmax_scale=self.softmax_temp,
causal=causal,
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
else:
key_padding_mask_bool = key_padding_mask.bool_matrix
nheads = qkv.shape[-2]
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
output_unpad = flash_blocksparse_attn_func(
x_unpad, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal
x_unpad,
cu_seqlens,
blockmask,
self.dropout_p if self.training else 0.0,
max_s,
softmax_scale=self.softmax_temp,
causal=causal,
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen
),
"b s (h d) -> b s h d",
h=nheads,
)
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
indices, batch_size, seqlen),
'b s (h d) -> b s h d', h=nheads)
else:
assert max_s is not None
seqlen = max_s
# Convert mask to take a subset
seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256
assert seqlen_rounded // 16 <= self.layout.shape[0], seqlen_rounded // 256 <= self.layout.shape[1]
blockmask = self.layout[:seqlen_rounded // 16, :seqlen_rounded // 256]
assert seqlen_rounded // 16 <= self.layout.shape[0], (
seqlen_rounded // 256 <= self.layout.shape[1]
)
blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256]
if convert_mask:
output = flash_blocksparse_attn_func(
qkv, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal
qkv,
cu_seqlens,
blockmask,
self.dropout_p if self.training else 0.0,
max_s,
softmax_scale=self.softmax_temp,
causal=causal,
)
else:
output = flash_blocksparse_attn_func(
qkv, cu_seqlens, self.blockmask_converted, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal,
qkv,
cu_seqlens,
self.blockmask_converted,
self.dropout_p if self.training else 0.0,
max_s,
softmax_scale=self.softmax_temp,
causal=causal,
convert_mask=False,
)
......@@ -105,12 +152,22 @@ class FlashBlocksparseAttention(nn.Module):
class FlashBlocksparseMHA(nn.Module):
def __init__(self, embed_dim, num_heads, sparsity_config, bias=True, batch_first=True,
attention_dropout=0.0, causal=False, max_seq_length=2048,
device=None, dtype=None, **kwargs) -> None:
def __init__(
self,
embed_dim,
num_heads,
sparsity_config,
bias=True,
batch_first=True,
attention_dropout=0.0,
causal=False,
max_seq_length=2048,
device=None,
dtype=None,
**kwargs,
) -> None:
assert batch_first
factory_kwargs = {'device': device, 'dtype': dtype}
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.embed_dim = embed_dim
self.causal = causal
......@@ -122,15 +179,19 @@ class FlashBlocksparseMHA(nn.Module):
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
self.inner_attn = FlashBlocksparseAttention(
sparsity_config, attention_dropout=attention_dropout,
max_seq_length=max_seq_length, **factory_kwargs
sparsity_config,
attention_dropout=attention_dropout,
max_seq_length=max_seq_length,
**factory_kwargs,
)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
def forward(self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None,
need_weights=False):
def forward(
self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None, need_weights=False
):
qkv = self.Wqkv(x)
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask,
need_weights=need_weights, causal=self.causal)
return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
context, attn_weights = self.inner_attn(
qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
)
return self.out_proj(rearrange(context, "b s h d -> b s (h d)")), attn_weights
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py
import flash_attn_cuda
import torch
import torch.nn as nn
import flash_attn_cuda
def convert_blockmask(blockmask, causal):
"""Convert from the 0-1 format to the format used by the CUDA code.
......@@ -40,29 +39,51 @@ def convert_blockmask(blockmask, causal):
return nonzero_idx.T.contiguous().to(dtype=torch.int32)
def _flash_blocksparse_attn_forward(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale,
causal, return_softmax):
context, softmax_lse, *rest = flash_attn_cuda.fwd_block(qkv, cu_seqlens, blockmask, dropout_p,
max_s, softmax_scale, causal,
return_softmax, None)
def _flash_blocksparse_attn_forward(
qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax
):
context, softmax_lse, *rest = flash_attn_cuda.fwd_block(
qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax, None
)
# if context.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
S_dmask = rest[0] if return_softmax else None
return context, softmax_lse, S_dmask
def _flash_blocksparse_attn_backward(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, blockmask,
dropout_p, max_s, softmax_scale, causal):
dqkv, dp, softmax_d = flash_attn_cuda.bwd_block(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens,
blockmask, dropout_p, softmax_scale, max_s,
causal, None)
def _flash_blocksparse_attn_backward(
dout,
qkv,
out,
S_dmask,
softmax_lse,
cu_seqlens,
blockmask,
dropout_p,
max_s,
softmax_scale,
causal,
):
dqkv, dp, softmax_d = flash_attn_cuda.bwd_block(
dout,
qkv,
out,
S_dmask,
softmax_lse,
cu_seqlens,
blockmask,
dropout_p,
softmax_scale,
max_s,
causal,
None,
)
# if dqkv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return dqkv
class FlashBlocksparseAttnFun(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
# Save rng_state because the backward pass will regenerate the dropout mask
......@@ -70,8 +91,14 @@ class FlashBlocksparseAttnFun(torch.autograd.Function):
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(
qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal=causal,
return_softmax=False
qkv,
cu_seqlens,
blockmask,
dropout_p,
max_s,
softmax_scale,
causal=causal,
return_softmax=False,
)
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)
ctx.dropout_p = dropout_p
......@@ -88,8 +115,17 @@ class FlashBlocksparseAttnFun(torch.autograd.Function):
torch.cuda.set_rng_state(rng_state)
# S_dmask is None, temporarily use another tensor just to get it running
dqkv = _flash_blocksparse_attn_backward(
dout, qkv, context, context, softmax_lse, cu_seqlens, blockmask, ctx.dropout_p,
ctx.max_s, ctx.softmax_scale, ctx.causal
dout,
qkv,
context,
context,
softmax_lse,
cu_seqlens,
blockmask,
ctx.dropout_p,
ctx.max_s,
ctx.softmax_scale,
ctx.causal,
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
......@@ -99,7 +135,6 @@ class FlashBlocksparseAttnFun(torch.autograd.Function):
# We duplicate code to return both the output and the softmax for testing
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
# Save rng_state because the backward pass is gonna regenerate the dropout mask
......@@ -107,8 +142,14 @@ class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(
qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal=causal,
return_softmax=True
qkv,
cu_seqlens,
blockmask,
dropout_p,
max_s,
softmax_scale,
causal=causal,
return_softmax=True,
)
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)
ctx.dropout_p = dropout_p
......@@ -124,18 +165,35 @@ class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dqkv = _flash_blocksparse_attn_backward(
dout, qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, ctx.dropout_p,
ctx.max_s, ctx.softmax_scale, ctx.causal
dout,
qkv,
context,
S_dmask,
softmax_lse,
cu_seqlens,
blockmask,
ctx.dropout_p,
ctx.max_s,
ctx.softmax_scale,
ctx.causal,
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None
def flash_blocksparse_attn_func(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale=None,
causal=False, return_attn_probs=False, convert_mask=True):
"""dropout_p should be set to 0.0 during evaluation
"""
def flash_blocksparse_attn_func(
qkv,
cu_seqlens,
blockmask,
dropout_p,
max_s,
softmax_scale=None,
causal=False,
return_attn_probs=False,
convert_mask=True,
):
"""dropout_p should be set to 0.0 during evaluation"""
func = FlashBlocksparseAttnFun if not return_attn_probs else FlashBlocksparseAttnFunWithS
if convert_mask:
blockmask = convert_blockmask(blockmask, causal=causal)
......
......@@ -17,13 +17,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from apex._autocast_utils import _cast_if_autocast_enabled
from apex.transformer.enums import AttnMaskType
from fused_softmax_lib import scaled_masked_softmax_forward, scaled_masked_softmax_backward
from fused_softmax_lib import scaled_masked_softmax_get_batch_per_block
from fused_softmax_lib import scaled_upper_triang_masked_softmax_forward, scaled_upper_triang_masked_softmax_backward
from fused_softmax_lib import (
scaled_masked_softmax_backward,
scaled_masked_softmax_forward,
scaled_masked_softmax_get_batch_per_block,
scaled_upper_triang_masked_softmax_backward,
scaled_upper_triang_masked_softmax_forward,
)
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
......@@ -37,9 +39,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, scale):
scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax_forward(
inputs, scale_t[0]
)
softmax_results = scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
......@@ -81,9 +81,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
@staticmethod
def backward(ctx, output_grads):
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_masked_softmax_backward(
output_grads, softmax_results, scale_t[0]
)
input_grads = scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None
......@@ -122,9 +120,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
if self.input_in_fp16 and self.input_in_bf16:
raise RuntimeError(
"both fp16 and bf16 flags cannot be active at the same time."
)
raise RuntimeError("both fp16 and bf16 flags cannot be active at the same time.")
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
......
......@@ -4,11 +4,10 @@
from functools import partial
import torch.nn as nn
from einops import rearrange
from torch import _assert
from torch.nn.modules.utils import _pair
from einops import rearrange
try:
from flash_attn.ops.fused_dense import FusedDense
except ImportError:
......@@ -16,8 +15,8 @@ except ImportError:
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
"""2D Image to Patch Embedding"""
def __init__(
self,
img_size=224,
......@@ -38,7 +37,7 @@ class PatchEmbed(nn.Module):
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed')
raise ImportError("fused_dense is not installed")
linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense
self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias)
......@@ -46,11 +45,23 @@ class PatchEmbed(nn.Module):
def forward(self, x):
_, _, H, W = x.shape
_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
x = self.proj(rearrange(x, 'b c (h p1) (w p2) -> b h w (c p1 p2)',
p1=self.patch_size[0], p2=self.patch_size[1]))
_assert(
H == self.img_size[0],
f"Input image height ({H}) doesn't match model ({self.img_size[0]}).",
)
_assert(
W == self.img_size[1],
f"Input image width ({W}) doesn't match model ({self.img_size[1]}).",
)
x = self.proj(
rearrange(
x,
"b c (h p1) (w p2) -> b h w (c p1 p2)",
p1=self.patch_size[0],
p2=self.patch_size[1],
)
)
if self.flatten:
x = rearrange(x, 'b h w c -> b (h w) c')
x = rearrange(x, "b h w c -> b (h w) c")
x = self.norm(x)
return x
# Copyright (c) 2023, Tri Dao.
from typing import Tuple, Optional
import math
from typing import Optional, Tuple
import rotary_emb
import torch
from einops import rearrange, repeat
import rotary_emb
def rotate_half(x, interleaved=False):
if not interleaved:
......@@ -16,7 +14,7 @@ def rotate_half(x, interleaved=False):
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1), '... d two -> ... (d two)', two=2)
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
......@@ -26,14 +24,15 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(cos, 's d -> s 1 (2 d)')
sin = repeat(sin, 's d -> s 1 (2 d)')
return torch.cat([x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:]], dim=-1)
cos = repeat(cos, "s d -> s 1 (2 d)")
sin = repeat(sin, "s d -> s 1 (2 d)")
return torch.cat(
[x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
dim=-1,
)
class ApplyRotaryEmb(torch.autograd.Function):
@staticmethod
def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
"""
......@@ -57,10 +56,20 @@ class ApplyRotaryEmb(torch.autograd.Function):
if inplace:
o1, o2 = x1, x2
else:
o1, o2 = (out_ro.chunk(2, dim=-1) if not interleaved
else (out_ro[..., ::2], out_ro[..., 1::2]))
rotary_emb.apply_rotary(x1, x2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), o1, o2, False)
o1, o2 = (
out_ro.chunk(2, dim=-1)
if not interleaved
else (out_ro[..., ::2], out_ro[..., 1::2])
)
rotary_emb.apply_rotary(
x1,
x2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
o1,
o2,
False,
)
if not inplace and rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin)
......@@ -76,17 +85,28 @@ class ApplyRotaryEmb(torch.autograd.Function):
rotary_dim *= 2
inplace = ctx.inplace
do_ro = do[..., :rotary_dim]
do1, do2 = (do_ro.chunk(2, dim=-1) if not ctx.interleaved
else (do_ro[..., ::2], do_ro[..., 1::2]))
do1, do2 = (
do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2])
)
dx = torch.empty_like(do) if not inplace else do
if inplace:
dx1, dx2 = do1, do2
else:
dx_ro = dx[..., :rotary_dim]
dx1, dx2 = (dx_ro.chunk(2, dim=-1) if not ctx.interleaved
else (dx_ro[..., ::2], dx_ro[..., 1::2]))
rotary_emb.apply_rotary(do1, do2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), dx1, dx2, True)
dx1, dx2 = (
dx_ro.chunk(2, dim=-1)
if not ctx.interleaved
else (dx_ro[..., ::2], dx_ro[..., 1::2])
)
rotary_emb.apply_rotary(
do1,
do2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
dx1,
dx2,
True,
)
if not inplace and rotary_dim < headdim:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None, None
......@@ -96,7 +116,6 @@ apply_rotary_emb_func = ApplyRotaryEmb.apply
class ApplyRotaryEmbQKV_(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
"""
......@@ -119,12 +138,26 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
q_ro = qkv[:, :, 0, :, :rotary_dim]
q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2])
rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
rotary_emb.apply_rotary(
q1,
q2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
q1,
q2,
False,
)
k_ro = qkv[:, :, 1, :, :rotary_dim]
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False)
rotary_emb.apply_rotary(
k1,
k2,
rearrange(cos_k[:seqlen], "s d -> s 1 d"),
rearrange(sin_k[:seqlen], "s d -> s 1 d"),
k1,
k2,
False,
)
ctx.save_for_backward(cos, sin, cos_k, sin_k)
ctx.interleaved = interleaved
return qkv
......@@ -136,15 +169,31 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
rotary_dim = cos.shape[-1]
rotary_dim *= 2
dq_ro = dqkv[:, :, 0, :, :rotary_dim]
dq1, dq2 = (dq_ro.chunk(2, dim=-1) if not ctx.interleaved
else (dq_ro[..., ::2], dq_ro[..., 1::2]))
rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True)
dq1, dq2 = (
dq_ro.chunk(2, dim=-1) if not ctx.interleaved else (dq_ro[..., ::2], dq_ro[..., 1::2])
)
rotary_emb.apply_rotary(
dq1,
dq2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
dq1,
dq2,
True,
)
dk_ro = dqkv[:, :, 1, :, :rotary_dim]
dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
else (dk_ro[..., ::2], dk_ro[..., 1::2]))
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1, dk2, True)
dk1, dk2 = (
dk_ro.chunk(2, dim=-1) if not ctx.interleaved else (dk_ro[..., ::2], dk_ro[..., 1::2])
)
rotary_emb.apply_rotary(
dk1,
dk2,
rearrange(cos_k[:seqlen], "s d -> s 1 d"),
rearrange(sin_k[:seqlen], "s d -> s 1 d"),
dk1,
dk2,
True,
)
return dqkv, None, None, None, None, None
......@@ -152,7 +201,6 @@ apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
class ApplyRotaryEmbKV_(torch.autograd.Function):
@staticmethod
def forward(ctx, kv, cos, sin, interleaved=False):
"""
......@@ -171,9 +219,15 @@ class ApplyRotaryEmbKV_(torch.autograd.Function):
assert seqlen <= rotary_seqlen
k_ro = kv[:, :, 0, :, :rotary_dim]
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
rotary_emb.apply_rotary(k1, k2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), k1, k2,
False) # conj=False since this is the forward pass
rotary_emb.apply_rotary(
k1,
k2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
k1,
k2,
False,
) # conj=False since this is the forward pass
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
return kv
......@@ -185,11 +239,18 @@ class ApplyRotaryEmbKV_(torch.autograd.Function):
rotary_dim = cos.shape[-1]
rotary_dim *= 2
dk_ro = dkv[:, :, 0, :, :rotary_dim]
dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
else (dk_ro[..., ::2], dk_ro[..., 1::2]))
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), dk1, dk2,
True) # conj=True since this is the backward pass
dk1, dk2 = (
dk_ro.chunk(2, dim=-1) if not ctx.interleaved else (dk_ro[..., ::2], dk_ro[..., 1::2])
)
rotary_emb.apply_rotary(
dk1,
dk2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
dk1,
dk2,
True,
) # conj=True since this is the backward pass
return dkv, None, None, None
......@@ -214,8 +275,15 @@ class RotaryEmbedding(torch.nn.Module):
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
"""
def __init__(self, dim: int, base=10000.0, interleaved=False, scale_base=None,
pos_idx_in_fp32=True, device=None):
def __init__(
self,
dim: int,
base=10000.0,
interleaved=False,
scale_base=None,
pos_idx_in_fp32=True,
device=None,
):
"""
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
......@@ -239,8 +307,11 @@ class RotaryEmbedding(torch.nn.Module):
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.interleaved = interleaved
self.scale_base = scale_base
scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
/ (1.4 * dim) if scale_base is not None else None)
scale = (
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
if scale_base is not None
else None
)
self.register_buffer("scale", scale, persistent=False)
self._seq_len_cached = 0
......@@ -250,17 +321,21 @@ class RotaryEmbedding(torch.nn.Module):
self._sin_k_cached = None
def _compute_inv_freq(self, device=None):
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device,
dtype=torch.float32) / self.dim))
return 1.0 / (
self.base
** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
)
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
# Reset the tables if the sequence length has changed,
# if we're on a new device (possibly due to tracing for instance),
# or if we're switching from inference mode to training
if (seqlen > self._seq_len_cached or self._cos_cached.device != device
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
or (self.training and self._cos_cached.is_inference())):
or (self.training and self._cos_cached.is_inference())
):
self._seq_len_cached = seqlen
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
......@@ -285,17 +360,20 @@ class RotaryEmbedding(torch.nn.Module):
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
else:
power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
- seqlen // 2) / self.scale_base)
scale = self.scale.to(device=power.device) ** rearrange(power, 's -> s 1')
power = (
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
- seqlen // 2
) / self.scale_base
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
# We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
def forward(self, qkv: torch.Tensor, kv: Optional[torch.Tensor] = None,
seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(
self, qkv: torch.Tensor, kv: Optional[torch.Tensor] = None, seqlen_offset: int = 0
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
else it's just q of shape (batch, seqlen, nheads, headdim)
......@@ -308,29 +386,43 @@ class RotaryEmbedding(torch.nn.Module):
if kv is None:
if self.scale is None:
return apply_rotary_emb_qkv_(
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
None, None, self.interleaved
qkv,
self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:],
None,
None,
self.interleaved,
)
else:
return apply_rotary_emb_qkv_(
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:],
self.interleaved
qkv,
self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:],
self._cos_k_cached[seqlen_offset:],
self._sin_k_cached[seqlen_offset:],
self.interleaved,
)
else:
q = qkv
q = apply_rotary_emb_func(
q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
self.interleaved, True
q,
self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:],
self.interleaved,
True,
)
if self.scale is None:
kv = apply_rotary_emb_kv_(
kv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
self.interleaved
kv,
self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:],
self.interleaved,
)
else:
kv = apply_rotary_emb_kv_(
kv, self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:],
self.interleaved
kv,
self._cos_k_cached[seqlen_offset:],
self._sin_k_cached[seqlen_offset:],
self.interleaved,
)
return q, kv
......@@ -5,7 +5,6 @@
# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
import torch
import torch.nn as nn
import xentropy_cuda_lib
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
......@@ -17,10 +16,16 @@ if "all_gather_into_tensor" not in dir(torch.distributed):
class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels, smoothing=0.0, ignored_index=-100, inplace_backward=False,
process_group=None):
def forward(
ctx,
logits,
labels,
smoothing=0.0,
ignored_index=-100,
inplace_backward=False,
process_group=None,
):
"""
logits: (batch, vocab_size)
labels: (batch,)
......@@ -34,7 +39,7 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
if world_size == 1:
losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing)
losses.masked_fill_(labels==ignored_index, 0)
losses.masked_fill_(labels == ignored_index, 0)
labels_local = labels
else:
rank = torch.distributed.get_rank(process_group)
......@@ -48,8 +53,9 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
# For tensor parallel cross entropy with smoothing, we want to pass in the total number
# of classes so that smoothing can be applied correctly. If total_classes=-1, use the
# last dimension of the input tensor.
losses, lse_local = xentropy_cuda_lib.forward(logits, labels_local, smoothing,
world_size * vocab_size)
losses, lse_local = xentropy_cuda_lib.forward(
logits, labels_local, smoothing, world_size * vocab_size
)
assert lse_local.shape == (batch,)
assert losses.shape == (batch,)
losses.masked_fill_(ignored_mask, 0)
......@@ -61,10 +67,12 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
# For labels not in the vocab of this partition, losses contains
# 0.1 * (lse_local - sum logit / total_classes).
lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype,
device=lse_local.device)
torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(),
group=process_group)
lse_allgather = torch.empty(
world_size, batch, dtype=lse_local.dtype, device=lse_local.device
)
torch.distributed.all_gather_into_tensor(
lse_allgather, lse_local.contiguous(), group=process_group
)
handle_losses = torch.distributed.all_reduce(
losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
)
......@@ -74,16 +82,18 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
# If there's smoothing=0.1, the total losses are
# 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
# We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
rank_per_sample = torch.div(labels, vocab_size, rounding_mode='floor')
lse_local = lse_allgather[rank_per_sample,
torch.arange(batch, device=lse_allgather.device)]
rank_per_sample = torch.div(labels, vocab_size, rounding_mode="floor")
lse_local = lse_allgather[
rank_per_sample, torch.arange(batch, device=lse_allgather.device)
]
handle_losses.wait()
if smoothing == 0.0:
losses += lse - lse_local
else:
losses += ((1 - smoothing) * (lse - lse_local)
+ smoothing * (lse - lse_allgather.sum(dim=0)))
losses += (1 - smoothing) * (lse - lse_local) + smoothing * (
lse - lse_allgather.sum(dim=0)
)
losses.masked_fill_(ignored_mask, 0)
ctx.save_for_backward(logits, lse, labels_local)
......@@ -96,19 +106,24 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
def backward(ctx, grad_loss):
logits, lse, labels = ctx.saved_tensors
grad_loss = grad_loss.contiguous()
grad_loss.masked_fill_(labels==ctx.ignored_index, 0)
grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, lse, labels,
ctx.smoothing, ctx.inplace_backward,
ctx.total_classes)
grad_loss.masked_fill_(labels == ctx.ignored_index, 0)
grad_logits = xentropy_cuda_lib.backward(
grad_loss, logits, lse, labels, ctx.smoothing, ctx.inplace_backward, ctx.total_classes
)
return grad_logits, None, None, None, None, None, None
class CrossEntropyLoss(nn.Module):
def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
inplace_backward=False, process_group=None):
def __init__(
self,
ignore_index=-100,
reduction="mean",
label_smoothing=0.0,
inplace_backward=False,
process_group=None,
):
super().__init__()
if reduction not in ['mean', 'none']:
if reduction not in ["mean", "none"]:
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
self.ignore_index = ignore_index
self.reduction = reduction
......@@ -120,10 +135,14 @@ class CrossEntropyLoss(nn.Module):
assert input.is_cuda and target.is_cuda
# SoftmaxCrossEntropyLoss implicitly casts to float
loss = SoftmaxCrossEntropyLossFn.apply(
input, target, self.label_smoothing, self.ignore_index, self.inplace_backward,
self.process_group
input,
target,
self.label_smoothing,
self.ignore_index,
self.inplace_backward,
self.process_group,
)
if self.reduction == 'mean':
if self.reduction == "mean":
return loss.sum() / (target != self.ignore_index).sum()
else:
return loss
This diff is collapsed.
......@@ -2,93 +2,114 @@
import math
import re
from collections import OrderedDict
import torch
import torch.nn.functional as F
from einops import rearrange
from transformers import GPT2Config, FalconConfig
from transformers import FalconConfig, GPT2Config
def remap_state_dict_hf_falcon(state_dict, config):
def key_mapping_layers(key):
return re.sub(r'^transformer.h.', 'transformer.layers.', key)
return re.sub(r"^transformer.h.", "transformer.layers.", key)
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding
def key_mapping_emb(key):
return re.sub(r'^transformer.word_embeddings.', 'transformer.embeddings.word_embeddings.', key)
return re.sub(
r"^transformer.word_embeddings.", "transformer.embeddings.word_embeddings.", key
)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
if getattr(config, "tie_word_embeddings"):
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
else:
output_embeddings = state_dict.pop('lm_head.weight')
output_embeddings = state_dict.pop("lm_head.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
state_dict["lm_head.weight"] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
output_embeddings_bias = state_dict.pop('lm_head.bias')
state_dict['lm_head.bias'] = F.pad(
output_embeddings_bias = state_dict.pop("lm_head.bias")
state_dict["lm_head.bias"] = F.pad(
output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
)
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.',
r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.',
r'transformer.layers.\1.norm2.', key)
key = re.sub(r'^transformer.layers.(\d+).ln_attn.', r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).ln_mlp.', r'transformer.layers.\1.norm2.', key)
key = re.sub(
r"^transformer.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key
)
key = re.sub(
r"^transformer.layers.(\d+).post_attention_layernorm.",
r"transformer.layers.\1.norm2.",
key,
)
key = re.sub(r"^transformer.layers.(\d+).ln_attn.", r"transformer.layers.\1.norm1.", key)
key = re.sub(r"^transformer.layers.(\d+).ln_mlp.", r"transformer.layers.\1.norm2.", key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.',
r'transformer.layers.\1.mlp.fc1.', key)
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.',
r'transformer.layers.\1.mlp.fc2.', key)
key = re.sub(
r"^transformer.layers.(\d+).mlp.dense_h_to_4h.", r"transformer.layers.\1.mlp.fc1.", key
)
key = re.sub(
r"^transformer.layers.(\d+).mlp.dense_4h_to_h.", r"transformer.layers.\1.mlp.fc2.", key
)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
def key_mapping_attn(key):
key = re.sub(r'^transformer.layers.(\d+).self_attention.query_key_value.',
r'transformer.layers.\1.mixer.Wqkv.', key)
key = re.sub(r'^transformer.layers.(\d+).self_attention.dense.',
r'transformer.layers.\1.mixer.out_proj.', key)
key = re.sub(
r"^transformer.layers.(\d+).self_attention.query_key_value.",
r"transformer.layers.\1.mixer.Wqkv.",
key,
)
key = re.sub(
r"^transformer.layers.(\d+).self_attention.dense.",
r"transformer.layers.\1.mixer.out_proj.",
key,
)
return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
n_head = config.n_head
n_head_kv = getattr(config, "n_head_kv", 1)
headdim = config.hidden_size // n_head
for l in range(config.n_layer):
# The weights are stored in a different layout compared to our implementation
Wqkv = rearrange(state_dict.pop(f'transformer.layers.{l}.mixer.Wqkv.weight'),
Wqkv = rearrange(
state_dict.pop(f"transformer.layers.{l}.mixer.Wqkv.weight"),
"(group ratio headdim) ... -> group ratio headdim ...",
ratio=n_head // n_head_kv + 2, headdim=headdim)
ratio=n_head // n_head_kv + 2,
headdim=headdim,
)
Wq = rearrange(Wqkv[:, :-2], "group ratio headdim ... -> (group ratio headdim) ...")
Wk = rearrange(Wqkv[:, [-2]], "group ratio headdim ... -> (group ratio headdim) ...")
Wv = rearrange(Wqkv[:, [-1]], "group ratio headdim ... -> (group ratio headdim) ...")
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
return state_dict
def falcon_config_to_gpt2_config(falcon_config: FalconConfig) -> GPT2Config:
# The 40b config uses "n_head_kv" instead of "num_kv_heads"
n_head_kv = getattr(falcon_config, "n_head_kv",
1 if getattr(falcon_config, "multi_query", False)
else falcon_config.n_head)
n_head_kv = getattr(
falcon_config,
"n_head_kv",
1 if getattr(falcon_config, "multi_query", False) else falcon_config.n_head,
)
# HACK: the 40b config has 2 LN per layer instead of 1, but that's not reflected in the config.
# So we have to infer it from the number of heads in the key/value block
parallel_block_tied_norm = n_head_kv == 1
......
......@@ -11,6 +11,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import GPT2Config
from flash_attn.models.falcon import remap_state_dict_hf_falcon
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
from flash_attn.models.gptj import remap_state_dict_hf_gptj
......@@ -27,10 +29,9 @@ from flash_attn.modules.mlp import (
ParallelMLP,
)
from flash_attn.ops.activations import sqrelu_fwd
from flash_attn.utils.distributed import all_gather_raw, sync_shared_params, get_dim_for_local_rank
from flash_attn.utils.distributed import all_gather_raw, get_dim_for_local_rank, sync_shared_params
from flash_attn.utils.generation import GenerationMixin
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import GPT2Config
try:
from flash_attn.ops.fused_dense import ColumnParallelLinear
......@@ -690,7 +691,7 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
if key in state_dict:
x = state_dict[key]
dim = x.shape[0] // world_size
state_dict[key] = x[rank * dim: (rank + 1) * dim]
state_dict[key] = x[rank * dim : (rank + 1) * dim]
def shard_last_dim(state_dict, key, multiple_of=1):
if key in state_dict:
......@@ -707,17 +708,19 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
x = state_dict[key]
dim = x.shape[0] // world_size // 2
state_dict[key] = rearrange(
rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim: (rank + 1) * dim],
rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim : (rank + 1) * dim],
"two o ... -> (two o) ...",
)
def shard_qkv_headdim(state_dict, key):
if key in state_dict:
n_head_each_rank = [
get_dim_for_local_rank(n_head, world_size, local_rank) for local_rank in range(world_size)
get_dim_for_local_rank(n_head, world_size, local_rank)
for local_rank in range(world_size)
]
n_head_kv_each_rank = [
get_dim_for_local_rank(n_head_kv, world_size, local_rank) for local_rank in range(world_size)
get_dim_for_local_rank(n_head_kv, world_size, local_rank)
for local_rank in range(world_size)
]
beg_n_head = sum(n_head_each_rank[:rank])
......@@ -729,7 +732,8 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
if n_head_kv == n_head:
x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3)
state_dict[key] = rearrange(
x[:, beg_n_head * head_dim : end_n_head * head_dim], "three d ... -> (three d) ..."
x[:, beg_n_head * head_dim : end_n_head * head_dim],
"three d ... -> (three d) ...",
)
else:
x = rearrange(
......@@ -741,8 +745,14 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
torch.cat(
[
x[beg_n_head:end_n_head],
x[n_head + beg_n_head_kv: n_head + end_n_head_kv],
x[n_head + n_head_kv + beg_n_head_kv: n_head + n_head_kv + end_n_head_kv],
x[n_head + beg_n_head_kv : n_head + end_n_head_kv],
x[
n_head
+ n_head_kv
+ beg_n_head_kv : n_head
+ n_head_kv
+ end_n_head_kv
],
],
dim=0,
),
......@@ -824,7 +834,7 @@ def combine_state_dicts_tp(state_dicts, config):
torch.cat([x[:n_head_per_rank] for x in xs], dim=0),
torch.cat(
[
x[n_head_per_rank: n_head_per_rank + n_head_kv_per_rank]
x[n_head_per_rank : n_head_per_rank + n_head_kv_per_rank]
for x in xs
],
dim=0,
......
......@@ -2,80 +2,100 @@
import math
import re
from collections import OrderedDict
import torch
import torch.nn.functional as F
from einops import rearrange
from transformers import GPT2Config, GPTNeoXConfig
def remap_state_dict_hf_gpt_neox(state_dict, config):
def key_mapping_layers(key):
return re.sub(r'^gpt_neox.', 'transformer.', key)
return re.sub(r"^gpt_neox.", "transformer.", key)
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding
def key_mapping_emb(key):
return re.sub(r'^transformer.embed_in.', 'transformer.embeddings.word_embeddings.', key)
return re.sub(r"^transformer.embed_in.", "transformer.embeddings.word_embeddings.", key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
if getattr(config, "tie_word_embeddings"):
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
else:
output_embeddings = state_dict.pop('embed_out.weight')
output_embeddings = state_dict.pop("embed_out.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
state_dict["lm_head.weight"] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key)
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.', r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.', r'transformer.layers.\1.norm2.', key)
key = re.sub(r"^transformer.final_layer_norm.", r"transformer.ln_f.", key)
key = re.sub(
r"^transformer.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key
)
key = re.sub(
r"^transformer.layers.(\d+).post_attention_layernorm.",
r"transformer.layers.\1.norm2.",
key,
)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.', r'transformer.layers.\1.mlp.fc1.', key)
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.', r'transformer.layers.\1.mlp.fc2.', key)
key = re.sub(
r"^transformer.layers.(\d+).mlp.dense_h_to_4h.", r"transformer.layers.\1.mlp.fc1.", key
)
key = re.sub(
r"^transformer.layers.(\d+).mlp.dense_4h_to_h.", r"transformer.layers.\1.mlp.fc2.", key
)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for l in range(config.n_layer):
# We don't store these biases
state_dict.pop(f'transformer.layers.{l}.attention.bias')
state_dict.pop(f'transformer.layers.{l}.attention.masked_bias')
state_dict.pop(f"transformer.layers.{l}.attention.bias")
state_dict.pop(f"transformer.layers.{l}.attention.masked_bias")
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
headdim = config.hidden_size // config.num_attention_heads
Wqkv = state_dict.pop(f'transformer.layers.{l}.attention.query_key_value.weight')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = rearrange(
Wqkv, '(nheads three headdim) ... -> (three nheads headdim) ...',
three=3, headdim=headdim
Wqkv = state_dict.pop(f"transformer.layers.{l}.attention.query_key_value.weight")
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = rearrange(
Wqkv,
"(nheads three headdim) ... -> (three nheads headdim) ...",
three=3,
headdim=headdim,
)
bqkv = state_dict.pop(f'transformer.layers.{l}.attention.query_key_value.bias')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = rearrange(
bqkv, '(nheads three headdim) -> (three nheads headdim)',
three=3, headdim=headdim
bqkv = state_dict.pop(f"transformer.layers.{l}.attention.query_key_value.bias")
state_dict[f"transformer.layers.{l}.mixer.Wqkv.bias"] = rearrange(
bqkv, "(nheads three headdim) -> (three nheads headdim)", three=3, headdim=headdim
)
def key_mapping_attn(key):
key = re.sub(r'^transformer.layers.(\d+).attention.dense.',
r'transformer.layers.\1.mixer.out_proj.', key)
key = re.sub(r'^transformer.layers.(\d+).attention.rotary_emb.',
r'transformer.layers.\1.mixer.rotary_emb.', key)
key = re.sub(
r"^transformer.layers.(\d+).attention.dense.",
r"transformer.layers.\1.mixer.out_proj.",
key,
)
key = re.sub(
r"^transformer.layers.(\d+).attention.rotary_emb.",
r"transformer.layers.\1.mixer.rotary_emb.",
key,
)
return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
......
......@@ -2,67 +2,78 @@
import math
import re
from collections import OrderedDict
import torch
import torch.nn.functional as F
from transformers import GPT2Config, GPTJConfig
def remap_state_dict_hf_gptj(state_dict, config):
def key_mapping_layers(key):
return re.sub(r'^transformer.h.', 'transformer.layers.', key)
return re.sub(r"^transformer.h.", "transformer.layers.", key)
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding
def key_mapping_emb(key):
return re.sub(r'^transformer.wte.', 'transformer.embeddings.word_embeddings.', key)
return re.sub(r"^transformer.wte.", "transformer.embeddings.word_embeddings.", key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
if getattr(config, "tie_word_embeddings"):
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
else:
output_embeddings = state_dict.pop('lm_head.weight')
output_embeddings = state_dict.pop("lm_head.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
state_dict["lm_head.weight"] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
output_embeddings_bias = state_dict.pop('lm_head.bias')
state_dict['lm_head.bias'] = F.pad(
output_embeddings_bias = state_dict.pop("lm_head.bias")
state_dict["lm_head.bias"] = F.pad(
output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
)
# LayerNorm
def key_mapping_ln(key):
return re.sub(r'^transformer.layers.(\d+).ln_1.', r'transformer.layers.\1.norm1.', key)
return re.sub(r"^transformer.layers.(\d+).ln_1.", r"transformer.layers.\1.norm1.", key)
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
key = re.sub(r'^transformer.layers.(\d+).mlp.fc_in.', r'transformer.layers.\1.mlp.fc1.', key)
key = re.sub(r'^transformer.layers.(\d+).mlp.fc_out.', r'transformer.layers.\1.mlp.fc2.', key)
key = re.sub(
r"^transformer.layers.(\d+).mlp.fc_in.", r"transformer.layers.\1.mlp.fc1.", key
)
key = re.sub(
r"^transformer.layers.(\d+).mlp.fc_out.", r"transformer.layers.\1.mlp.fc2.", key
)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for l in range(config.n_layer):
Wq = state_dict.pop(f'transformer.layers.{l}.attn.q_proj.weight')
Wk = state_dict.pop(f'transformer.layers.{l}.attn.k_proj.weight')
Wv = state_dict.pop(f'transformer.layers.{l}.attn.v_proj.weight')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
Wq = state_dict.pop(f"transformer.layers.{l}.attn.q_proj.weight")
Wk = state_dict.pop(f"transformer.layers.{l}.attn.k_proj.weight")
Wv = state_dict.pop(f"transformer.layers.{l}.attn.v_proj.weight")
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
# We don't store these biases
state_dict.pop(f'transformer.layers.{l}.attn.bias')
state_dict.pop(f'transformer.layers.{l}.attn.masked_bias')
state_dict.pop(f"transformer.layers.{l}.attn.bias")
state_dict.pop(f"transformer.layers.{l}.attn.masked_bias")
def key_mapping_attn(key):
return re.sub(r'^transformer.layers.(\d+).attn.out_proj.',
r'transformer.layers.\1.mixer.out_proj.', key)
return re.sub(
r"^transformer.layers.(\d+).attn.out_proj.",
r"transformer.layers.\1.mixer.out_proj.",
key,
)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
......
......@@ -15,63 +15,81 @@ from transformers import GPT2Config, LlamaConfig
def remap_state_dict_meta_llama(state_dict, config):
def key_mapping_layers(key):
return f'transformer.{key}' if not key.startswith('output.') else key
return f"transformer.{key}" if not key.startswith("output.") else key
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding
def key_mapping_emb(key):
return re.sub(r'^transformer.tok_embeddings.', 'transformer.embeddings.word_embeddings.', key)
return re.sub(
r"^transformer.tok_embeddings.", "transformer.embeddings.word_embeddings.", key
)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = (
math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
)
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
if getattr(config, "tie_word_embeddings"):
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
else:
output_embeddings = state_dict.pop('output.weight')
output_embeddings = state_dict.pop("output.weight")
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
# differently.
vocab_size = (math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
vocab_size = (
math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
state_dict["lm_head.weight"] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.norm.', r'transformer.ln_f.', key)
key = re.sub(r'^transformer.layers.(\d+).attention_norm.', r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).ffn_norm.', r'transformer.layers.\1.norm2.', key)
key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key)
key = re.sub(
r"^transformer.layers.(\d+).attention_norm.", r"transformer.layers.\1.norm1.", key
)
key = re.sub(r"^transformer.layers.(\d+).ffn_norm.", r"transformer.layers.\1.norm2.", key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
for l in range(config.n_layer):
w1 = state_dict.pop(f'transformer.layers.{l}.feed_forward.w1.weight')
w3 = state_dict.pop(f'transformer.layers.{l}.feed_forward.w3.weight')
w1 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w1.weight")
w3 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w3.weight")
# Our ordering is different
state_dict[f'transformer.layers.{l}.mlp.fc1.weight'] = torch.cat([w3, w1], dim=0)
state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0)
def key_mapping_mlp(key):
return re.sub(r'^transformer.layers.(\d+).feed_forward.w2.',
r'transformer.layers.\1.mlp.fc2.', key)
return re.sub(
r"^transformer.layers.(\d+).feed_forward.w2.", r"transformer.layers.\1.mlp.fc2.", key
)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for l in range(config.n_layer):
Wq = state_dict.pop(f'transformer.layers.{l}.attention.wq.weight')
Wk = state_dict.pop(f'transformer.layers.{l}.attention.wk.weight')
Wv = state_dict.pop(f'transformer.layers.{l}.attention.wv.weight')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
Wq = state_dict.pop(f"transformer.layers.{l}.attention.wq.weight")
Wk = state_dict.pop(f"transformer.layers.{l}.attention.wk.weight")
Wv = state_dict.pop(f"transformer.layers.{l}.attention.wv.weight")
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
# We don't store these
state_dict.pop(f'transformer.layers.{l}.attention.inner_attention.rope.freqs', None)
state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None)
def key_mapping_attn(key):
return re.sub(r'^transformer.layers.(\d+).attention.wo.',
r'transformer.layers.\1.mixer.out_proj.', key)
return re.sub(
r"^transformer.layers.(\d+).attention.wo.",
r"transformer.layers.\1.mixer.out_proj.",
key,
)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
state_dict.pop("transformer.rope.freqs", None)
......@@ -82,29 +100,32 @@ def remap_state_dict_meta_llama(state_dict, config):
def remap_state_dict_hf_llama(state_dict, config):
# Embedding
def key_mapping_emb(key):
return re.sub(r'^model.embed_tokens.', 'transformer.embeddings.word_embeddings.', key)
return re.sub(r"^model.embed_tokens.", "transformer.embeddings.word_embeddings.", key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = (
math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
)
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
# LM head
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
if getattr(config, "tie_word_embeddings"):
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
else:
output_embeddings = state_dict.pop('lm_head.weight')
output_embeddings = state_dict.pop("lm_head.weight")
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
# differently.
vocab_size = (math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
vocab_size = (
math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
state_dict["lm_head.weight"] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
......@@ -113,21 +134,22 @@ def remap_state_dict_hf_llama(state_dict, config):
# Fusing weights this way based on difference in the following:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220
# https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115
w1 = state_dict.pop(f'model.layers.{l}.mlp.gate_proj.weight')
w3 = state_dict.pop(f'model.layers.{l}.mlp.up_proj.weight')
state_dict[f'transformer.layers.{l}.mlp.fc1.weight'] = torch.cat([w3, w1], dim=0)
w1 = state_dict.pop(f"model.layers.{l}.mlp.gate_proj.weight")
w3 = state_dict.pop(f"model.layers.{l}.mlp.up_proj.weight")
state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0)
def key_mapping_mlp(key):
return re.sub(r'^model.layers.(\d+).mlp.down_proj.',
r'transformer.layers.\1.mlp.fc2.', key)
return re.sub(r"^model.layers.(\d+).mlp.down_proj.", r"transformer.layers.\1.mlp.fc2.", key)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^model.norm.', r'transformer.ln_f.', key)
key = re.sub(r'^model.layers.(\d+).input_layernorm.', r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^model.layers.(\d+).post_attention_layernorm.', r'transformer.layers.\1.norm2.', key)
key = re.sub(r"^model.norm.", r"transformer.ln_f.", key)
key = re.sub(r"^model.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key)
key = re.sub(
r"^model.layers.(\d+).post_attention_layernorm.", r"transformer.layers.\1.norm2.", key
)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
......@@ -135,42 +157,52 @@ def remap_state_dict_hf_llama(state_dict, config):
def inv_permute(w):
# Inverse of permute implemented in:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
return w.reshape(
config.n_head, 2, config.n_embd // config.n_head // 2, config.n_embd
).transpose(1, 2).reshape(config.n_embd, config.n_embd)
return (
w.reshape(config.n_head, 2, config.n_embd // config.n_head // 2, config.n_embd)
.transpose(1, 2)
.reshape(config.n_embd, config.n_embd)
)
# Attention
for l in range(config.n_layer):
Wq = state_dict.pop(f'model.layers.{l}.self_attn.q_proj.weight')
Wk = state_dict.pop(f'model.layers.{l}.self_attn.k_proj.weight')
Wv = state_dict.pop(f'model.layers.{l}.self_attn.v_proj.weight')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat(
Wq = state_dict.pop(f"model.layers.{l}.self_attn.q_proj.weight")
Wk = state_dict.pop(f"model.layers.{l}.self_attn.k_proj.weight")
Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight")
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat(
[inv_permute(Wq), inv_permute(Wk), Wv], dim=0
)
# We don't store these
state_dict.pop(f'model.layers.{l}.self_attn.rotary_emb.inv_freq', None)
state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None)
def key_mapping_attn(key):
return re.sub(r'^model.layers.(\d+).self_attn.o_proj.',
r'transformer.layers.\1.mixer.out_proj.', key)
return re.sub(
r"^model.layers.(\d+).self_attn.o_proj.", r"transformer.layers.\1.mixer.out_proj.", key
)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
def config_from_meta_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> LlamaConfig:
def config_from_meta_checkpoint(
checkpoint_path: Union[str, os.PathLike], model_name: str
) -> LlamaConfig:
"""Load a LlamaConfig from a checkpoint path."""
with open(Path(checkpoint_path) / model_name / 'params.json') as f:
with open(Path(checkpoint_path) / model_name / "params.json") as f:
params = json.load(f)
config = LlamaConfig(hidden_size=params['dim'], intermediate_size=None,
num_attention_heads=params['n_heads'],
num_hidden_layers=params['n_layers'],
rms_norm_eps=params['norm_eps'])
config = LlamaConfig(
hidden_size=params["dim"],
intermediate_size=None,
num_attention_heads=params["n_heads"],
num_hidden_layers=params["n_layers"],
rms_norm_eps=params["norm_eps"],
)
return config
def config_from_hf_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> LlamaConfig:
return LlamaConfig.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf' / "config.json")
def config_from_hf_checkpoint(
checkpoint_path: Union[str, os.PathLike], model_name: str
) -> LlamaConfig:
return LlamaConfig.from_pretrained(Path(checkpoint_path) / f"{model_name}-hf" / "config.json")
def config_from_checkpoint(
......@@ -182,10 +214,14 @@ def config_from_checkpoint(
return config_from_hf_checkpoint(checkpoint_path, model_name)
def state_dicts_from_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> list[dict]:
def state_dicts_from_checkpoint(
checkpoint_path: Union[str, os.PathLike], model_name: str
) -> list[dict]:
# Need to sort, otherwise we mess up the ordering and the weights are wrong
return [torch.load(path, map_location='cpu')
for path in sorted((Path(checkpoint_path) / model_name).glob('consolidated.*.pth'))]
return [
torch.load(path, map_location="cpu")
for path in sorted((Path(checkpoint_path) / model_name).glob("consolidated.*.pth"))
]
def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
......@@ -196,7 +232,7 @@ def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
n_layer=llama_config.num_hidden_layers,
n_head=llama_config.num_attention_heads,
n_inner=llama_config.intermediate_size,
activation_function='swiglu', # Hardcode since HF calls it 'silu'
activation_function="swiglu", # Hardcode since HF calls it 'silu'
# Llama doesn't have dropout, idk if it's because they only release the inference code
resid_pdrop=0.0,
embd_pdrop=0.0,
......
......@@ -2,75 +2,86 @@
import math
import re
from collections import OrderedDict
import torch
import torch.nn.functional as F
from transformers import GPT2Config, OPTConfig
def remap_state_dict_hf_opt(state_dict, config):
def key_mapping_model(key):
key = re.sub(r'^model.decoder.', 'transformer.', key)
key = re.sub(r"^model.decoder.", "transformer.", key)
# The OPT-350m model uses '^decoder' instead of '^model.decoder'
key = re.sub(r'^decoder.', 'transformer.', key)
key = re.sub(r"^decoder.", "transformer.", key)
return key
state_dict = OrderedDict((key_mapping_model(k), v) for k, v in state_dict.items())
# Word embedding and position embedding
def key_mapping_emb(key):
key = re.sub(r'^transformer.embed_tokens.', 'transformer.embeddings.word_embeddings.', key)
key = re.sub(r"^transformer.embed_tokens.", "transformer.embeddings.word_embeddings.", key)
# The OPT-350m model uses has project_in and project_out
key = re.sub(r'^transformer.project_in.', 'transformer.embeddings.project_in.', key)
key = re.sub(r'^transformer.project_out.', 'project_out.', key)
key = re.sub(r'^transformer.embed_positions.',
'transformer.embeddings.position_embeddings.', key)
key = re.sub(r"^transformer.project_in.", "transformer.embeddings.project_in.", key)
key = re.sub(r"^transformer.project_out.", "project_out.", key)
key = re.sub(
r"^transformer.embed_positions.", "transformer.embeddings.position_embeddings.", key
)
return key
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
# OPT uses the first 2 indices of pos_emb for padding tokens
pos_embeddings = state_dict.pop('transformer.embeddings.position_embeddings.weight')
state_dict['transformer.embeddings.position_embeddings.weight'] = pos_embeddings[2:]
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
pos_embeddings = state_dict.pop("transformer.embeddings.position_embeddings.weight")
state_dict["transformer.embeddings.position_embeddings.weight"] = pos_embeddings[2:]
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key)
key = re.sub(r"^transformer.final_layer_norm.", r"transformer.ln_f.", key)
# The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm'
key = re.sub(r'^transformer.layer_norm.', r'transformer.ln_f.', key)
key = re.sub(r'^transformer.layers.(\d+).self_attn_layer_norm.',
r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).final_layer_norm.',
r'transformer.layers.\1.norm2.', key)
key = re.sub(r"^transformer.layer_norm.", r"transformer.ln_f.", key)
key = re.sub(
r"^transformer.layers.(\d+).self_attn_layer_norm.", r"transformer.layers.\1.norm1.", key
)
key = re.sub(
r"^transformer.layers.(\d+).final_layer_norm.", r"transformer.layers.\1.norm2.", key
)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
return re.sub(r'^transformer.layers.(\d+).fc(1|2).',
r'transformer.layers.\1.mlp.fc\2.', key)
return re.sub(
r"^transformer.layers.(\d+).fc(1|2).", r"transformer.layers.\1.mlp.fc\2.", key
)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for l in range(config.n_layer):
Wq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.weight')
Wk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.weight')
Wv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.weight')
bq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.bias')
bk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.bias')
bv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.bias')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = torch.cat([bq, bk, bv], dim=0)
Wq = state_dict.pop(f"transformer.layers.{l}.self_attn.q_proj.weight")
Wk = state_dict.pop(f"transformer.layers.{l}.self_attn.k_proj.weight")
Wv = state_dict.pop(f"transformer.layers.{l}.self_attn.v_proj.weight")
bq = state_dict.pop(f"transformer.layers.{l}.self_attn.q_proj.bias")
bk = state_dict.pop(f"transformer.layers.{l}.self_attn.k_proj.bias")
bv = state_dict.pop(f"transformer.layers.{l}.self_attn.v_proj.bias")
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
state_dict[f"transformer.layers.{l}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
def key_mapping_attn(key):
return re.sub(r'^transformer.layers.(\d+).self_attn.out_proj.',
r'transformer.layers.\1.mixer.out_proj.', key)
return re.sub(
r"^transformer.layers.(\d+).self_attn.out_proj.",
r"transformer.layers.\1.mixer.out_proj.",
key,
)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
......@@ -79,8 +90,11 @@ def remap_state_dict_hf_opt(state_dict, config):
def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config:
assert opt_config.layerdrop == 0.0
assert opt_config.layer_norm_elementwise_affine
word_embed_proj_dim = (None if opt_config.word_embed_proj_dim == opt_config.hidden_size
else opt_config.word_embed_proj_dim)
word_embed_proj_dim = (
None
if opt_config.word_embed_proj_dim == opt_config.hidden_size
else opt_config.word_embed_proj_dim
)
return GPT2Config(
vocab_size=opt_config.vocab_size,
n_positions=opt_config.max_position_embeddings,
......@@ -98,5 +112,5 @@ def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config:
eos_token_id=opt_config.eos_token_id,
# These are new arguments not in the original GPT2Config
prenorm=opt_config.do_layer_norm_before,
word_embed_proj_dim=word_embed_proj_dim
word_embed_proj_dim=word_embed_proj_dim,
)
......@@ -10,13 +10,14 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from timm.models.helpers import named_apply
from torch.nn.init import trunc_normal_
from torchvision.ops import StochasticDepth
from flash_attn.layers.patch_embed import PatchEmbed
from flash_attn.modules.block import Block
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import FusedMLP, Mlp
from timm.models.helpers import named_apply
from torch.nn.init import trunc_normal_
from torchvision.ops import StochasticDepth
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm
......
# Copyright (c) 2022, Tri Dao.
from typing import Optional
from functools import partial
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torchvision.ops import StochasticDepth
from flash_attn.modules.mha import MHA
......@@ -35,11 +34,24 @@ except ImportError:
class Block(nn.Module):
def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
dropout_cls=nn.Dropout, prenorm=True, resid_dropout1=0., resid_dropout2=0.,
drop_path1=0., drop_path2=0., fused_dropout_add_ln=False, return_residual=False,
residual_in_fp32=False, sequence_parallel=False, mark_shared_params=False):
def __init__(
self,
dim,
mixer_cls=None,
mlp_cls=None,
norm_cls=nn.LayerNorm,
dropout_cls=nn.Dropout,
prenorm=True,
resid_dropout1=0.0,
resid_dropout2=0.0,
drop_path1=0.0,
drop_path2=0.0,
fused_dropout_add_ln=False,
return_residual=False,
residual_in_fp32=False,
sequence_parallel=False,
mark_shared_params=False,
):
"""
For prenorm=True, this Block has a slightly different structure compared to a regular
prenorm Transformer block.
......@@ -63,26 +75,27 @@ class Block(nn.Module):
self.return_residual = return_residual
self.residual_in_fp32 = residual_in_fp32
if self.residual_in_fp32:
assert self.prenorm, 'residual_in_fp32 is only compatible with prenorm=True'
assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
if mixer_cls is None:
mixer_cls = partial(MHA, num_heads=dim // 64)
if mlp_cls is None:
mlp_cls = partial(Mlp, hidden_features=4 * dim)
self.mixer = mixer_cls(dim)
self.dropout1 = dropout_cls(resid_dropout1)
self.drop_path1 = StochasticDepth(drop_path1, mode='row')
self.drop_path1 = StochasticDepth(drop_path1, mode="row")
self.norm1 = norm_cls(dim)
self.mlp = mlp_cls(dim)
if not isinstance(self.mlp, nn.Identity):
self.dropout2 = dropout_cls(resid_dropout2)
self.drop_path2 = StochasticDepth(drop_path2, mode='row')
self.drop_path2 = StochasticDepth(drop_path2, mode="row")
self.norm2 = norm_cls(dim)
if self.fused_dropout_add_ln:
assert dropout_add_layer_norm is not None, 'dropout_layer_norm is not installed'
assert dropout_add_rms_norm is not None, 'dropout_layer_norm is not installed'
assert (isinstance(self.norm1, (nn.LayerNorm, RMSNorm))
and isinstance(self.dropout1, nn.Dropout))
assert dropout_add_layer_norm is not None, "dropout_layer_norm is not installed"
assert dropout_add_rms_norm is not None, "dropout_layer_norm is not installed"
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
self.dropout1, nn.Dropout
)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# then the input to each worker in the tensor parallel group will be different.
......@@ -94,22 +107,27 @@ class Block(nn.Module):
if sequence_parallel:
for p in self.norm1.parameters():
p._sequence_parallel = True
if hasattr(self, 'norm2'):
if hasattr(self, "norm2"):
for p in self.norm2.parameters():
p._sequence_parallel = True
# Mark the norm parameters as "shared_params" so that we sync their values at init.
if mark_shared_params:
for p in self.norm1.parameters():
p._shared_params = True
if hasattr(self, 'norm2'):
if hasattr(self, "norm2"):
for p in self.norm2.parameters():
p._shared_params = True
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None,
mixer_subset=None, mixer_kwargs=None):
def forward(
self,
hidden_states: Tensor,
residual: Optional[Tensor] = None,
mixer_subset=None,
mixer_kwargs=None,
):
r"""Pass the input through the encoder layer.
Args:
......@@ -119,8 +137,11 @@ class Block(nn.Module):
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
"""
fused_add_norm_fn = (dropout_add_rms_norm if RMSNorm and isinstance(self.norm1, RMSNorm)
else dropout_add_layer_norm)
fused_add_norm_fn = (
dropout_add_rms_norm
if RMSNorm and isinstance(self.norm1, RMSNorm)
else dropout_add_layer_norm
)
if self.prenorm:
if not self.fused_dropout_add_ln:
dropped = self.drop_path1(self.dropout1(hidden_states))
......@@ -132,19 +153,28 @@ class Block(nn.Module):
if self.drop_path1.p == 0 or not self.training:
rowscale1 = None
else:
rowscale1 = self.drop_path1(torch.ones(
hidden_states.shape[:-1], device=hidden_states.device,
dtype=hidden_states.dtype)
rowscale1 = self.drop_path1(
torch.ones(
hidden_states.shape[:-1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
)
hidden_states, residual = fused_add_norm_fn(
hidden_states, residual, self.norm1.weight, self.norm1.bias,
self.dropout1.p if self.training else 0.0, self.norm1.eps,
rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32
hidden_states,
residual,
self.norm1.weight,
self.norm1.bias,
self.dropout1.p if self.training else 0.0,
self.norm1.eps,
rowscale=rowscale1,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
)
if mixer_kwargs is None:
mixer_kwargs = {}
if mixer_subset is not None:
mixer_kwargs['mixer_subset'] = mixer_subset
mixer_kwargs["mixer_subset"] = mixer_subset
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
if mixer_subset is not None:
residual = residual[:, mixer_subset]
......@@ -159,14 +189,23 @@ class Block(nn.Module):
if self.drop_path2.p == 0 or not self.training:
rowscale2 = None
else:
rowscale2 = self.drop_path2(torch.ones(
hidden_states.shape[:-1], device=hidden_states.device,
dtype=hidden_states.dtype)
rowscale2 = self.drop_path2(
torch.ones(
hidden_states.shape[:-1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
)
hidden_states, residual = fused_add_norm_fn(
hidden_states, residual, self.norm2.weight, self.norm2.bias,
self.dropout2.p if self.training else 0.0, self.norm2.eps,
rowscale=rowscale2, prenorm=True, residual_in_fp32=self.residual_in_fp32
hidden_states,
residual,
self.norm2.weight,
self.norm2.bias,
self.dropout2.p if self.training else 0.0,
self.norm2.eps,
rowscale=rowscale2,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
......@@ -178,38 +217,58 @@ class Block(nn.Module):
if self.return_residual: # mixer out is actually a pair here
mixer_out, hidden_states = mixer_out
if not self.fused_dropout_add_ln:
hidden_states = self.norm1((self.drop_path1(self.dropout1(mixer_out))
+ hidden_states).to(dtype=self.norm1.weight.dtype))
hidden_states = self.norm1(
(self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
dtype=self.norm1.weight.dtype
)
)
else:
if self.drop_path1.p == 0 or not self.training:
rowscale1 = None
else:
rowscale1 = self.drop_path1(torch.ones(
mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype)
rowscale1 = self.drop_path1(
torch.ones(
mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype
)
)
hidden_states = fused_add_norm_fn(
mixer_out, hidden_states, self.norm1.weight, self.norm1.bias,
self.dropout1.p if self.training else 0.0, self.norm1.eps,
rowscale=rowscale1, prenorm=False
mixer_out,
hidden_states,
self.norm1.weight,
self.norm1.bias,
self.dropout1.p if self.training else 0.0,
self.norm1.eps,
rowscale=rowscale1,
prenorm=False,
)
if not isinstance(self.mlp, nn.Identity):
mlp_out = self.mlp(hidden_states)
if self.return_residual: # mlp out is actually a pair here
mlp_out, hidden_states = mlp_out
if not self.fused_dropout_add_ln:
hidden_states = self.norm2((self.drop_path2(self.dropout2(mlp_out))
+ hidden_states).to(dtype=self.norm2.weight.dtype))
hidden_states = self.norm2(
(self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
dtype=self.norm2.weight.dtype
)
)
else:
if self.drop_path2.p == 0 or not self.training:
rowscale2 = None
else:
rowscale2 = self.drop_path2(torch.ones(
mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype)
rowscale2 = self.drop_path2(
torch.ones(
mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype
)
)
hidden_states = fused_add_norm_fn(
mlp_out, hidden_states, self.norm2.weight, self.norm2.bias,
self.dropout2.p if self.training else 0.0, self.norm2.eps,
rowscale=rowscale2, prenorm=False
mlp_out,
hidden_states,
self.norm2.weight,
self.norm2.bias,
self.dropout2.p if self.training else 0.0,
self.norm2.eps,
rowscale=rowscale2,
prenorm=False,
)
return hidden_states
......@@ -219,10 +278,21 @@ class ParallelBlock(nn.Module):
and PaLM.
"""
def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
dropout_cls=nn.Dropout, resid_dropout1=0., resid_dropout2=0.,
tied_norm=False, fused_dropout_add_ln=False, residual_in_fp32=False,
sequence_parallel=False, mark_shared_params=False):
def __init__(
self,
dim,
mixer_cls=None,
mlp_cls=None,
norm_cls=nn.LayerNorm,
dropout_cls=nn.Dropout,
resid_dropout1=0.0,
resid_dropout2=0.0,
tied_norm=False,
fused_dropout_add_ln=False,
residual_in_fp32=False,
sequence_parallel=False,
mark_shared_params=False,
):
"""
This Block has a slightly different structure compared to a regular
prenorm Transformer block.
......@@ -250,10 +320,15 @@ class ParallelBlock(nn.Module):
self.norm2 = norm_cls(dim)
if self.fused_dropout_add_ln:
assert dropout_add_layer_norm_parallel_residual is not None, 'dropout_layer_norm is not installed'
assert dropout_add_rms_norm_parallel_residual is not None, 'dropout_layer_norm is not installed'
assert (isinstance(self.norm1, (nn.LayerNorm, RMSNorm))
and isinstance(self.dropout1, nn.Dropout))
assert (
dropout_add_layer_norm_parallel_residual is not None
), "dropout_layer_norm is not installed"
assert (
dropout_add_rms_norm_parallel_residual is not None
), "dropout_layer_norm is not installed"
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
self.dropout1, nn.Dropout
)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# then the input to each worker in the tensor parallel group will be different.
......@@ -265,22 +340,27 @@ class ParallelBlock(nn.Module):
if sequence_parallel:
for p in self.norm1.parameters():
p._sequence_parallel = True
if hasattr(self, 'norm2'):
if hasattr(self, "norm2"):
for p in self.norm2.parameters():
p._sequence_parallel = True
# Mark the norm parameters as "shared_params" so that we sync their values at init.
if mark_shared_params:
for p in self.norm1.parameters():
p._shared_params = True
if hasattr(self, 'norm2'):
if hasattr(self, "norm2"):
for p in self.norm2.parameters():
p._shared_params = True
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
def forward(self, hidden_states1: Tensor, hidden_states2: Optional[Tensor] = None,
residual: Optional[Tensor] = None, mixer_kwargs=None):
def forward(
self,
hidden_states1: Tensor,
hidden_states2: Optional[Tensor] = None,
residual: Optional[Tensor] = None,
mixer_kwargs=None,
):
r"""Pass the input through the encoder layer.
Args:
......@@ -290,30 +370,47 @@ class ParallelBlock(nn.Module):
"""
# TODO: Ideally we should only do the allgather / allreduce once for
# the Linear to MLP & Attention
fused_add_norm_fn = (dropout_add_rms_norm_parallel_residual
fused_add_norm_fn = (
dropout_add_rms_norm_parallel_residual
if isinstance(self.norm1, RMSNorm)
else dropout_add_layer_norm_parallel_residual)
else dropout_add_layer_norm_parallel_residual
)
if not self.fused_dropout_add_ln:
dropped1 = self.dropout1(hidden_states1)
# For the very 1st block, we only want 1 dropout, not two different dropouts
if hidden_states2 is not None:
dropped2 = self.dropout2(hidden_states2)
residual = ((residual + dropped1 + dropped2)
if residual is not None else dropped1 + dropped2)
residual = (
(residual + dropped1 + dropped2)
if residual is not None
else dropped1 + dropped2
)
else:
residual = (residual + dropped1) if residual is not None else dropped1
hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
hidden_states2 = (self.norm2(residual.to(dtype=self.norm2.weight.dtype))
if not self.tied_norm else hidden_states1)
hidden_states2 = (
self.norm2(residual.to(dtype=self.norm2.weight.dtype))
if not self.tied_norm
else hidden_states1
)
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
weight2, bias2 = ((self.norm2.weight, self.norm2.bias)
if not self.tied_norm else (None, None))
weight2, bias2 = (
(self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None)
)
hidden_states1, hidden_states2, residual = fused_add_norm_fn(
hidden_states1, hidden_states2, residual, self.norm1.weight, self.norm1.bias,
weight2, bias2, self.dropout1.p if self.training else 0.0, self.norm1.eps,
prenorm=True, residual_in_fp32=self.residual_in_fp32
hidden_states1,
hidden_states2,
residual,
self.norm1.weight,
self.norm1.bias,
weight2,
bias2,
self.dropout1.p if self.training else 0.0,
self.norm1.eps,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
)
if self.tied_norm:
hidden_states2 = hidden_states1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment