Commit f1a73d07 authored by Tri Dao's avatar Tri Dao
Browse files

Run isort and black on python files

parent cbb4cf5f
__version__ = "2.0.8"
from flash_attn.flash_attn_interface import flash_attn_func
from flash_attn.flash_attn_interface import flash_attn_kvpacked_func
from flash_attn.flash_attn_interface import flash_attn_qkvpacked_func
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func
from flash_attn.flash_attn_interface import flash_attn_varlen_func
from flash_attn.flash_attn_interface import (
flash_attn_func,
flash_attn_kvpacked_func,
flash_attn_qkvpacked_func,
flash_attn_varlen_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
......@@ -2,12 +2,10 @@
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
class IndexFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
......@@ -16,20 +14,24 @@ class IndexFirstAxis(torch.autograd.Function):
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
return torch.gather(rearrange(input, 'b ... -> b (...)'), 0,
repeat(indices, 'z -> z d', d=second_dim)).reshape(-1, *other_shape)
return torch.gather(
rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
).reshape(-1, *other_shape)
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
grad_output = rearrange(grad_output, 'b ... -> b (...)')
grad_input = torch.zeros([ctx.first_axis_dim, grad_output.shape[1]],
device=grad_output.device, dtype=grad_output.dtype)
grad_output = rearrange(grad_output, "b ... -> b (...)")
grad_input = torch.zeros(
[ctx.first_axis_dim, grad_output.shape[1]],
device=grad_output.device,
dtype=grad_output.dtype,
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# grad_input[indices] = grad_output
grad_input.scatter_(0, repeat(indices, 'z -> z d', d=grad_output.shape[1]), grad_output)
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
......@@ -37,14 +39,14 @@ index_first_axis = IndexFirstAxis.apply
class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, values, indices, first_axis_dim):
ctx.save_for_backward(indices)
assert indices.ndim == 1
assert values.ndim >= 2
output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device,
dtype=values.dtype)
output = torch.zeros(
first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output[indices] = values
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
......@@ -52,7 +54,7 @@ class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
(indices,) = ctx.saved_tensors
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
grad_values = grad_output[indices]
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
......@@ -63,7 +65,6 @@ index_put_first_axis = IndexPutFirstAxis.apply
class IndexFirstAxisResidual(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
......@@ -79,7 +80,7 @@ class IndexFirstAxisResidual(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output, grad_residual):
indices, = ctx.saved_tensors
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
assert grad_residual.shape[1:] == other_shape
......@@ -113,8 +114,12 @@ def unpad_input(hidden_states, attention_mask):
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'), indices), indices,
cu_seqlens, max_seqlen_in_batch)
return (
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def pad_input(hidden_states, indices, batch, seqlen):
......@@ -129,4 +134,4 @@ def pad_input(hidden_states, indices, batch, seqlen):
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output[indices] = hidden_states
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
return rearrange(output, '(b s) ... -> b s ...', b=batch)
return rearrange(output, "(b s) ... -> b s ...", b=batch)
This diff is collapsed.
This diff is collapsed.
......@@ -11,22 +11,41 @@ This is a Triton implementation of the Flash Attention algorithm
import pytest
import torch
import triton
import triton.language as tl
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
Q,
K,
V,
sm_scale,
TMP,
L,
M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
stride_oz,
stride_oh,
stride_om,
stride_on,
Z,
H,
N_CTX,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
......@@ -100,9 +119,13 @@ def _fwd_kernel(
@triton.jit
def _bwd_preprocess(
Out, DO, L,
NewDO, Delta,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
Out,
DO,
L,
NewDO,
Delta,
BLOCK_M: tl.constexpr,
D_HEAD: tl.constexpr,
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
......@@ -120,16 +143,36 @@ def _bwd_preprocess(
@triton.jit
def _bwd_kernel(
Q, K, V, sm_scale, Out, DO,
DQ, DK, DV,
L, M,
Q,
K,
V,
sm_scale,
Out,
DO,
DQ,
DK,
DV,
L,
M,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
Z,
H,
N_CTX,
num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
off_hz = tl.program_id(0)
......@@ -203,7 +246,6 @@ def _bwd_kernel(
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sm_scale):
BLOCK = 128
......@@ -213,22 +255,45 @@ class _attention(torch.autograd.Function):
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
tmp = torch.empty(
(q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
_fwd_kernel[grid](
q, k, v, sm_scale,
tmp, L, m,
q,
k,
v,
sm_scale,
tmp,
L,
m,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk, num_warps=num_warps,
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
o.stride(0),
o.stride(1),
o.stride(2),
o.stride(3),
q.shape[0],
q.shape[1],
q.shape[2],
BLOCK_M=BLOCK,
BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk,
num_warps=num_warps,
num_stages=1,
)
ctx.save_for_backward(q, k, v, o, L, m)
......@@ -247,27 +312,51 @@ class _attention(torch.autograd.Function):
dv = torch.empty_like(v)
do_scaled = torch.empty_like(do)
delta = torch.empty_like(l)
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do, l,
do_scaled, delta,
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)](
o,
do,
l,
do_scaled,
delta,
BLOCK_M=ctx.BLOCK,
D_HEAD=ctx.BLOCK_DMODEL,
)
# NOTE: kernel currently buggy for other values of `num_warps`
num_warps = 8
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l, m,
q,
k,
v,
ctx.sm_scale,
o,
do_scaled,
dq,
dk,
dv,
l,
m,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
q.shape[0],
q.shape[1],
q.shape[2],
ctx.grid[0],
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,
BLOCK_M=ctx.BLOCK,
BLOCK_N=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
num_warps=num_warps,
num_stages=1,
)
return dq.to(q.dtype), dk, dv, None
......
import math
import hydra
import torch
import torch.nn as nn
from einops import rearrange
import hydra
from flash_attn.flash_blocksparse_attn_interface import flash_blocksparse_attn_func
from flash_attn.flash_blocksparse_attn_interface import convert_blockmask
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
from flash_attn.flash_blocksparse_attn_interface import (
convert_blockmask,
flash_blocksparse_attn_func,
)
class FlashBlocksparseAttention(nn.Module):
......@@ -21,8 +22,16 @@ class FlashBlocksparseAttention(nn.Module):
attention_dropout: The dropout rate to apply to the attention
(default: 0.1)
"""
def __init__(self, sparsity_config, softmax_temp=None, attention_dropout=0.0,
max_seq_length=2048, device=None, dtype=None):
def __init__(
self,
sparsity_config,
softmax_temp=None,
attention_dropout=0.0,
max_seq_length=2048,
device=None,
dtype=None,
):
super().__init__()
self.sparsity_config = hydra.utils.instantiate(sparsity_config)
self.softmax_temp = softmax_temp
......@@ -36,8 +45,17 @@ class FlashBlocksparseAttention(nn.Module):
self.register_buffer("blockmask_converted", blockmask_converted)
# logger.info(f'Attention class {self.__class__}: saving={self.layout.float().mean()}')
def forward(self, qkv, attn_mask=None, key_padding_mask=None, causal=False, cu_seqlens=None,
max_s=None, need_weights=False, convert_mask=True):
def forward(
self,
qkv,
attn_mask=None,
key_padding_mask=None,
causal=False,
cu_seqlens=None,
max_s=None,
need_weights=False,
convert_mask=True,
):
"""Implements the multihead softmax attention.
Arguments
---------
......@@ -57,47 +75,76 @@ class FlashBlocksparseAttention(nn.Module):
seqlen = qkv.shape[1]
# Convert mask to take a subset
seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256
assert seqlen_rounded // 16 <= self.layout.shape[0], seqlen_rounded // 256 <= self.layout.shape[1]
blockmask = self.layout[:seqlen_rounded // 16, :seqlen_rounded // 256]
assert seqlen_rounded // 16 <= self.layout.shape[0], (
seqlen_rounded // 256 <= self.layout.shape[1]
)
blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256]
if key_padding_mask is None:
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device)
cu_seqlens = torch.arange(
0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device
)
output = flash_blocksparse_attn_func(
qkv, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal
qkv,
cu_seqlens,
blockmask,
self.dropout_p if self.training else 0.0,
max_s,
softmax_scale=self.softmax_temp,
causal=causal,
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
else:
key_padding_mask_bool = key_padding_mask.bool_matrix
nheads = qkv.shape[-2]
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
output_unpad = flash_blocksparse_attn_func(
x_unpad, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal
x_unpad,
cu_seqlens,
blockmask,
self.dropout_p if self.training else 0.0,
max_s,
softmax_scale=self.softmax_temp,
causal=causal,
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen
),
"b s (h d) -> b s h d",
h=nheads,
)
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
indices, batch_size, seqlen),
'b s (h d) -> b s h d', h=nheads)
else:
assert max_s is not None
seqlen = max_s
# Convert mask to take a subset
seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256
assert seqlen_rounded // 16 <= self.layout.shape[0], seqlen_rounded // 256 <= self.layout.shape[1]
blockmask = self.layout[:seqlen_rounded // 16, :seqlen_rounded // 256]
assert seqlen_rounded // 16 <= self.layout.shape[0], (
seqlen_rounded // 256 <= self.layout.shape[1]
)
blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256]
if convert_mask:
output = flash_blocksparse_attn_func(
qkv, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal
qkv,
cu_seqlens,
blockmask,
self.dropout_p if self.training else 0.0,
max_s,
softmax_scale=self.softmax_temp,
causal=causal,
)
else:
output = flash_blocksparse_attn_func(
qkv, cu_seqlens, self.blockmask_converted, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal,
qkv,
cu_seqlens,
self.blockmask_converted,
self.dropout_p if self.training else 0.0,
max_s,
softmax_scale=self.softmax_temp,
causal=causal,
convert_mask=False,
)
......@@ -105,12 +152,22 @@ class FlashBlocksparseAttention(nn.Module):
class FlashBlocksparseMHA(nn.Module):
def __init__(self, embed_dim, num_heads, sparsity_config, bias=True, batch_first=True,
attention_dropout=0.0, causal=False, max_seq_length=2048,
device=None, dtype=None, **kwargs) -> None:
def __init__(
self,
embed_dim,
num_heads,
sparsity_config,
bias=True,
batch_first=True,
attention_dropout=0.0,
causal=False,
max_seq_length=2048,
device=None,
dtype=None,
**kwargs,
) -> None:
assert batch_first
factory_kwargs = {'device': device, 'dtype': dtype}
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.embed_dim = embed_dim
self.causal = causal
......@@ -122,15 +179,19 @@ class FlashBlocksparseMHA(nn.Module):
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
self.inner_attn = FlashBlocksparseAttention(
sparsity_config, attention_dropout=attention_dropout,
max_seq_length=max_seq_length, **factory_kwargs
sparsity_config,
attention_dropout=attention_dropout,
max_seq_length=max_seq_length,
**factory_kwargs,
)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
def forward(self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None,
need_weights=False):
def forward(
self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None, need_weights=False
):
qkv = self.Wqkv(x)
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask,
need_weights=need_weights, causal=self.causal)
return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
context, attn_weights = self.inner_attn(
qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
)
return self.out_proj(rearrange(context, "b s h d -> b s (h d)")), attn_weights
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py
import flash_attn_cuda
import torch
import torch.nn as nn
import flash_attn_cuda
def convert_blockmask(blockmask, causal):
"""Convert from the 0-1 format to the format used by the CUDA code.
......@@ -40,29 +39,51 @@ def convert_blockmask(blockmask, causal):
return nonzero_idx.T.contiguous().to(dtype=torch.int32)
def _flash_blocksparse_attn_forward(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale,
causal, return_softmax):
context, softmax_lse, *rest = flash_attn_cuda.fwd_block(qkv, cu_seqlens, blockmask, dropout_p,
max_s, softmax_scale, causal,
return_softmax, None)
def _flash_blocksparse_attn_forward(
qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax
):
context, softmax_lse, *rest = flash_attn_cuda.fwd_block(
qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax, None
)
# if context.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
S_dmask = rest[0] if return_softmax else None
return context, softmax_lse, S_dmask
def _flash_blocksparse_attn_backward(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, blockmask,
dropout_p, max_s, softmax_scale, causal):
dqkv, dp, softmax_d = flash_attn_cuda.bwd_block(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens,
blockmask, dropout_p, softmax_scale, max_s,
causal, None)
def _flash_blocksparse_attn_backward(
dout,
qkv,
out,
S_dmask,
softmax_lse,
cu_seqlens,
blockmask,
dropout_p,
max_s,
softmax_scale,
causal,
):
dqkv, dp, softmax_d = flash_attn_cuda.bwd_block(
dout,
qkv,
out,
S_dmask,
softmax_lse,
cu_seqlens,
blockmask,
dropout_p,
softmax_scale,
max_s,
causal,
None,
)
# if dqkv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return dqkv
class FlashBlocksparseAttnFun(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
# Save rng_state because the backward pass will regenerate the dropout mask
......@@ -70,8 +91,14 @@ class FlashBlocksparseAttnFun(torch.autograd.Function):
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(
qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal=causal,
return_softmax=False
qkv,
cu_seqlens,
blockmask,
dropout_p,
max_s,
softmax_scale,
causal=causal,
return_softmax=False,
)
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)
ctx.dropout_p = dropout_p
......@@ -88,8 +115,17 @@ class FlashBlocksparseAttnFun(torch.autograd.Function):
torch.cuda.set_rng_state(rng_state)
# S_dmask is None, temporarily use another tensor just to get it running
dqkv = _flash_blocksparse_attn_backward(
dout, qkv, context, context, softmax_lse, cu_seqlens, blockmask, ctx.dropout_p,
ctx.max_s, ctx.softmax_scale, ctx.causal
dout,
qkv,
context,
context,
softmax_lse,
cu_seqlens,
blockmask,
ctx.dropout_p,
ctx.max_s,
ctx.softmax_scale,
ctx.causal,
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
......@@ -99,7 +135,6 @@ class FlashBlocksparseAttnFun(torch.autograd.Function):
# We duplicate code to return both the output and the softmax for testing
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
# Save rng_state because the backward pass is gonna regenerate the dropout mask
......@@ -107,8 +142,14 @@ class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(
qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal=causal,
return_softmax=True
qkv,
cu_seqlens,
blockmask,
dropout_p,
max_s,
softmax_scale,
causal=causal,
return_softmax=True,
)
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)
ctx.dropout_p = dropout_p
......@@ -124,18 +165,35 @@ class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dqkv = _flash_blocksparse_attn_backward(
dout, qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, ctx.dropout_p,
ctx.max_s, ctx.softmax_scale, ctx.causal
dout,
qkv,
context,
S_dmask,
softmax_lse,
cu_seqlens,
blockmask,
ctx.dropout_p,
ctx.max_s,
ctx.softmax_scale,
ctx.causal,
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None
def flash_blocksparse_attn_func(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale=None,
causal=False, return_attn_probs=False, convert_mask=True):
"""dropout_p should be set to 0.0 during evaluation
"""
def flash_blocksparse_attn_func(
qkv,
cu_seqlens,
blockmask,
dropout_p,
max_s,
softmax_scale=None,
causal=False,
return_attn_probs=False,
convert_mask=True,
):
"""dropout_p should be set to 0.0 during evaluation"""
func = FlashBlocksparseAttnFun if not return_attn_probs else FlashBlocksparseAttnFunWithS
if convert_mask:
blockmask = convert_blockmask(blockmask, causal=causal)
......
......@@ -17,13 +17,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from apex._autocast_utils import _cast_if_autocast_enabled
from apex.transformer.enums import AttnMaskType
from fused_softmax_lib import scaled_masked_softmax_forward, scaled_masked_softmax_backward
from fused_softmax_lib import scaled_masked_softmax_get_batch_per_block
from fused_softmax_lib import scaled_upper_triang_masked_softmax_forward, scaled_upper_triang_masked_softmax_backward
from fused_softmax_lib import (
scaled_masked_softmax_backward,
scaled_masked_softmax_forward,
scaled_masked_softmax_get_batch_per_block,
scaled_upper_triang_masked_softmax_backward,
scaled_upper_triang_masked_softmax_forward,
)
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
......@@ -37,9 +39,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, scale):
scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax_forward(
inputs, scale_t[0]
)
softmax_results = scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
......@@ -81,9 +81,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
@staticmethod
def backward(ctx, output_grads):
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_masked_softmax_backward(
output_grads, softmax_results, scale_t[0]
)
input_grads = scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None
......@@ -122,9 +120,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
if self.input_in_fp16 and self.input_in_bf16:
raise RuntimeError(
"both fp16 and bf16 flags cannot be active at the same time."
)
raise RuntimeError("both fp16 and bf16 flags cannot be active at the same time.")
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
......
......@@ -4,11 +4,10 @@
from functools import partial
import torch.nn as nn
from einops import rearrange
from torch import _assert
from torch.nn.modules.utils import _pair
from einops import rearrange
try:
from flash_attn.ops.fused_dense import FusedDense
except ImportError:
......@@ -16,18 +15,18 @@ except ImportError:
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
"""2D Image to Patch Embedding"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
fused_bias_fc=False,
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
fused_bias_fc=False,
):
super().__init__()
img_size = _pair(img_size)
......@@ -38,7 +37,7 @@ class PatchEmbed(nn.Module):
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed')
raise ImportError("fused_dense is not installed")
linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense
self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias)
......@@ -46,11 +45,23 @@ class PatchEmbed(nn.Module):
def forward(self, x):
_, _, H, W = x.shape
_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
x = self.proj(rearrange(x, 'b c (h p1) (w p2) -> b h w (c p1 p2)',
p1=self.patch_size[0], p2=self.patch_size[1]))
_assert(
H == self.img_size[0],
f"Input image height ({H}) doesn't match model ({self.img_size[0]}).",
)
_assert(
W == self.img_size[1],
f"Input image width ({W}) doesn't match model ({self.img_size[1]}).",
)
x = self.proj(
rearrange(
x,
"b c (h p1) (w p2) -> b h w (c p1 p2)",
p1=self.patch_size[0],
p2=self.patch_size[1],
)
)
if self.flatten:
x = rearrange(x, 'b h w c -> b (h w) c')
x = rearrange(x, "b h w c -> b (h w) c")
x = self.norm(x)
return x
# Copyright (c) 2023, Tri Dao.
from typing import Tuple, Optional
import math
from typing import Optional, Tuple
import rotary_emb
import torch
from einops import rearrange, repeat
import rotary_emb
def rotate_half(x, interleaved=False):
if not interleaved:
......@@ -16,7 +14,7 @@ def rotate_half(x, interleaved=False):
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1), '... d two -> ... (d two)', two=2)
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
......@@ -26,14 +24,15 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(cos, 's d -> s 1 (2 d)')
sin = repeat(sin, 's d -> s 1 (2 d)')
return torch.cat([x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:]], dim=-1)
cos = repeat(cos, "s d -> s 1 (2 d)")
sin = repeat(sin, "s d -> s 1 (2 d)")
return torch.cat(
[x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
dim=-1,
)
class ApplyRotaryEmb(torch.autograd.Function):
@staticmethod
def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
"""
......@@ -57,10 +56,20 @@ class ApplyRotaryEmb(torch.autograd.Function):
if inplace:
o1, o2 = x1, x2
else:
o1, o2 = (out_ro.chunk(2, dim=-1) if not interleaved
else (out_ro[..., ::2], out_ro[..., 1::2]))
rotary_emb.apply_rotary(x1, x2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), o1, o2, False)
o1, o2 = (
out_ro.chunk(2, dim=-1)
if not interleaved
else (out_ro[..., ::2], out_ro[..., 1::2])
)
rotary_emb.apply_rotary(
x1,
x2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
o1,
o2,
False,
)
if not inplace and rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin)
......@@ -76,17 +85,28 @@ class ApplyRotaryEmb(torch.autograd.Function):
rotary_dim *= 2
inplace = ctx.inplace
do_ro = do[..., :rotary_dim]
do1, do2 = (do_ro.chunk(2, dim=-1) if not ctx.interleaved
else (do_ro[..., ::2], do_ro[..., 1::2]))
do1, do2 = (
do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2])
)
dx = torch.empty_like(do) if not inplace else do
if inplace:
dx1, dx2 = do1, do2
else:
dx_ro = dx[..., :rotary_dim]
dx1, dx2 = (dx_ro.chunk(2, dim=-1) if not ctx.interleaved
else (dx_ro[..., ::2], dx_ro[..., 1::2]))
rotary_emb.apply_rotary(do1, do2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), dx1, dx2, True)
dx1, dx2 = (
dx_ro.chunk(2, dim=-1)
if not ctx.interleaved
else (dx_ro[..., ::2], dx_ro[..., 1::2])
)
rotary_emb.apply_rotary(
do1,
do2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
dx1,
dx2,
True,
)
if not inplace and rotary_dim < headdim:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None, None
......@@ -96,7 +116,6 @@ apply_rotary_emb_func = ApplyRotaryEmb.apply
class ApplyRotaryEmbQKV_(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
"""
......@@ -119,12 +138,26 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
q_ro = qkv[:, :, 0, :, :rotary_dim]
q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2])
rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
rotary_emb.apply_rotary(
q1,
q2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
q1,
q2,
False,
)
k_ro = qkv[:, :, 1, :, :rotary_dim]
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False)
rotary_emb.apply_rotary(
k1,
k2,
rearrange(cos_k[:seqlen], "s d -> s 1 d"),
rearrange(sin_k[:seqlen], "s d -> s 1 d"),
k1,
k2,
False,
)
ctx.save_for_backward(cos, sin, cos_k, sin_k)
ctx.interleaved = interleaved
return qkv
......@@ -136,15 +169,31 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
rotary_dim = cos.shape[-1]
rotary_dim *= 2
dq_ro = dqkv[:, :, 0, :, :rotary_dim]
dq1, dq2 = (dq_ro.chunk(2, dim=-1) if not ctx.interleaved
else (dq_ro[..., ::2], dq_ro[..., 1::2]))
rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True)
dq1, dq2 = (
dq_ro.chunk(2, dim=-1) if not ctx.interleaved else (dq_ro[..., ::2], dq_ro[..., 1::2])
)
rotary_emb.apply_rotary(
dq1,
dq2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
dq1,
dq2,
True,
)
dk_ro = dqkv[:, :, 1, :, :rotary_dim]
dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
else (dk_ro[..., ::2], dk_ro[..., 1::2]))
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1, dk2, True)
dk1, dk2 = (
dk_ro.chunk(2, dim=-1) if not ctx.interleaved else (dk_ro[..., ::2], dk_ro[..., 1::2])
)
rotary_emb.apply_rotary(
dk1,
dk2,
rearrange(cos_k[:seqlen], "s d -> s 1 d"),
rearrange(sin_k[:seqlen], "s d -> s 1 d"),
dk1,
dk2,
True,
)
return dqkv, None, None, None, None, None
......@@ -152,7 +201,6 @@ apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
class ApplyRotaryEmbKV_(torch.autograd.Function):
@staticmethod
def forward(ctx, kv, cos, sin, interleaved=False):
"""
......@@ -171,9 +219,15 @@ class ApplyRotaryEmbKV_(torch.autograd.Function):
assert seqlen <= rotary_seqlen
k_ro = kv[:, :, 0, :, :rotary_dim]
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
rotary_emb.apply_rotary(k1, k2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), k1, k2,
False) # conj=False since this is the forward pass
rotary_emb.apply_rotary(
k1,
k2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
k1,
k2,
False,
) # conj=False since this is the forward pass
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
return kv
......@@ -185,11 +239,18 @@ class ApplyRotaryEmbKV_(torch.autograd.Function):
rotary_dim = cos.shape[-1]
rotary_dim *= 2
dk_ro = dkv[:, :, 0, :, :rotary_dim]
dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
else (dk_ro[..., ::2], dk_ro[..., 1::2]))
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), dk1, dk2,
True) # conj=True since this is the backward pass
dk1, dk2 = (
dk_ro.chunk(2, dim=-1) if not ctx.interleaved else (dk_ro[..., ::2], dk_ro[..., 1::2])
)
rotary_emb.apply_rotary(
dk1,
dk2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
dk1,
dk2,
True,
) # conj=True since this is the backward pass
return dkv, None, None, None
......@@ -214,21 +275,28 @@ class RotaryEmbedding(torch.nn.Module):
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
"""
def __init__(self, dim: int, base=10000.0, interleaved=False, scale_base=None,
pos_idx_in_fp32=True, device=None):
def __init__(
self,
dim: int,
base=10000.0,
interleaved=False,
scale_base=None,
pos_idx_in_fp32=True,
device=None,
):
"""
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
otherwise they might be in lower precision.
This option was added because previously (before 2023-07-02), when we construct
the position indices, we use the dtype of self.inv_freq. In most cases this would
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
self.inv_freq would be bf16, and the position indices are also in bf16.
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
embeddings for some positions will coincide.
To maintain compatibility with models previously trained in pure bf16,
we add this option.
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
otherwise they might be in lower precision.
This option was added because previously (before 2023-07-02), when we construct
the position indices, we use the dtype of self.inv_freq. In most cases this would
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
self.inv_freq would be bf16, and the position indices are also in bf16.
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
embeddings for some positions will coincide.
To maintain compatibility with models previously trained in pure bf16,
we add this option.
"""
super().__init__()
self.dim = dim
......@@ -239,8 +307,11 @@ class RotaryEmbedding(torch.nn.Module):
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.interleaved = interleaved
self.scale_base = scale_base
scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
/ (1.4 * dim) if scale_base is not None else None)
scale = (
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
if scale_base is not None
else None
)
self.register_buffer("scale", scale, persistent=False)
self._seq_len_cached = 0
......@@ -250,17 +321,21 @@ class RotaryEmbedding(torch.nn.Module):
self._sin_k_cached = None
def _compute_inv_freq(self, device=None):
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device,
dtype=torch.float32) / self.dim))
return 1.0 / (
self.base
** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
)
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
# Reset the tables if the sequence length has changed,
# if we're on a new device (possibly due to tracing for instance),
# or if we're switching from inference mode to training
if (seqlen > self._seq_len_cached or self._cos_cached.device != device
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
or (self.training and self._cos_cached.is_inference())):
or (self.training and self._cos_cached.is_inference())
):
self._seq_len_cached = seqlen
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
......@@ -285,17 +360,20 @@ class RotaryEmbedding(torch.nn.Module):
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
else:
power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
- seqlen // 2) / self.scale_base)
scale = self.scale.to(device=power.device) ** rearrange(power, 's -> s 1')
power = (
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
- seqlen // 2
) / self.scale_base
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
# We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
def forward(self, qkv: torch.Tensor, kv: Optional[torch.Tensor] = None,
seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(
self, qkv: torch.Tensor, kv: Optional[torch.Tensor] = None, seqlen_offset: int = 0
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
else it's just q of shape (batch, seqlen, nheads, headdim)
......@@ -308,29 +386,43 @@ class RotaryEmbedding(torch.nn.Module):
if kv is None:
if self.scale is None:
return apply_rotary_emb_qkv_(
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
None, None, self.interleaved
qkv,
self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:],
None,
None,
self.interleaved,
)
else:
return apply_rotary_emb_qkv_(
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:],
self.interleaved
qkv,
self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:],
self._cos_k_cached[seqlen_offset:],
self._sin_k_cached[seqlen_offset:],
self.interleaved,
)
else:
q = qkv
q = apply_rotary_emb_func(
q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
self.interleaved, True
q,
self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:],
self.interleaved,
True,
)
if self.scale is None:
kv = apply_rotary_emb_kv_(
kv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
self.interleaved
kv,
self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:],
self.interleaved,
)
else:
kv = apply_rotary_emb_kv_(
kv, self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:],
self.interleaved
kv,
self._cos_k_cached[seqlen_offset:],
self._sin_k_cached[seqlen_offset:],
self.interleaved,
)
return q, kv
......@@ -5,7 +5,6 @@
# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
import torch
import torch.nn as nn
import xentropy_cuda_lib
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
......@@ -17,10 +16,16 @@ if "all_gather_into_tensor" not in dir(torch.distributed):
class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels, smoothing=0.0, ignored_index=-100, inplace_backward=False,
process_group=None):
def forward(
ctx,
logits,
labels,
smoothing=0.0,
ignored_index=-100,
inplace_backward=False,
process_group=None,
):
"""
logits: (batch, vocab_size)
labels: (batch,)
......@@ -34,7 +39,7 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
if world_size == 1:
losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing)
losses.masked_fill_(labels==ignored_index, 0)
losses.masked_fill_(labels == ignored_index, 0)
labels_local = labels
else:
rank = torch.distributed.get_rank(process_group)
......@@ -48,8 +53,9 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
# For tensor parallel cross entropy with smoothing, we want to pass in the total number
# of classes so that smoothing can be applied correctly. If total_classes=-1, use the
# last dimension of the input tensor.
losses, lse_local = xentropy_cuda_lib.forward(logits, labels_local, smoothing,
world_size * vocab_size)
losses, lse_local = xentropy_cuda_lib.forward(
logits, labels_local, smoothing, world_size * vocab_size
)
assert lse_local.shape == (batch,)
assert losses.shape == (batch,)
losses.masked_fill_(ignored_mask, 0)
......@@ -61,10 +67,12 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
# For labels not in the vocab of this partition, losses contains
# 0.1 * (lse_local - sum logit / total_classes).
lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype,
device=lse_local.device)
torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(),
group=process_group)
lse_allgather = torch.empty(
world_size, batch, dtype=lse_local.dtype, device=lse_local.device
)
torch.distributed.all_gather_into_tensor(
lse_allgather, lse_local.contiguous(), group=process_group
)
handle_losses = torch.distributed.all_reduce(
losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
)
......@@ -74,16 +82,18 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
# If there's smoothing=0.1, the total losses are
# 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
# We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
rank_per_sample = torch.div(labels, vocab_size, rounding_mode='floor')
lse_local = lse_allgather[rank_per_sample,
torch.arange(batch, device=lse_allgather.device)]
rank_per_sample = torch.div(labels, vocab_size, rounding_mode="floor")
lse_local = lse_allgather[
rank_per_sample, torch.arange(batch, device=lse_allgather.device)
]
handle_losses.wait()
if smoothing == 0.0:
losses += lse - lse_local
else:
losses += ((1 - smoothing) * (lse - lse_local)
+ smoothing * (lse - lse_allgather.sum(dim=0)))
losses += (1 - smoothing) * (lse - lse_local) + smoothing * (
lse - lse_allgather.sum(dim=0)
)
losses.masked_fill_(ignored_mask, 0)
ctx.save_for_backward(logits, lse, labels_local)
......@@ -96,19 +106,24 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
def backward(ctx, grad_loss):
logits, lse, labels = ctx.saved_tensors
grad_loss = grad_loss.contiguous()
grad_loss.masked_fill_(labels==ctx.ignored_index, 0)
grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, lse, labels,
ctx.smoothing, ctx.inplace_backward,
ctx.total_classes)
grad_loss.masked_fill_(labels == ctx.ignored_index, 0)
grad_logits = xentropy_cuda_lib.backward(
grad_loss, logits, lse, labels, ctx.smoothing, ctx.inplace_backward, ctx.total_classes
)
return grad_logits, None, None, None, None, None, None
class CrossEntropyLoss(nn.Module):
def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
inplace_backward=False, process_group=None):
def __init__(
self,
ignore_index=-100,
reduction="mean",
label_smoothing=0.0,
inplace_backward=False,
process_group=None,
):
super().__init__()
if reduction not in ['mean', 'none']:
if reduction not in ["mean", "none"]:
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
self.ignore_index = ignore_index
self.reduction = reduction
......@@ -120,10 +135,14 @@ class CrossEntropyLoss(nn.Module):
assert input.is_cuda and target.is_cuda
# SoftmaxCrossEntropyLoss implicitly casts to float
loss = SoftmaxCrossEntropyLossFn.apply(
input, target, self.label_smoothing, self.ignore_index, self.inplace_backward,
self.process_group
input,
target,
self.label_smoothing,
self.ignore_index,
self.inplace_backward,
self.process_group,
)
if self.reduction == 'mean':
if self.reduction == "mean":
return loss.sum() / (target != self.ignore_index).sum()
else:
return loss
This diff is collapsed.
......@@ -2,93 +2,114 @@
import math
import re
from collections import OrderedDict
import torch
import torch.nn.functional as F
from einops import rearrange
from transformers import GPT2Config, FalconConfig
from transformers import FalconConfig, GPT2Config
def remap_state_dict_hf_falcon(state_dict, config):
def key_mapping_layers(key):
return re.sub(r'^transformer.h.', 'transformer.layers.', key)
return re.sub(r"^transformer.h.", "transformer.layers.", key)
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding
def key_mapping_emb(key):
return re.sub(r'^transformer.word_embeddings.', 'transformer.embeddings.word_embeddings.', key)
return re.sub(
r"^transformer.word_embeddings.", "transformer.embeddings.word_embeddings.", key
)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
if getattr(config, "tie_word_embeddings"):
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
else:
output_embeddings = state_dict.pop('lm_head.weight')
output_embeddings = state_dict.pop("lm_head.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
state_dict["lm_head.weight"] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
output_embeddings_bias = state_dict.pop('lm_head.bias')
state_dict['lm_head.bias'] = F.pad(
output_embeddings_bias = state_dict.pop("lm_head.bias")
state_dict["lm_head.bias"] = F.pad(
output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
)
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.',
r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.',
r'transformer.layers.\1.norm2.', key)
key = re.sub(r'^transformer.layers.(\d+).ln_attn.', r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).ln_mlp.', r'transformer.layers.\1.norm2.', key)
key = re.sub(
r"^transformer.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key
)
key = re.sub(
r"^transformer.layers.(\d+).post_attention_layernorm.",
r"transformer.layers.\1.norm2.",
key,
)
key = re.sub(r"^transformer.layers.(\d+).ln_attn.", r"transformer.layers.\1.norm1.", key)
key = re.sub(r"^transformer.layers.(\d+).ln_mlp.", r"transformer.layers.\1.norm2.", key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.',
r'transformer.layers.\1.mlp.fc1.', key)
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.',
r'transformer.layers.\1.mlp.fc2.', key)
key = re.sub(
r"^transformer.layers.(\d+).mlp.dense_h_to_4h.", r"transformer.layers.\1.mlp.fc1.", key
)
key = re.sub(
r"^transformer.layers.(\d+).mlp.dense_4h_to_h.", r"transformer.layers.\1.mlp.fc2.", key
)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
def key_mapping_attn(key):
key = re.sub(r'^transformer.layers.(\d+).self_attention.query_key_value.',
r'transformer.layers.\1.mixer.Wqkv.', key)
key = re.sub(r'^transformer.layers.(\d+).self_attention.dense.',
r'transformer.layers.\1.mixer.out_proj.', key)
key = re.sub(
r"^transformer.layers.(\d+).self_attention.query_key_value.",
r"transformer.layers.\1.mixer.Wqkv.",
key,
)
key = re.sub(
r"^transformer.layers.(\d+).self_attention.dense.",
r"transformer.layers.\1.mixer.out_proj.",
key,
)
return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
n_head = config.n_head
n_head_kv = getattr(config, "n_head_kv", 1)
headdim = config.hidden_size // n_head
for l in range(config.n_layer):
# The weights are stored in a different layout compared to our implementation
Wqkv = rearrange(state_dict.pop(f'transformer.layers.{l}.mixer.Wqkv.weight'),
"(group ratio headdim) ... -> group ratio headdim ...",
ratio=n_head // n_head_kv + 2, headdim=headdim)
Wqkv = rearrange(
state_dict.pop(f"transformer.layers.{l}.mixer.Wqkv.weight"),
"(group ratio headdim) ... -> group ratio headdim ...",
ratio=n_head // n_head_kv + 2,
headdim=headdim,
)
Wq = rearrange(Wqkv[:, :-2], "group ratio headdim ... -> (group ratio headdim) ...")
Wk = rearrange(Wqkv[:, [-2]], "group ratio headdim ... -> (group ratio headdim) ...")
Wv = rearrange(Wqkv[:, [-1]], "group ratio headdim ... -> (group ratio headdim) ...")
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
return state_dict
def falcon_config_to_gpt2_config(falcon_config: FalconConfig) -> GPT2Config:
# The 40b config uses "n_head_kv" instead of "num_kv_heads"
n_head_kv = getattr(falcon_config, "n_head_kv",
1 if getattr(falcon_config, "multi_query", False)
else falcon_config.n_head)
n_head_kv = getattr(
falcon_config,
"n_head_kv",
1 if getattr(falcon_config, "multi_query", False) else falcon_config.n_head,
)
# HACK: the 40b config has 2 LN per layer instead of 1, but that's not reflected in the config.
# So we have to infer it from the number of heads in the key/value block
parallel_block_tied_norm = n_head_kv == 1
......
......@@ -11,6 +11,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import GPT2Config
from flash_attn.models.falcon import remap_state_dict_hf_falcon
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
from flash_attn.models.gptj import remap_state_dict_hf_gptj
......@@ -27,10 +29,9 @@ from flash_attn.modules.mlp import (
ParallelMLP,
)
from flash_attn.ops.activations import sqrelu_fwd
from flash_attn.utils.distributed import all_gather_raw, sync_shared_params, get_dim_for_local_rank
from flash_attn.utils.distributed import all_gather_raw, get_dim_for_local_rank, sync_shared_params
from flash_attn.utils.generation import GenerationMixin
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import GPT2Config
try:
from flash_attn.ops.fused_dense import ColumnParallelLinear
......@@ -690,7 +691,7 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
if key in state_dict:
x = state_dict[key]
dim = x.shape[0] // world_size
state_dict[key] = x[rank * dim: (rank + 1) * dim]
state_dict[key] = x[rank * dim : (rank + 1) * dim]
def shard_last_dim(state_dict, key, multiple_of=1):
if key in state_dict:
......@@ -707,17 +708,19 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
x = state_dict[key]
dim = x.shape[0] // world_size // 2
state_dict[key] = rearrange(
rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim: (rank + 1) * dim],
rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim : (rank + 1) * dim],
"two o ... -> (two o) ...",
)
def shard_qkv_headdim(state_dict, key):
if key in state_dict:
n_head_each_rank = [
get_dim_for_local_rank(n_head, world_size, local_rank) for local_rank in range(world_size)
get_dim_for_local_rank(n_head, world_size, local_rank)
for local_rank in range(world_size)
]
n_head_kv_each_rank = [
get_dim_for_local_rank(n_head_kv, world_size, local_rank) for local_rank in range(world_size)
get_dim_for_local_rank(n_head_kv, world_size, local_rank)
for local_rank in range(world_size)
]
beg_n_head = sum(n_head_each_rank[:rank])
......@@ -729,7 +732,8 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
if n_head_kv == n_head:
x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3)
state_dict[key] = rearrange(
x[:, beg_n_head * head_dim : end_n_head * head_dim], "three d ... -> (three d) ..."
x[:, beg_n_head * head_dim : end_n_head * head_dim],
"three d ... -> (three d) ...",
)
else:
x = rearrange(
......@@ -741,8 +745,14 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
torch.cat(
[
x[beg_n_head:end_n_head],
x[n_head + beg_n_head_kv: n_head + end_n_head_kv],
x[n_head + n_head_kv + beg_n_head_kv: n_head + n_head_kv + end_n_head_kv],
x[n_head + beg_n_head_kv : n_head + end_n_head_kv],
x[
n_head
+ n_head_kv
+ beg_n_head_kv : n_head
+ n_head_kv
+ end_n_head_kv
],
],
dim=0,
),
......@@ -824,7 +834,7 @@ def combine_state_dicts_tp(state_dicts, config):
torch.cat([x[:n_head_per_rank] for x in xs], dim=0),
torch.cat(
[
x[n_head_per_rank: n_head_per_rank + n_head_kv_per_rank]
x[n_head_per_rank : n_head_per_rank + n_head_kv_per_rank]
for x in xs
],
dim=0,
......
......@@ -2,80 +2,100 @@
import math
import re
from collections import OrderedDict
import torch
import torch.nn.functional as F
from einops import rearrange
from transformers import GPT2Config, GPTNeoXConfig
def remap_state_dict_hf_gpt_neox(state_dict, config):
def key_mapping_layers(key):
return re.sub(r'^gpt_neox.', 'transformer.', key)
return re.sub(r"^gpt_neox.", "transformer.", key)
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding
def key_mapping_emb(key):
return re.sub(r'^transformer.embed_in.', 'transformer.embeddings.word_embeddings.', key)
return re.sub(r"^transformer.embed_in.", "transformer.embeddings.word_embeddings.", key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
if getattr(config, "tie_word_embeddings"):
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
else:
output_embeddings = state_dict.pop('embed_out.weight')
output_embeddings = state_dict.pop("embed_out.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
state_dict["lm_head.weight"] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key)
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.', r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.', r'transformer.layers.\1.norm2.', key)
key = re.sub(r"^transformer.final_layer_norm.", r"transformer.ln_f.", key)
key = re.sub(
r"^transformer.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key
)
key = re.sub(
r"^transformer.layers.(\d+).post_attention_layernorm.",
r"transformer.layers.\1.norm2.",
key,
)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.', r'transformer.layers.\1.mlp.fc1.', key)
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.', r'transformer.layers.\1.mlp.fc2.', key)
key = re.sub(
r"^transformer.layers.(\d+).mlp.dense_h_to_4h.", r"transformer.layers.\1.mlp.fc1.", key
)
key = re.sub(
r"^transformer.layers.(\d+).mlp.dense_4h_to_h.", r"transformer.layers.\1.mlp.fc2.", key
)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for l in range(config.n_layer):
# We don't store these biases
state_dict.pop(f'transformer.layers.{l}.attention.bias')
state_dict.pop(f'transformer.layers.{l}.attention.masked_bias')
state_dict.pop(f"transformer.layers.{l}.attention.bias")
state_dict.pop(f"transformer.layers.{l}.attention.masked_bias")
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
headdim = config.hidden_size // config.num_attention_heads
Wqkv = state_dict.pop(f'transformer.layers.{l}.attention.query_key_value.weight')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = rearrange(
Wqkv, '(nheads three headdim) ... -> (three nheads headdim) ...',
three=3, headdim=headdim
Wqkv = state_dict.pop(f"transformer.layers.{l}.attention.query_key_value.weight")
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = rearrange(
Wqkv,
"(nheads three headdim) ... -> (three nheads headdim) ...",
three=3,
headdim=headdim,
)
bqkv = state_dict.pop(f'transformer.layers.{l}.attention.query_key_value.bias')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = rearrange(
bqkv, '(nheads three headdim) -> (three nheads headdim)',
three=3, headdim=headdim
bqkv = state_dict.pop(f"transformer.layers.{l}.attention.query_key_value.bias")
state_dict[f"transformer.layers.{l}.mixer.Wqkv.bias"] = rearrange(
bqkv, "(nheads three headdim) -> (three nheads headdim)", three=3, headdim=headdim
)
def key_mapping_attn(key):
key = re.sub(r'^transformer.layers.(\d+).attention.dense.',
r'transformer.layers.\1.mixer.out_proj.', key)
key = re.sub(r'^transformer.layers.(\d+).attention.rotary_emb.',
r'transformer.layers.\1.mixer.rotary_emb.', key)
key = re.sub(
r"^transformer.layers.(\d+).attention.dense.",
r"transformer.layers.\1.mixer.out_proj.",
key,
)
key = re.sub(
r"^transformer.layers.(\d+).attention.rotary_emb.",
r"transformer.layers.\1.mixer.rotary_emb.",
key,
)
return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
......
......@@ -2,67 +2,78 @@
import math
import re
from collections import OrderedDict
import torch
import torch.nn.functional as F
from transformers import GPT2Config, GPTJConfig
def remap_state_dict_hf_gptj(state_dict, config):
def key_mapping_layers(key):
return re.sub(r'^transformer.h.', 'transformer.layers.', key)
return re.sub(r"^transformer.h.", "transformer.layers.", key)
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding
def key_mapping_emb(key):
return re.sub(r'^transformer.wte.', 'transformer.embeddings.word_embeddings.', key)
return re.sub(r"^transformer.wte.", "transformer.embeddings.word_embeddings.", key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
if getattr(config, "tie_word_embeddings"):
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
else:
output_embeddings = state_dict.pop('lm_head.weight')
output_embeddings = state_dict.pop("lm_head.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
state_dict["lm_head.weight"] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
output_embeddings_bias = state_dict.pop('lm_head.bias')
state_dict['lm_head.bias'] = F.pad(
output_embeddings_bias = state_dict.pop("lm_head.bias")
state_dict["lm_head.bias"] = F.pad(
output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
)
# LayerNorm
def key_mapping_ln(key):
return re.sub(r'^transformer.layers.(\d+).ln_1.', r'transformer.layers.\1.norm1.', key)
return re.sub(r"^transformer.layers.(\d+).ln_1.", r"transformer.layers.\1.norm1.", key)
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
key = re.sub(r'^transformer.layers.(\d+).mlp.fc_in.', r'transformer.layers.\1.mlp.fc1.', key)
key = re.sub(r'^transformer.layers.(\d+).mlp.fc_out.', r'transformer.layers.\1.mlp.fc2.', key)
key = re.sub(
r"^transformer.layers.(\d+).mlp.fc_in.", r"transformer.layers.\1.mlp.fc1.", key
)
key = re.sub(
r"^transformer.layers.(\d+).mlp.fc_out.", r"transformer.layers.\1.mlp.fc2.", key
)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for l in range(config.n_layer):
Wq = state_dict.pop(f'transformer.layers.{l}.attn.q_proj.weight')
Wk = state_dict.pop(f'transformer.layers.{l}.attn.k_proj.weight')
Wv = state_dict.pop(f'transformer.layers.{l}.attn.v_proj.weight')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
Wq = state_dict.pop(f"transformer.layers.{l}.attn.q_proj.weight")
Wk = state_dict.pop(f"transformer.layers.{l}.attn.k_proj.weight")
Wv = state_dict.pop(f"transformer.layers.{l}.attn.v_proj.weight")
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
# We don't store these biases
state_dict.pop(f'transformer.layers.{l}.attn.bias')
state_dict.pop(f'transformer.layers.{l}.attn.masked_bias')
state_dict.pop(f"transformer.layers.{l}.attn.bias")
state_dict.pop(f"transformer.layers.{l}.attn.masked_bias")
def key_mapping_attn(key):
return re.sub(r'^transformer.layers.(\d+).attn.out_proj.',
r'transformer.layers.\1.mixer.out_proj.', key)
return re.sub(
r"^transformer.layers.(\d+).attn.out_proj.",
r"transformer.layers.\1.mixer.out_proj.",
key,
)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
......
......@@ -15,63 +15,81 @@ from transformers import GPT2Config, LlamaConfig
def remap_state_dict_meta_llama(state_dict, config):
def key_mapping_layers(key):
return f'transformer.{key}' if not key.startswith('output.') else key
return f"transformer.{key}" if not key.startswith("output.") else key
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding
def key_mapping_emb(key):
return re.sub(r'^transformer.tok_embeddings.', 'transformer.embeddings.word_embeddings.', key)
return re.sub(
r"^transformer.tok_embeddings.", "transformer.embeddings.word_embeddings.", key
)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = (
math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
)
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
if getattr(config, "tie_word_embeddings"):
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
else:
output_embeddings = state_dict.pop('output.weight')
output_embeddings = state_dict.pop("output.weight")
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
# differently.
vocab_size = (math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
vocab_size = (
math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
state_dict["lm_head.weight"] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.norm.', r'transformer.ln_f.', key)
key = re.sub(r'^transformer.layers.(\d+).attention_norm.', r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).ffn_norm.', r'transformer.layers.\1.norm2.', key)
key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key)
key = re.sub(
r"^transformer.layers.(\d+).attention_norm.", r"transformer.layers.\1.norm1.", key
)
key = re.sub(r"^transformer.layers.(\d+).ffn_norm.", r"transformer.layers.\1.norm2.", key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
for l in range(config.n_layer):
w1 = state_dict.pop(f'transformer.layers.{l}.feed_forward.w1.weight')
w3 = state_dict.pop(f'transformer.layers.{l}.feed_forward.w3.weight')
w1 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w1.weight")
w3 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w3.weight")
# Our ordering is different
state_dict[f'transformer.layers.{l}.mlp.fc1.weight'] = torch.cat([w3, w1], dim=0)
state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0)
def key_mapping_mlp(key):
return re.sub(r'^transformer.layers.(\d+).feed_forward.w2.',
r'transformer.layers.\1.mlp.fc2.', key)
return re.sub(
r"^transformer.layers.(\d+).feed_forward.w2.", r"transformer.layers.\1.mlp.fc2.", key
)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for l in range(config.n_layer):
Wq = state_dict.pop(f'transformer.layers.{l}.attention.wq.weight')
Wk = state_dict.pop(f'transformer.layers.{l}.attention.wk.weight')
Wv = state_dict.pop(f'transformer.layers.{l}.attention.wv.weight')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
Wq = state_dict.pop(f"transformer.layers.{l}.attention.wq.weight")
Wk = state_dict.pop(f"transformer.layers.{l}.attention.wk.weight")
Wv = state_dict.pop(f"transformer.layers.{l}.attention.wv.weight")
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
# We don't store these
state_dict.pop(f'transformer.layers.{l}.attention.inner_attention.rope.freqs', None)
state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None)
def key_mapping_attn(key):
return re.sub(r'^transformer.layers.(\d+).attention.wo.',
r'transformer.layers.\1.mixer.out_proj.', key)
return re.sub(
r"^transformer.layers.(\d+).attention.wo.",
r"transformer.layers.\1.mixer.out_proj.",
key,
)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
state_dict.pop("transformer.rope.freqs", None)
......@@ -82,29 +100,32 @@ def remap_state_dict_meta_llama(state_dict, config):
def remap_state_dict_hf_llama(state_dict, config):
# Embedding
def key_mapping_emb(key):
return re.sub(r'^model.embed_tokens.', 'transformer.embeddings.word_embeddings.', key)
return re.sub(r"^model.embed_tokens.", "transformer.embeddings.word_embeddings.", key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = (
math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
)
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
# LM head
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
if getattr(config, "tie_word_embeddings"):
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
else:
output_embeddings = state_dict.pop('lm_head.weight')
output_embeddings = state_dict.pop("lm_head.weight")
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
# differently.
vocab_size = (math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
vocab_size = (
math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
state_dict["lm_head.weight"] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
......@@ -113,21 +134,22 @@ def remap_state_dict_hf_llama(state_dict, config):
# Fusing weights this way based on difference in the following:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220
# https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115
w1 = state_dict.pop(f'model.layers.{l}.mlp.gate_proj.weight')
w3 = state_dict.pop(f'model.layers.{l}.mlp.up_proj.weight')
state_dict[f'transformer.layers.{l}.mlp.fc1.weight'] = torch.cat([w3, w1], dim=0)
w1 = state_dict.pop(f"model.layers.{l}.mlp.gate_proj.weight")
w3 = state_dict.pop(f"model.layers.{l}.mlp.up_proj.weight")
state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0)
def key_mapping_mlp(key):
return re.sub(r'^model.layers.(\d+).mlp.down_proj.',
r'transformer.layers.\1.mlp.fc2.', key)
return re.sub(r"^model.layers.(\d+).mlp.down_proj.", r"transformer.layers.\1.mlp.fc2.", key)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^model.norm.', r'transformer.ln_f.', key)
key = re.sub(r'^model.layers.(\d+).input_layernorm.', r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^model.layers.(\d+).post_attention_layernorm.', r'transformer.layers.\1.norm2.', key)
key = re.sub(r"^model.norm.", r"transformer.ln_f.", key)
key = re.sub(r"^model.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key)
key = re.sub(
r"^model.layers.(\d+).post_attention_layernorm.", r"transformer.layers.\1.norm2.", key
)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
......@@ -135,42 +157,52 @@ def remap_state_dict_hf_llama(state_dict, config):
def inv_permute(w):
# Inverse of permute implemented in:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
return w.reshape(
config.n_head, 2, config.n_embd // config.n_head // 2, config.n_embd
).transpose(1, 2).reshape(config.n_embd, config.n_embd)
return (
w.reshape(config.n_head, 2, config.n_embd // config.n_head // 2, config.n_embd)
.transpose(1, 2)
.reshape(config.n_embd, config.n_embd)
)
# Attention
for l in range(config.n_layer):
Wq = state_dict.pop(f'model.layers.{l}.self_attn.q_proj.weight')
Wk = state_dict.pop(f'model.layers.{l}.self_attn.k_proj.weight')
Wv = state_dict.pop(f'model.layers.{l}.self_attn.v_proj.weight')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat(
Wq = state_dict.pop(f"model.layers.{l}.self_attn.q_proj.weight")
Wk = state_dict.pop(f"model.layers.{l}.self_attn.k_proj.weight")
Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight")
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat(
[inv_permute(Wq), inv_permute(Wk), Wv], dim=0
)
# We don't store these
state_dict.pop(f'model.layers.{l}.self_attn.rotary_emb.inv_freq', None)
state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None)
def key_mapping_attn(key):
return re.sub(r'^model.layers.(\d+).self_attn.o_proj.',
r'transformer.layers.\1.mixer.out_proj.', key)
return re.sub(
r"^model.layers.(\d+).self_attn.o_proj.", r"transformer.layers.\1.mixer.out_proj.", key
)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
def config_from_meta_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> LlamaConfig:
def config_from_meta_checkpoint(
checkpoint_path: Union[str, os.PathLike], model_name: str
) -> LlamaConfig:
"""Load a LlamaConfig from a checkpoint path."""
with open(Path(checkpoint_path) / model_name / 'params.json') as f:
with open(Path(checkpoint_path) / model_name / "params.json") as f:
params = json.load(f)
config = LlamaConfig(hidden_size=params['dim'], intermediate_size=None,
num_attention_heads=params['n_heads'],
num_hidden_layers=params['n_layers'],
rms_norm_eps=params['norm_eps'])
config = LlamaConfig(
hidden_size=params["dim"],
intermediate_size=None,
num_attention_heads=params["n_heads"],
num_hidden_layers=params["n_layers"],
rms_norm_eps=params["norm_eps"],
)
return config
def config_from_hf_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> LlamaConfig:
return LlamaConfig.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf' / "config.json")
def config_from_hf_checkpoint(
checkpoint_path: Union[str, os.PathLike], model_name: str
) -> LlamaConfig:
return LlamaConfig.from_pretrained(Path(checkpoint_path) / f"{model_name}-hf" / "config.json")
def config_from_checkpoint(
......@@ -182,10 +214,14 @@ def config_from_checkpoint(
return config_from_hf_checkpoint(checkpoint_path, model_name)
def state_dicts_from_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> list[dict]:
def state_dicts_from_checkpoint(
checkpoint_path: Union[str, os.PathLike], model_name: str
) -> list[dict]:
# Need to sort, otherwise we mess up the ordering and the weights are wrong
return [torch.load(path, map_location='cpu')
for path in sorted((Path(checkpoint_path) / model_name).glob('consolidated.*.pth'))]
return [
torch.load(path, map_location="cpu")
for path in sorted((Path(checkpoint_path) / model_name).glob("consolidated.*.pth"))
]
def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
......@@ -196,7 +232,7 @@ def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
n_layer=llama_config.num_hidden_layers,
n_head=llama_config.num_attention_heads,
n_inner=llama_config.intermediate_size,
activation_function='swiglu', # Hardcode since HF calls it 'silu'
activation_function="swiglu", # Hardcode since HF calls it 'silu'
# Llama doesn't have dropout, idk if it's because they only release the inference code
resid_pdrop=0.0,
embd_pdrop=0.0,
......
......@@ -2,75 +2,86 @@
import math
import re
from collections import OrderedDict
import torch
import torch.nn.functional as F
from transformers import GPT2Config, OPTConfig
def remap_state_dict_hf_opt(state_dict, config):
def key_mapping_model(key):
key = re.sub(r'^model.decoder.', 'transformer.', key)
key = re.sub(r"^model.decoder.", "transformer.", key)
# The OPT-350m model uses '^decoder' instead of '^model.decoder'
key = re.sub(r'^decoder.', 'transformer.', key)
key = re.sub(r"^decoder.", "transformer.", key)
return key
state_dict = OrderedDict((key_mapping_model(k), v) for k, v in state_dict.items())
# Word embedding and position embedding
def key_mapping_emb(key):
key = re.sub(r'^transformer.embed_tokens.', 'transformer.embeddings.word_embeddings.', key)
key = re.sub(r"^transformer.embed_tokens.", "transformer.embeddings.word_embeddings.", key)
# The OPT-350m model uses has project_in and project_out
key = re.sub(r'^transformer.project_in.', 'transformer.embeddings.project_in.', key)
key = re.sub(r'^transformer.project_out.', 'project_out.', key)
key = re.sub(r'^transformer.embed_positions.',
'transformer.embeddings.position_embeddings.', key)
key = re.sub(r"^transformer.project_in.", "transformer.embeddings.project_in.", key)
key = re.sub(r"^transformer.project_out.", "project_out.", key)
key = re.sub(
r"^transformer.embed_positions.", "transformer.embeddings.position_embeddings.", key
)
return key
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
# OPT uses the first 2 indices of pos_emb for padding tokens
pos_embeddings = state_dict.pop('transformer.embeddings.position_embeddings.weight')
state_dict['transformer.embeddings.position_embeddings.weight'] = pos_embeddings[2:]
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
pos_embeddings = state_dict.pop("transformer.embeddings.position_embeddings.weight")
state_dict["transformer.embeddings.position_embeddings.weight"] = pos_embeddings[2:]
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key)
key = re.sub(r"^transformer.final_layer_norm.", r"transformer.ln_f.", key)
# The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm'
key = re.sub(r'^transformer.layer_norm.', r'transformer.ln_f.', key)
key = re.sub(r'^transformer.layers.(\d+).self_attn_layer_norm.',
r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).final_layer_norm.',
r'transformer.layers.\1.norm2.', key)
key = re.sub(r"^transformer.layer_norm.", r"transformer.ln_f.", key)
key = re.sub(
r"^transformer.layers.(\d+).self_attn_layer_norm.", r"transformer.layers.\1.norm1.", key
)
key = re.sub(
r"^transformer.layers.(\d+).final_layer_norm.", r"transformer.layers.\1.norm2.", key
)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
return re.sub(r'^transformer.layers.(\d+).fc(1|2).',
r'transformer.layers.\1.mlp.fc\2.', key)
return re.sub(
r"^transformer.layers.(\d+).fc(1|2).", r"transformer.layers.\1.mlp.fc\2.", key
)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for l in range(config.n_layer):
Wq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.weight')
Wk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.weight')
Wv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.weight')
bq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.bias')
bk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.bias')
bv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.bias')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = torch.cat([bq, bk, bv], dim=0)
Wq = state_dict.pop(f"transformer.layers.{l}.self_attn.q_proj.weight")
Wk = state_dict.pop(f"transformer.layers.{l}.self_attn.k_proj.weight")
Wv = state_dict.pop(f"transformer.layers.{l}.self_attn.v_proj.weight")
bq = state_dict.pop(f"transformer.layers.{l}.self_attn.q_proj.bias")
bk = state_dict.pop(f"transformer.layers.{l}.self_attn.k_proj.bias")
bv = state_dict.pop(f"transformer.layers.{l}.self_attn.v_proj.bias")
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
state_dict[f"transformer.layers.{l}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
def key_mapping_attn(key):
return re.sub(r'^transformer.layers.(\d+).self_attn.out_proj.',
r'transformer.layers.\1.mixer.out_proj.', key)
return re.sub(
r"^transformer.layers.(\d+).self_attn.out_proj.",
r"transformer.layers.\1.mixer.out_proj.",
key,
)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
......@@ -79,8 +90,11 @@ def remap_state_dict_hf_opt(state_dict, config):
def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config:
assert opt_config.layerdrop == 0.0
assert opt_config.layer_norm_elementwise_affine
word_embed_proj_dim = (None if opt_config.word_embed_proj_dim == opt_config.hidden_size
else opt_config.word_embed_proj_dim)
word_embed_proj_dim = (
None
if opt_config.word_embed_proj_dim == opt_config.hidden_size
else opt_config.word_embed_proj_dim
)
return GPT2Config(
vocab_size=opt_config.vocab_size,
n_positions=opt_config.max_position_embeddings,
......@@ -98,5 +112,5 @@ def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config:
eos_token_id=opt_config.eos_token_id,
# These are new arguments not in the original GPT2Config
prenorm=opt_config.do_layer_norm_before,
word_embed_proj_dim=word_embed_proj_dim
word_embed_proj_dim=word_embed_proj_dim,
)
......@@ -10,13 +10,14 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from timm.models.helpers import named_apply
from torch.nn.init import trunc_normal_
from torchvision.ops import StochasticDepth
from flash_attn.layers.patch_embed import PatchEmbed
from flash_attn.modules.block import Block
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import FusedMLP, Mlp
from timm.models.helpers import named_apply
from torch.nn.init import trunc_normal_
from torchvision.ops import StochasticDepth
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment