"vscode:/vscode.git/clone" did not exist on "79292ff3e06219bc67d06932cee98a5ad2ce5c04"
Commit f1a73d07 authored by Tri Dao's avatar Tri Dao
Browse files

Run isort and black on python files

parent cbb4cf5f
__version__ = "2.0.8" __version__ = "2.0.8"
from flash_attn.flash_attn_interface import flash_attn_func from flash_attn.flash_attn_interface import (
from flash_attn.flash_attn_interface import flash_attn_kvpacked_func flash_attn_func,
from flash_attn.flash_attn_interface import flash_attn_qkvpacked_func flash_attn_kvpacked_func,
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func flash_attn_qkvpacked_func,
from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func flash_attn_varlen_func,
from flash_attn.flash_attn_interface import flash_attn_varlen_func flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
...@@ -2,12 +2,10 @@ ...@@ -2,12 +2,10 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
class IndexFirstAxis(torch.autograd.Function): class IndexFirstAxis(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input, indices): def forward(ctx, input, indices):
ctx.save_for_backward(indices) ctx.save_for_backward(indices)
...@@ -16,20 +14,24 @@ class IndexFirstAxis(torch.autograd.Function): ...@@ -16,20 +14,24 @@ class IndexFirstAxis(torch.autograd.Function):
second_dim = other_shape.numel() second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices] # return input[indices]
return torch.gather(rearrange(input, 'b ... -> b (...)'), 0, return torch.gather(
repeat(indices, 'z -> z d', d=second_dim)).reshape(-1, *other_shape) rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
).reshape(-1, *other_shape)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
indices, = ctx.saved_tensors (indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2 assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:] other_shape = grad_output.shape[1:]
grad_output = rearrange(grad_output, 'b ... -> b (...)') grad_output = rearrange(grad_output, "b ... -> b (...)")
grad_input = torch.zeros([ctx.first_axis_dim, grad_output.shape[1]], grad_input = torch.zeros(
device=grad_output.device, dtype=grad_output.dtype) [ctx.first_axis_dim, grad_output.shape[1]],
device=grad_output.device,
dtype=grad_output.dtype,
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# grad_input[indices] = grad_output # grad_input[indices] = grad_output
grad_input.scatter_(0, repeat(indices, 'z -> z d', d=grad_output.shape[1]), grad_output) grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
...@@ -37,14 +39,14 @@ index_first_axis = IndexFirstAxis.apply ...@@ -37,14 +39,14 @@ index_first_axis = IndexFirstAxis.apply
class IndexPutFirstAxis(torch.autograd.Function): class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, values, indices, first_axis_dim): def forward(ctx, values, indices, first_axis_dim):
ctx.save_for_backward(indices) ctx.save_for_backward(indices)
assert indices.ndim == 1 assert indices.ndim == 1
assert values.ndim >= 2 assert values.ndim >= 2
output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, output = torch.zeros(
dtype=values.dtype) first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output[indices] = values output[indices] = values
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
...@@ -52,7 +54,7 @@ class IndexPutFirstAxis(torch.autograd.Function): ...@@ -52,7 +54,7 @@ class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
indices, = ctx.saved_tensors (indices,) = ctx.saved_tensors
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
grad_values = grad_output[indices] grad_values = grad_output[indices]
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
...@@ -63,7 +65,6 @@ index_put_first_axis = IndexPutFirstAxis.apply ...@@ -63,7 +65,6 @@ index_put_first_axis = IndexPutFirstAxis.apply
class IndexFirstAxisResidual(torch.autograd.Function): class IndexFirstAxisResidual(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input, indices): def forward(ctx, input, indices):
ctx.save_for_backward(indices) ctx.save_for_backward(indices)
...@@ -79,7 +80,7 @@ class IndexFirstAxisResidual(torch.autograd.Function): ...@@ -79,7 +80,7 @@ class IndexFirstAxisResidual(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output, grad_residual): def backward(ctx, grad_output, grad_residual):
indices, = ctx.saved_tensors (indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2 assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:] other_shape = grad_output.shape[1:]
assert grad_residual.shape[1:] == other_shape assert grad_residual.shape[1:] == other_shape
...@@ -113,8 +114,12 @@ def unpad_input(hidden_states, attention_mask): ...@@ -113,8 +114,12 @@ def unpad_input(hidden_states, attention_mask):
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be, # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster. # so we write custom forward and backward to make it a bit faster.
return (index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'), indices), indices, return (
cu_seqlens, max_seqlen_in_batch) index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def pad_input(hidden_states, indices, batch, seqlen): def pad_input(hidden_states, indices, batch, seqlen):
...@@ -129,4 +134,4 @@ def pad_input(hidden_states, indices, batch, seqlen): ...@@ -129,4 +134,4 @@ def pad_input(hidden_states, indices, batch, seqlen):
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output[indices] = hidden_states # output[indices] = hidden_states
output = index_put_first_axis(hidden_states, indices, batch * seqlen) output = index_put_first_axis(hidden_states, indices, batch * seqlen)
return rearrange(output, '(b s) ... -> b s ...', b=batch) return rearrange(output, "(b s) ... -> b s ...", b=batch)
import flash_attn_2_cuda as flash_attn_cuda
import torch import torch
import torch.nn as nn import torch.nn as nn
import flash_attn_2_cuda as flash_attn_cuda
from einops import rearrange from einops import rearrange
...@@ -45,40 +44,109 @@ def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softma ...@@ -45,40 +44,109 @@ def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softma
return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, def _flash_attn_varlen_forward(
dropout_p, softmax_scale, causal, return_softmax): q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
return_softmax,
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, q,
softmax_scale, False, causal, return_softmax, None k,
v,
None,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
False,
causal,
return_softmax,
None,
) )
# if out.isnan().any() or softmax_lse.isnan().any(): # if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint() # breakpoint()
return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, def _flash_attn_backward(
dropout_p, softmax_scale, causal, rng_state=None): dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, rng_state=None
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous # dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd( dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, dout,
softmax_scale, causal, None, rng_state q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
dropout_p,
softmax_scale,
causal,
None,
rng_state,
) )
return dq, dk, dv, softmax_d return dq, dk, dv, softmax_d
def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, def _flash_attn_varlen_backward(
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dout,
dropout_p, softmax_scale, causal, rng_state=None): q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
rng_state=None,
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous # dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd( dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, dout,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None, rng_state q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
False,
causal,
None,
rng_state,
) )
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint() # breakpoint()
...@@ -86,14 +154,18 @@ def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, ...@@ -86,14 +154,18 @@ def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
class FlashAttnQKVPackedFunc(torch.autograd.Function): class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax): def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5) softmax_scale = qkv.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], dropout_p, softmax_scale, qkv[:, :, 0],
causal=causal, return_softmax=return_softmax and dropout_p > 0 qkv[:, :, 1],
qkv[:, :, 2],
dropout_p,
softmax_scale,
causal=causal,
return_softmax=return_softmax and dropout_p > 0,
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
...@@ -107,22 +179,41 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -107,22 +179,41 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
_flash_attn_backward( _flash_attn_backward(
dout, q, k, v, out, softmax_lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], dout,
ctx.dropout_p, ctx.softmax_scale, ctx.causal, rng_state=rng_state q,
k,
v,
out,
softmax_lse,
dqkv[:, :, 0],
dqkv[:, :, 1],
dqkv[:, :, 2],
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
rng_state=rng_state,
) )
dqkv = dqkv[..., :dout.shape[-1]] # We could have padded the head dimension dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None return dqkv, None, None, None, None
class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax): def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5) softmax_scale = qkv.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, qkv[:, 0],
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0 qkv[:, 1],
qkv[:, 2],
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
dropout_p,
softmax_scale,
causal=causal,
return_softmax=return_softmax and dropout_p > 0,
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
...@@ -137,23 +228,41 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -137,23 +228,41 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
_flash_attn_varlen_backward( _flash_attn_varlen_backward(
dout, q, k, v, out, softmax_lse, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], dout,
cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen, q,
ctx.dropout_p, ctx.softmax_scale, ctx.causal, rng_state=rng_state k,
v,
out,
softmax_lse,
dqkv[:, 0],
dqkv[:, 1],
dqkv[:, 2],
cu_seqlens,
cu_seqlens,
ctx.max_seqlen,
ctx.max_seqlen,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
rng_state=rng_state,
) )
dqkv = dqkv[..., :dout.shape[-1]] # We could have padded the head dimension dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None return dqkv, None, None, None, None, None, None
class FlashAttnKVPackedFunc(torch.autograd.Function): class FlashAttnKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax): def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
q, kv[:, :, 0], kv[:, :, 1], dropout_p, softmax_scale, causal=causal, q,
return_softmax=return_softmax and dropout_p > 0 kv[:, :, 0],
kv[:, :, 1],
dropout_p,
softmax_scale,
causal=causal,
return_softmax=return_softmax and dropout_p > 0,
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
...@@ -168,28 +277,58 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -168,28 +277,58 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
_flash_attn_backward( _flash_attn_backward(
dout, q, k, v, out, softmax_lse, dout,
dq, dkv[:, :, 0], dkv[:, :, 1], ctx.dropout_p, ctx.softmax_scale, ctx.causal, q,
rng_state=rng_state k,
v,
out,
softmax_lse,
dq,
dkv[:, :, 0],
dkv[:, :, 1],
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
rng_state=rng_state,
) )
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., :dout.shape[-1]] dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None return dq, dkv, None, None, None, None
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, def forward(
softmax_scale, causal, return_softmax): ctx,
q,
kv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
return_softmax,
):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
q, kv[:, 0], kv[:, 1], cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q,
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0 kv[:, 0],
kv[:, 1],
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal=causal,
return_softmax=return_softmax and dropout_p > 0,
)
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse,
cu_seqlens_q, cu_seqlens_k, rng_state)
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k ctx.max_seqlen_k = max_seqlen_k
...@@ -204,24 +343,42 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -204,24 +343,42 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
_flash_attn_varlen_backward( _flash_attn_varlen_backward(
dout, q, k, v, out, softmax_lse, dq, dkv[:, 0], dkv[:, 1], dout,
cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k, q,
ctx.dropout_p, ctx.softmax_scale, ctx.causal, rng_state=rng_state k,
v,
out,
softmax_lse,
dq,
dkv[:, 0],
dkv[:, 1],
cu_seqlens_q,
cu_seqlens_k,
ctx.max_seqlen_q,
ctx.max_seqlen_k,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
rng_state=rng_state,
) )
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., :dout.shape[-1]] dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None, None, None, None return dq, dkv, None, None, None, None, None, None, None, None
class FlashAttnFunc(torch.autograd.Function): class FlashAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax): def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
q, k, v, dropout_p, softmax_scale, causal=causal, q,
return_softmax=return_softmax and dropout_p > 0 k,
v,
dropout_p,
softmax_scale,
causal=causal,
return_softmax=return_softmax and dropout_p > 0,
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
...@@ -234,29 +391,60 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -234,29 +391,60 @@ class FlashAttnFunc(torch.autograd.Function):
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_backward( _flash_attn_backward(
dout, q, k, v, out, softmax_lse, dout,
dq, dk, dv, ctx.dropout_p, ctx.softmax_scale, ctx.causal, q,
rng_state=rng_state k,
v,
out,
softmax_lse,
dq,
dk,
dv,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
rng_state=rng_state,
) )
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., :dout.shape[-1]] dk = dk[..., : dout.shape[-1]]
dv = dv[..., :dout.shape[-1]] dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None return dq, dk, dv, None, None, None, None, None, None, None, None
class FlashAttnVarlenFunc(torch.autograd.Function): class FlashAttnVarlenFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, def forward(
softmax_scale, causal, return_softmax): ctx,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
return_softmax,
):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q,
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0 k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal=causal,
return_softmax=return_softmax and dropout_p > 0,
)
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse,
cu_seqlens_q, cu_seqlens_k, rng_state)
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k ctx.max_seqlen_k = max_seqlen_k
...@@ -269,18 +457,33 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -269,18 +457,33 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_varlen_backward( _flash_attn_varlen_backward(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, dout,
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal, q,
rng_state=rng_state k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
ctx.max_seqlen_q,
ctx.max_seqlen_k,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
rng_state=rng_state,
) )
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., :dout.shape[-1]] dk = dk[..., : dout.shape[-1]]
dv = dv[..., :dout.shape[-1]] dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None return dq, dk, dv, None, None, None, None, None, None, None, None
def flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False, def flash_attn_qkvpacked_func(
return_attn_probs=False): qkv, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False
):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
...@@ -309,8 +512,9 @@ def flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=Fal ...@@ -309,8 +512,9 @@ def flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=Fal
return FlashAttnQKVPackedFunc.apply(qkv, dropout_p, softmax_scale, causal, return_attn_probs) return FlashAttnQKVPackedFunc.apply(qkv, dropout_p, softmax_scale, causal, return_attn_probs)
def flash_attn_kvpacked_func(q, kv, dropout_p=0.0, softmax_scale=None, causal=False, def flash_attn_kvpacked_func(
return_attn_probs=False): q, kv, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False
):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than If K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
...@@ -342,8 +546,9 @@ def flash_attn_kvpacked_func(q, kv, dropout_p=0.0, softmax_scale=None, causal=Fa ...@@ -342,8 +546,9 @@ def flash_attn_kvpacked_func(q, kv, dropout_p=0.0, softmax_scale=None, causal=Fa
return FlashAttnKVPackedFunc.apply(q, kv, dropout_p, softmax_scale, causal, return_attn_probs) return FlashAttnKVPackedFunc.apply(q, kv, dropout_p, softmax_scale, causal, return_attn_probs)
def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, def flash_attn_func(
return_attn_probs=False): q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False
):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
...@@ -373,8 +578,15 @@ def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, ...@@ -373,8 +578,15 @@ def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
return FlashAttnFunc.apply(q, k, v, dropout_p, softmax_scale, causal, return_attn_probs) return FlashAttnFunc.apply(q, k, v, dropout_p, softmax_scale, causal, return_attn_probs)
def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0, softmax_scale=None, def flash_attn_varlen_qkvpacked_func(
causal=False, return_attn_probs=False): qkv,
cu_seqlens,
max_seqlen,
dropout_p=0.0,
softmax_scale=None,
causal=False,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
...@@ -408,9 +620,18 @@ def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0, ...@@ -408,9 +620,18 @@ def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0,
) )
def flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, def flash_attn_varlen_kvpacked_func(
dropout_p=0.0, softmax_scale=None, causal=False, q,
return_attn_probs=False): kv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=False,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than If K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
...@@ -446,14 +667,32 @@ def flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqle ...@@ -446,14 +667,32 @@ def flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqle
pattern (negative means that location was dropped, nonnegative means it was kept). pattern (negative means that location was dropped, nonnegative means it was kept).
""" """
return FlashAttnVarlenKVPackedFunc.apply( return FlashAttnVarlenKVPackedFunc.apply(
q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q,
dropout_p, softmax_scale, causal, return_attn_probs kv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
return_attn_probs,
) )
def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, def flash_attn_varlen_func(
dropout_p=0.0, softmax_scale=None, causal=False, q,
return_attn_probs=False): k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=False,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
...@@ -487,6 +726,15 @@ def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, ma ...@@ -487,6 +726,15 @@ def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, ma
pattern (negative means that location was dropped, nonnegative means it was kept). pattern (negative means that location was dropped, nonnegative means it was kept).
""" """
return FlashAttnVarlenFunc.apply( return FlashAttnVarlenFunc.apply(
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q,
dropout_p, softmax_scale, causal, return_attn_probs k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
return_attn_probs,
) )
...@@ -42,7 +42,6 @@ than CUDA forward + backward. ...@@ -42,7 +42,6 @@ than CUDA forward + backward.
import math import math
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -65,21 +64,44 @@ import triton.language as tl ...@@ -65,21 +64,44 @@ import triton.language as tl
) )
@triton.jit @triton.jit
def _fwd_kernel( def _fwd_kernel(
Q, K, V, Bias, Out, Q,
Lse, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug K,
V,
Bias,
Out,
Lse,
TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
softmax_scale, softmax_scale,
stride_qb, stride_qh, stride_qm, stride_qb,
stride_kb, stride_kh, stride_kn, stride_qh,
stride_vb, stride_vh, stride_vn, stride_qm,
stride_bb, stride_bh, stride_bm, stride_kb,
stride_ob, stride_oh, stride_om, stride_kh,
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, stride_kn,
CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, stride_vb,
stride_vh,
stride_vn,
stride_bb,
stride_bh,
stride_bm,
stride_ob,
stride_oh,
stride_om,
nheads,
seqlen_q,
seqlen_k,
seqlen_q_rounded,
headdim,
CACHE_KEY_SEQLEN_Q,
CACHE_KEY_SEQLEN_K,
BIAS_TYPE: tl.constexpr, BIAS_TYPE: tl.constexpr,
IS_CAUSAL: tl.constexpr, IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, EVEN_N: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
): ):
start_m = tl.program_id(0) start_m = tl.program_id(0)
off_hb = tl.program_id(1) off_hb = tl.program_id(1)
...@@ -96,13 +118,24 @@ def _fwd_kernel( ...@@ -96,13 +118,24 @@ def _fwd_kernel(
# Adding parenthesis around indexing might use int32 math instead of int64 math? # Adding parenthesis around indexing might use int32 math instead of int64 math?
# https://github.com/openai/triton/issues/741 # https://github.com/openai/triton/issues/741
# I'm seeing a tiny bit of difference (5-7us) # I'm seeing a tiny bit of difference (5-7us)
q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) q_ptrs = (
k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) )
if BIAS_TYPE == 'vector': k_ptrs = (
K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
)
v_ptrs = (
V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
)
if BIAS_TYPE == "vector":
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
elif BIAS_TYPE == 'matrix': elif BIAS_TYPE == "matrix":
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :]) b_ptrs = (
Bias
+ off_b * stride_bb
+ off_h * stride_bh
+ (offs_m[:, None] * stride_bm + offs_n[None, :])
)
# initialize pointer to m and l # initialize pointer to m and l
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
...@@ -120,8 +153,9 @@ def _fwd_kernel( ...@@ -120,8 +153,9 @@ def _fwd_kernel(
if EVEN_HEADDIM: if EVEN_HEADDIM:
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
else: else:
q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), q = tl.load(
other=0.0) q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
)
# loop over k, v and update accumulator # loop over k, v and update accumulator
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
for start_n in range(0, end_n, BLOCK_N): for start_n in range(0, end_n, BLOCK_N):
...@@ -134,12 +168,17 @@ def _fwd_kernel( ...@@ -134,12 +168,17 @@ def _fwd_kernel(
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
else: else:
if EVEN_HEADDIM: if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, k = tl.load(
other=0.0) k_ptrs + start_n * stride_kn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else: else:
k = tl.load(k_ptrs + start_n * stride_kn, k = tl.load(
k_ptrs + start_n * stride_kn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0) other=0.0,
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True) qk += tl.dot(q, k, trans_b=True)
# Trying to combine the two masks seem to make the result wrong # Trying to combine the two masks seem to make the result wrong
...@@ -147,21 +186,25 @@ def _fwd_kernel( ...@@ -147,21 +186,25 @@ def _fwd_kernel(
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
if IS_CAUSAL: if IS_CAUSAL:
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
if BIAS_TYPE != 'none': if BIAS_TYPE != "none":
if BIAS_TYPE == 'vector': if BIAS_TYPE == "vector":
if EVEN_N: if EVEN_N:
bias = tl.load(b_ptrs + start_n).to(tl.float32) bias = tl.load(b_ptrs + start_n).to(tl.float32)
else: else:
bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32) bias = tl.load(
b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0
).to(tl.float32)
bias = bias[None, :] bias = bias[None, :]
elif BIAS_TYPE == 'matrix': elif BIAS_TYPE == "matrix":
if EVEN_M & EVEN_N: if EVEN_M & EVEN_N:
bias = tl.load(b_ptrs + start_n).to(tl.float32) bias = tl.load(b_ptrs + start_n).to(tl.float32)
else: else:
bias = tl.load(b_ptrs + start_n, bias = tl.load(
b_ptrs + start_n,
mask=(offs_m[:, None] < seqlen_q) mask=(offs_m[:, None] < seqlen_q)
& ((start_n + offs_n)[None, :] < seqlen_k), & ((start_n + offs_n)[None, :] < seqlen_k),
other=0.0).to(tl.float32) other=0.0,
).to(tl.float32)
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
# can then fuse the mult and add into an fma instruction. But if we have bias we need to # can then fuse the mult and add into an fma instruction. But if we have bias we need to
# to multiply with softmax_scale here. # to multiply with softmax_scale here.
...@@ -189,12 +232,17 @@ def _fwd_kernel( ...@@ -189,12 +232,17 @@ def _fwd_kernel(
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
else: else:
if EVEN_HEADDIM: if EVEN_HEADDIM:
v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, v = tl.load(
other=0.0) v_ptrs + start_n * stride_vn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else: else:
v = tl.load(v_ptrs + start_n * stride_vn, v = tl.load(
v_ptrs + start_n * stride_vn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0) other=0.0,
)
p = p.to(v.dtype) p = p.to(v.dtype)
acc_o += tl.dot(p, v) acc_o += tl.dot(p, v)
...@@ -216,7 +264,12 @@ def _fwd_kernel( ...@@ -216,7 +264,12 @@ def _fwd_kernel(
tl.store(lse_ptrs, lse_i) tl.store(lse_ptrs, lse_i)
# initialize pointers to output # initialize pointers to output
offs_d = tl.arange(0, BLOCK_HEADDIM) offs_d = tl.arange(0, BLOCK_HEADDIM)
out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :]) out_ptrs = (
Out
+ off_b * stride_ob
+ off_h * stride_oh
+ (offs_m[:, None] * stride_om + offs_d[None, :])
)
if EVEN_M: if EVEN_M:
if EVEN_HEADDIM: if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o) tl.store(out_ptrs, acc_o)
...@@ -226,17 +279,28 @@ def _fwd_kernel( ...@@ -226,17 +279,28 @@ def _fwd_kernel(
if EVEN_HEADDIM: if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
else: else:
tl.store(out_ptrs, acc_o, tl.store(
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)) out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
)
@triton.jit @triton.jit
def _bwd_preprocess_do_o_dot( def _bwd_preprocess_do_o_dot(
Out, DO, Delta, Out,
stride_ob, stride_oh, stride_om, DO,
stride_dob, stride_doh, stride_dom, Delta,
nheads, seqlen_q, seqlen_q_rounded, headdim, stride_ob,
BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, stride_oh,
stride_om,
stride_dob,
stride_doh,
stride_dom,
nheads,
seqlen_q,
seqlen_q_rounded,
headdim,
BLOCK_M: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
): ):
start_m = tl.program_id(0) start_m = tl.program_id(0)
off_hb = tl.program_id(1) off_hb = tl.program_id(1)
...@@ -246,10 +310,20 @@ def _bwd_preprocess_do_o_dot( ...@@ -246,10 +310,20 @@ def _bwd_preprocess_do_o_dot(
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, BLOCK_HEADDIM) offs_d = tl.arange(0, BLOCK_HEADDIM)
# load # load
o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], o = tl.load(
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32) Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :], mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32) other=0.0,
).to(tl.float32)
do = tl.load(
DO
+ off_b * stride_dob
+ off_h * stride_doh
+ offs_m[:, None] * stride_dom
+ offs_d[None, :],
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0,
).to(tl.float32)
delta = tl.sum(o * do, axis=1) delta = tl.sum(o * do, axis=1)
# write-back # write-back
tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
...@@ -257,8 +331,17 @@ def _bwd_preprocess_do_o_dot( ...@@ -257,8 +331,17 @@ def _bwd_preprocess_do_o_dot(
@triton.jit @triton.jit
def _bwd_store_dk_dv( def _bwd_store_dk_dv(
dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, dk_ptrs,
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, dv_ptrs,
dk,
dv,
offs_n,
offs_d,
seqlen_k,
headdim,
EVEN_M: tl.constexpr,
EVEN_N: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
): ):
# [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False, # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
# if we just call tl.store(dv_ptrs), there's a race condition # if we just call tl.store(dv_ptrs), there's a race condition
...@@ -281,19 +364,37 @@ def _bwd_store_dk_dv( ...@@ -281,19 +364,37 @@ def _bwd_store_dk_dv(
@triton.jit @triton.jit
def _bwd_kernel_one_col_block( def _bwd_kernel_one_col_block(
start_n, start_n,
Q, K, V, Bias, Q,
DO, DQ, DK, DV, K,
LSE, D, V,
Bias,
DO,
DQ,
DK,
DV,
LSE,
D,
softmax_scale, softmax_scale,
stride_qm, stride_kn, stride_vn, stride_bm, stride_qm,
stride_dom, stride_dqm, stride_dkn, stride_dvn, stride_kn,
seqlen_q, seqlen_k, headdim, stride_vn,
stride_bm,
stride_dom,
stride_dqm,
stride_dkn,
stride_dvn,
seqlen_q,
seqlen_k,
headdim,
ATOMIC_ADD: tl.constexpr, ATOMIC_ADD: tl.constexpr,
BIAS_TYPE: tl.constexpr, BIAS_TYPE: tl.constexpr,
IS_CAUSAL: tl.constexpr, IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, EVEN_N: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
): ):
# We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
...@@ -308,9 +409,9 @@ def _bwd_kernel_one_col_block( ...@@ -308,9 +409,9 @@ def _bwd_kernel_one_col_block(
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
if BIAS_TYPE == 'vector': if BIAS_TYPE == "vector":
b_ptrs = Bias + offs_n b_ptrs = Bias + offs_n
elif BIAS_TYPE == 'matrix': elif BIAS_TYPE == "matrix":
b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
# initialize dv and dk # initialize dv and dk
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
...@@ -322,8 +423,19 @@ def _bwd_kernel_one_col_block( ...@@ -322,8 +423,19 @@ def _bwd_kernel_one_col_block(
if begin_m >= seqlen_q: if begin_m >= seqlen_q:
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
_bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, _bwd_store_dk_dv(
EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM) dk_ptrs,
dv_ptrs,
dk,
dv,
offs_n,
offs_d,
seqlen_k,
headdim,
EVEN_M=EVEN_M,
EVEN_N=EVEN_N,
EVEN_HEADDIM=EVEN_HEADDIM,
)
return return
# k and v stay in SRAM throughout # k and v stay in SRAM throughout
# [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False, # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
...@@ -340,10 +452,12 @@ def _bwd_kernel_one_col_block( ...@@ -340,10 +452,12 @@ def _bwd_kernel_one_col_block(
k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
else: else:
k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), k = tl.load(
other=0.0) k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), )
other=0.0) v = tl.load(
v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
)
# loop over rows # loop over rows
num_block_m = tl.cdiv(seqlen_q, BLOCK_M) num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
...@@ -357,8 +471,11 @@ def _bwd_kernel_one_col_block( ...@@ -357,8 +471,11 @@ def _bwd_kernel_one_col_block(
if EVEN_HEADDIM: if EVEN_HEADDIM:
q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
else: else:
q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) q = tl.load(
& (offs_d[None, :] < headdim), other=0.0) q_ptrs,
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0,
)
# recompute p = softmax(qk, dim=-1).T # recompute p = softmax(qk, dim=-1).T
qk = tl.dot(q, k, trans_b=True) qk = tl.dot(q, k, trans_b=True)
# Trying to combine the two masks seem to make the result wrong # Trying to combine the two masks seem to make the result wrong
...@@ -366,29 +483,30 @@ def _bwd_kernel_one_col_block( ...@@ -366,29 +483,30 @@ def _bwd_kernel_one_col_block(
qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf")) qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
if IS_CAUSAL: if IS_CAUSAL:
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
if BIAS_TYPE != 'none': if BIAS_TYPE != "none":
tl.debug_barrier() # Race condition otherwise tl.debug_barrier() # Race condition otherwise
if BIAS_TYPE == 'vector': if BIAS_TYPE == "vector":
if EVEN_N: if EVEN_N:
bias = tl.load(b_ptrs).to(tl.float32) bias = tl.load(b_ptrs).to(tl.float32)
else: else:
bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32) bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
bias = bias[None, :] bias = bias[None, :]
elif BIAS_TYPE == 'matrix': elif BIAS_TYPE == "matrix":
if EVEN_M & EVEN_N: if EVEN_M & EVEN_N:
bias = tl.load(b_ptrs).to(tl.float32) bias = tl.load(b_ptrs).to(tl.float32)
else: else:
bias = tl.load(b_ptrs, bias = tl.load(
mask=(offs_m_curr[:, None] < seqlen_q) b_ptrs,
& (offs_n[None, :] < seqlen_k), mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k),
other=0.0).to(tl.float32) other=0.0,
).to(tl.float32)
qk = qk * softmax_scale + bias qk = qk * softmax_scale + bias
# There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
# Also wrong for headdim=64. # Also wrong for headdim=64.
if not (EVEN_M & EVEN_HEADDIM): if not (EVEN_M & EVEN_HEADDIM):
tl.debug_barrier() tl.debug_barrier()
lse_i = tl.load(LSE + offs_m_curr) lse_i = tl.load(LSE + offs_m_curr)
if BIAS_TYPE == 'none': if BIAS_TYPE == "none":
p = tl.exp(qk * softmax_scale - lse_i[:, None]) p = tl.exp(qk * softmax_scale - lse_i[:, None])
else: else:
p = tl.exp(qk - lse_i[:, None]) p = tl.exp(qk - lse_i[:, None])
...@@ -401,8 +519,11 @@ def _bwd_kernel_one_col_block( ...@@ -401,8 +519,11 @@ def _bwd_kernel_one_col_block(
do = tl.load(do_ptrs) do = tl.load(do_ptrs)
else: else:
# [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask. # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) do = tl.load(
& (offs_d[None, :] < headdim), other=0.0) do_ptrs,
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0,
)
# if EVEN_M: # if EVEN_M:
# if EVEN_HEADDIM: # if EVEN_HEADDIM:
# do = tl.load(do_ptrs) # do = tl.load(do_ptrs)
...@@ -434,7 +555,9 @@ def _bwd_kernel_one_col_block( ...@@ -434,7 +555,9 @@ def _bwd_kernel_one_col_block(
# compute dk = dot(ds.T, q) # compute dk = dot(ds.T, q)
dk += tl.dot(ds, q, trans_a=True) dk += tl.dot(ds, q, trans_a=True)
# compute dq # compute dq
if not (EVEN_M & EVEN_HEADDIM): # Otherewise there's a race condition when BIAS_TYPE='matrix' if not (
EVEN_M & EVEN_HEADDIM
): # Otherewise there's a race condition when BIAS_TYPE='matrix'
tl.debug_barrier() tl.debug_barrier()
if not ATOMIC_ADD: if not ATOMIC_ADD:
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
...@@ -443,19 +566,33 @@ def _bwd_kernel_one_col_block( ...@@ -443,19 +566,33 @@ def _bwd_kernel_one_col_block(
tl.store(dq_ptrs, dq, eviction_policy="evict_last") tl.store(dq_ptrs, dq, eviction_policy="evict_last")
else: else:
if EVEN_HEADDIM: if EVEN_HEADDIM:
dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0, dq = tl.load(
eviction_policy="evict_last") dq_ptrs,
mask=offs_m_curr[:, None] < seqlen_q,
other=0.0,
eviction_policy="evict_last",
)
dq += tl.dot(ds, k) dq += tl.dot(ds, k)
tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, tl.store(
eviction_policy="evict_last") dq_ptrs,
dq,
mask=offs_m_curr[:, None] < seqlen_q,
eviction_policy="evict_last",
)
else: else:
dq = tl.load(dq_ptrs, dq = tl.load(
dq_ptrs,
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0, eviction_policy="evict_last") other=0.0,
eviction_policy="evict_last",
)
dq += tl.dot(ds, k) dq += tl.dot(ds, k)
tl.store(dq_ptrs, dq, tl.store(
dq_ptrs,
dq,
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
eviction_policy="evict_last") eviction_policy="evict_last",
)
else: # If we're parallelizing across the seqlen_k dimension else: # If we're parallelizing across the seqlen_k dimension
dq = tl.dot(ds, k) dq = tl.dot(ds, k)
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
...@@ -464,19 +601,33 @@ def _bwd_kernel_one_col_block( ...@@ -464,19 +601,33 @@ def _bwd_kernel_one_col_block(
if EVEN_HEADDIM: if EVEN_HEADDIM:
tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
else: else:
tl.atomic_add(dq_ptrs, dq, tl.atomic_add(
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim)) dq_ptrs,
dq,
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
)
# increment pointers # increment pointers
dq_ptrs += BLOCK_M * stride_dqm dq_ptrs += BLOCK_M * stride_dqm
q_ptrs += BLOCK_M * stride_qm q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_dom do_ptrs += BLOCK_M * stride_dom
if BIAS_TYPE == 'matrix': if BIAS_TYPE == "matrix":
b_ptrs += BLOCK_M * stride_bm b_ptrs += BLOCK_M * stride_bm
# write-back # write-back
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
_bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, _bwd_store_dk_dv(
EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM) dk_ptrs,
dv_ptrs,
dk,
dv,
offs_n,
offs_d,
seqlen_k,
headdim,
EVEN_M=EVEN_M,
EVEN_N=EVEN_N,
EVEN_HEADDIM=EVEN_HEADDIM,
)
def init_to_zero(name): def init_to_zero(name):
...@@ -485,8 +636,18 @@ def init_to_zero(name): ...@@ -485,8 +636,18 @@ def init_to_zero(name):
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), triton.Config(
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False},
num_warps=8,
num_stages=1,
pre_hook=init_to_zero("DQ"),
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True},
num_warps=8,
num_stages=1,
pre_hook=init_to_zero("DQ"),
),
# Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now
# # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4* # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
...@@ -494,7 +655,7 @@ def init_to_zero(name): ...@@ -494,7 +655,7 @@ def init_to_zero(name):
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
], ],
key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'], key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "IS_CAUSAL", "BLOCK_HEADDIM"],
) )
@triton.heuristics( @triton.heuristics(
{ {
...@@ -505,26 +666,57 @@ def init_to_zero(name): ...@@ -505,26 +666,57 @@ def init_to_zero(name):
) )
@triton.jit @triton.jit
def _bwd_kernel( def _bwd_kernel(
Q, K, V, Bias, Q,
DO, DQ, DK, DV, K,
LSE, D, V,
Bias,
DO,
DQ,
DK,
DV,
LSE,
D,
softmax_scale, softmax_scale,
stride_qb, stride_qh, stride_qm, stride_qb,
stride_kb, stride_kh, stride_kn, stride_qh,
stride_vb, stride_vh, stride_vn, stride_qm,
stride_bb, stride_bh, stride_bm, stride_kb,
stride_dob, stride_doh, stride_dom, stride_kh,
stride_dqb, stride_dqh, stride_dqm, stride_kn,
stride_dkb, stride_dkh, stride_dkn, stride_vb,
stride_dvb, stride_dvh, stride_dvn, stride_vh,
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, stride_vn,
CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, stride_bb,
stride_bh,
stride_bm,
stride_dob,
stride_doh,
stride_dom,
stride_dqb,
stride_dqh,
stride_dqm,
stride_dkb,
stride_dkh,
stride_dkn,
stride_dvb,
stride_dvh,
stride_dvn,
nheads,
seqlen_q,
seqlen_k,
seqlen_q_rounded,
headdim,
CACHE_KEY_SEQLEN_Q,
CACHE_KEY_SEQLEN_K,
BIAS_TYPE: tl.constexpr, BIAS_TYPE: tl.constexpr,
IS_CAUSAL: tl.constexpr, IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr,
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, EVEN_N: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
): ):
off_hb = tl.program_id(1) off_hb = tl.program_id(1)
off_b = off_hb // nheads off_b = off_hb // nheads
...@@ -537,7 +729,7 @@ def _bwd_kernel( ...@@ -537,7 +729,7 @@ def _bwd_kernel(
DQ += off_b * stride_dqb + off_h * stride_dqh DQ += off_b * stride_dqb + off_h * stride_dqh
DK += off_b * stride_dkb + off_h * stride_dkh DK += off_b * stride_dkb + off_h * stride_dkh
DV += off_b * stride_dvb + off_h * stride_dvh DV += off_b * stride_dvb + off_h * stride_dvh
if BIAS_TYPE != 'none': if BIAS_TYPE != "none":
Bias += off_b * stride_bb + off_h * stride_bh Bias += off_b * stride_bb + off_h * stride_bh
# pointer to row-wise quantities in value-like data # pointer to row-wise quantities in value-like data
D += off_hb * seqlen_q_rounded D += off_hb * seqlen_q_rounded
...@@ -547,37 +739,73 @@ def _bwd_kernel( ...@@ -547,37 +739,73 @@ def _bwd_kernel(
for start_n in range(0, num_block_n): for start_n in range(0, num_block_n):
_bwd_kernel_one_col_block( _bwd_kernel_one_col_block(
start_n, start_n,
Q, K, V, Bias, Q,
DO, DQ, DK, DV, K,
LSE, D, V,
Bias,
DO,
DQ,
DK,
DV,
LSE,
D,
softmax_scale, softmax_scale,
stride_qm, stride_kn, stride_vn, stride_bm, stride_qm,
stride_dom, stride_dqm, stride_dkn, stride_dvn, stride_kn,
seqlen_q, seqlen_k, headdim, stride_vn,
stride_bm,
stride_dom,
stride_dqm,
stride_dkn,
stride_dvn,
seqlen_q,
seqlen_k,
headdim,
ATOMIC_ADD=False, ATOMIC_ADD=False,
BIAS_TYPE=BIAS_TYPE, BIAS_TYPE=BIAS_TYPE,
IS_CAUSAL=IS_CAUSAL, IS_CAUSAL=IS_CAUSAL,
BLOCK_HEADDIM=BLOCK_HEADDIM, BLOCK_HEADDIM=BLOCK_HEADDIM,
EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, EVEN_M=EVEN_M,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N EVEN_N=EVEN_N,
EVEN_HEADDIM=EVEN_HEADDIM,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
) )
else: else:
start_n = tl.program_id(0) start_n = tl.program_id(0)
_bwd_kernel_one_col_block( _bwd_kernel_one_col_block(
start_n, start_n,
Q, K, V, Bias, Q,
DO, DQ, DK, DV, K,
LSE, D, V,
Bias,
DO,
DQ,
DK,
DV,
LSE,
D,
softmax_scale, softmax_scale,
stride_qm, stride_kn, stride_vn, stride_bm, stride_qm,
stride_dom, stride_dqm, stride_dkn, stride_dvn, stride_kn,
seqlen_q, seqlen_k, headdim, stride_vn,
stride_bm,
stride_dom,
stride_dqm,
stride_dkn,
stride_dvn,
seqlen_q,
seqlen_k,
headdim,
ATOMIC_ADD=True, ATOMIC_ADD=True,
BIAS_TYPE=BIAS_TYPE, BIAS_TYPE=BIAS_TYPE,
IS_CAUSAL=IS_CAUSAL, IS_CAUSAL=IS_CAUSAL,
BLOCK_HEADDIM=BLOCK_HEADDIM, BLOCK_HEADDIM=BLOCK_HEADDIM,
EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, EVEN_M=EVEN_M,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N EVEN_N=EVEN_N,
EVEN_HEADDIM=EVEN_HEADDIM,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
) )
...@@ -587,14 +815,14 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None): ...@@ -587,14 +815,14 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
_, seqlen_k, _, _ = k.shape _, seqlen_k, _, _ = k.shape
assert k.shape == (batch, seqlen_k, nheads, d) assert k.shape == (batch, seqlen_k, nheads, d)
assert v.shape == (batch, seqlen_k, nheads, d) assert v.shape == (batch, seqlen_k, nheads, d)
assert d <= 128, 'FlashAttention only support head dimensions up to 128' assert d <= 128, "FlashAttention only support head dimensions up to 128"
assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type' assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16' assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
assert q.is_cuda and k.is_cuda and v.is_cuda assert q.is_cuda and k.is_cuda and v.is_cuda
softmax_scale = softmax_scale or 1.0 / math.sqrt(d) softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
has_bias = bias is not None has_bias = bias is not None
bias_type = 'none' bias_type = "none"
if has_bias: if has_bias:
assert bias.dtype in [q.dtype, torch.float] assert bias.dtype in [q.dtype, torch.float]
assert bias.is_cuda assert bias.is_cuda
...@@ -602,12 +830,13 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None): ...@@ -602,12 +830,13 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
if bias.stride(-1) != 1: if bias.stride(-1) != 1:
bias = bias.contiguous() bias = bias.contiguous()
if bias.shape[2:] == (1, seqlen_k): if bias.shape[2:] == (1, seqlen_k):
bias_type = 'vector' bias_type = "vector"
elif bias.shape[2:] == (seqlen_q, seqlen_k): elif bias.shape[2:] == (seqlen_q, seqlen_k):
bias_type = 'matrix' bias_type = "matrix"
else: else:
raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)' raise RuntimeError(
' or (seqlen_q, seqlen_k)') "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)"
)
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
...@@ -621,27 +850,50 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None): ...@@ -621,27 +850,50 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
num_warps = 4 if d <= 64 else 8 num_warps = 4 if d <= 64 else 8
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
_fwd_kernel[grid]( _fwd_kernel[grid](
q, k, v, bias, o, q,
lse, tmp, k,
v,
bias,
o,
lse,
tmp,
softmax_scale, softmax_scale,
q.stride(0), q.stride(2), q.stride(1), q.stride(0),
k.stride(0), k.stride(2), k.stride(1), q.stride(2),
v.stride(0), v.stride(2), v.stride(1), q.stride(1),
k.stride(0),
k.stride(2),
k.stride(1),
v.stride(0),
v.stride(2),
v.stride(1),
*bias_strides, *bias_strides,
o.stride(0), o.stride(2), o.stride(1), o.stride(0),
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, o.stride(2),
seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) o.stride(1),
nheads,
seqlen_q,
seqlen_k,
seqlen_q_rounded,
d,
seqlen_q // 32,
seqlen_k // 32, # key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs # Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d, # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
bias_type, causal, BLOCK_HEADDIM, bias_type,
BLOCK_M=BLOCK, BLOCK_N=BLOCK, causal,
BLOCK_HEADDIM,
BLOCK_M=BLOCK,
BLOCK_N=BLOCK,
num_warps=num_warps, num_warps=num_warps,
num_stages=1, num_stages=1,
) )
return o, lse, softmax_scale # softmax_scale could have been updated return o, lse, softmax_scale # softmax_scale could have been updated
def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None): def _flash_attn_backward(
do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None
):
# Make sure that the last dimension is contiguous # Make sure that the last dimension is contiguous
if do.stride(-1) != 1: if do.stride(-1) != 1:
do = do.contiguous() do = do.contiguous()
...@@ -662,53 +914,94 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=Fals ...@@ -662,53 +914,94 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=Fals
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
_bwd_preprocess_do_o_dot[grid]( _bwd_preprocess_do_o_dot[grid](
o, do, delta, o,
o.stride(0), o.stride(2), o.stride(1), do,
do.stride(0), do.stride(2), do.stride(1), delta,
nheads, seqlen_q, seqlen_q_rounded, d, o.stride(0),
BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM, o.stride(2),
o.stride(1),
do.stride(0),
do.stride(2),
do.stride(1),
nheads,
seqlen_q,
seqlen_q_rounded,
d,
BLOCK_M=128,
BLOCK_HEADDIM=BLOCK_HEADDIM,
) )
has_bias = bias is not None has_bias = bias is not None
bias_type = 'none' bias_type = "none"
if has_bias: if has_bias:
assert bias.dtype in [q.dtype, torch.float] assert bias.dtype in [q.dtype, torch.float]
assert bias.is_cuda assert bias.is_cuda
assert bias.dim() == 4 assert bias.dim() == 4
assert bias.stride(-1) == 1 assert bias.stride(-1) == 1
if bias.shape[2:] == (1, seqlen_k): if bias.shape[2:] == (1, seqlen_k):
bias_type = 'vector' bias_type = "vector"
elif bias.shape[2:] == (seqlen_q, seqlen_k): elif bias.shape[2:] == (seqlen_q, seqlen_k):
bias_type = 'matrix' bias_type = "matrix"
else: else:
raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)' raise RuntimeError(
' or (seqlen_q, seqlen_k)') "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)"
)
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
# BLOCK_M = 128 # BLOCK_M = 128
# BLOCK_N = 64 # BLOCK_N = 64
# num_warps = 4 # num_warps = 4
grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, grid = lambda META: (
batch * nheads) triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1,
batch * nheads,
)
_bwd_kernel[grid]( _bwd_kernel[grid](
q, k, v, bias, q,
do, dq_accum, dk, dv, k,
lse, delta, v,
bias,
do,
dq_accum,
dk,
dv,
lse,
delta,
softmax_scale, softmax_scale,
q.stride(0), q.stride(2), q.stride(1), q.stride(0),
k.stride(0), k.stride(2), k.stride(1), q.stride(2),
v.stride(0), v.stride(2), v.stride(1), q.stride(1),
k.stride(0),
k.stride(2),
k.stride(1),
v.stride(0),
v.stride(2),
v.stride(1),
*bias_strides, *bias_strides,
do.stride(0), do.stride(2), do.stride(1), do.stride(0),
dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1), do.stride(2),
dk.stride(0), dk.stride(2), dk.stride(1), do.stride(1),
dv.stride(0), dv.stride(2), dv.stride(1), dq_accum.stride(0),
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, dq_accum.stride(2),
seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) dq_accum.stride(1),
dk.stride(0),
dk.stride(2),
dk.stride(1),
dv.stride(0),
dv.stride(2),
dv.stride(1),
nheads,
seqlen_q,
seqlen_k,
seqlen_q_rounded,
d,
seqlen_q // 32,
seqlen_k // 32, # key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs # Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d, # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
bias_type, causal, BLOCK_HEADDIM, bias_type,
causal,
BLOCK_HEADDIM,
# SEQUENCE_PARALLEL=False, # SEQUENCE_PARALLEL=False,
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
# num_warps=num_warps, # num_warps=num_warps,
...@@ -718,7 +1011,6 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=Fals ...@@ -718,7 +1011,6 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=Fals
class FlashAttnQKVPackedFunc(torch.autograd.Function): class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None): def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
""" """
...@@ -731,8 +1023,12 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -731,8 +1023,12 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
if qkv.stride(-1) != 1: if qkv.stride(-1) != 1:
qkv = qkv.contiguous() qkv = qkv.contiguous()
o, lse, ctx.softmax_scale = _flash_attn_forward( o, lse, ctx.softmax_scale = _flash_attn_forward(
qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal, qkv[:, :, 0],
softmax_scale=softmax_scale qkv[:, :, 1],
qkv[:, :, 2],
bias=bias,
causal=causal,
softmax_scale=softmax_scale,
) )
ctx.save_for_backward(qkv, o, lse, bias) ctx.save_for_backward(qkv, o, lse, bias)
ctx.causal = causal ctx.causal = causal
...@@ -741,14 +1037,25 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -741,14 +1037,25 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, do): def backward(ctx, do):
qkv, o, lse, bias = ctx.saved_tensors qkv, o, lse, bias = ctx.saved_tensors
assert not ctx.needs_input_grad[1], 'FlashAttention does not support bias gradient yet' assert not ctx.needs_input_grad[1], "FlashAttention does not support bias gradient yet"
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
with torch.inference_mode(): with torch.inference_mode():
dqkv = torch.empty_like(qkv) dqkv = torch.empty_like(qkv)
_flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse, _flash_attn_backward(
dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], do,
bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
o,
lse,
dqkv[:, :, 0],
dqkv[:, :, 1],
dqkv[:, :, 2],
bias=bias,
causal=ctx.causal,
softmax_scale=ctx.softmax_scale,
)
return dqkv, None, None, None return dqkv, None, None, None
...@@ -756,7 +1063,6 @@ flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply ...@@ -756,7 +1063,6 @@ flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
class FlashAttnKVPackedFunc(torch.autograd.Function): class FlashAttnKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None): def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
""" """
...@@ -779,15 +1085,26 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -779,15 +1085,26 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
def backward(ctx, do): def backward(ctx, do):
q, kv, o, lse, bias = ctx.saved_tensors q, kv, o, lse, bias = ctx.saved_tensors
if len(ctx.needs_input_grad) >= 3: if len(ctx.needs_input_grad) >= 3:
assert not ctx.needs_input_grad[2], 'FlashAttention does not support bias gradient yet' assert not ctx.needs_input_grad[2], "FlashAttention does not support bias gradient yet"
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
with torch.inference_mode(): with torch.inference_mode():
dq = torch.empty_like(q) dq = torch.empty_like(q)
dkv = torch.empty_like(kv) dkv = torch.empty_like(kv)
_flash_attn_backward(do, q, kv[:, :, 0], kv[:, :, 1], o, lse, _flash_attn_backward(
dq, dkv[:, :, 0], dkv[:, :, 1], do,
bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) q,
kv[:, :, 0],
kv[:, :, 1],
o,
lse,
dq,
dkv[:, :, 0],
dkv[:, :, 1],
bias=bias,
causal=ctx.causal,
softmax_scale=ctx.softmax_scale,
)
return dq, dkv, None, None, None return dq, dkv, None, None, None
...@@ -795,7 +1112,6 @@ flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply ...@@ -795,7 +1112,6 @@ flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
class FlashAttnFunc(torch.autograd.Function): class FlashAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None): def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
""" """
...@@ -817,15 +1133,27 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -817,15 +1133,27 @@ class FlashAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, do): def backward(ctx, do):
q, k, v, o, lse, bias = ctx.saved_tensors q, k, v, o, lse, bias = ctx.saved_tensors
assert not ctx.needs_input_grad[3], 'FlashAttention does not support bias gradient yet' assert not ctx.needs_input_grad[3], "FlashAttention does not support bias gradient yet"
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
with torch.inference_mode(): with torch.inference_mode():
dq = torch.empty_like(q) dq = torch.empty_like(q)
dk = torch.empty_like(k) dk = torch.empty_like(k)
dv = torch.empty_like(v) dv = torch.empty_like(v)
_flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, _flash_attn_backward(
bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) do,
q,
k,
v,
o,
lse,
dq,
dk,
dv,
bias=bias,
causal=ctx.causal,
softmax_scale=ctx.softmax_scale,
)
return dq, dk, dv, None, None, None return dq, dk, dv, None, None, None
......
...@@ -11,22 +11,41 @@ This is a Triton implementation of the Flash Attention algorithm ...@@ -11,22 +11,41 @@ This is a Triton implementation of the Flash Attention algorithm
import pytest import pytest
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
@triton.jit @triton.jit
def _fwd_kernel( def _fwd_kernel(
Q, K, V, sm_scale, Q,
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug K,
V,
sm_scale,
TMP,
L,
M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
Out, Out,
stride_qz, stride_qh, stride_qm, stride_qk, stride_qz,
stride_kz, stride_kh, stride_kn, stride_kk, stride_qh,
stride_vz, stride_vh, stride_vk, stride_vn, stride_qm,
stride_oz, stride_oh, stride_om, stride_on, stride_qk,
Z, H, N_CTX, stride_kz,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
stride_oz,
stride_oh,
stride_om,
stride_on,
Z,
H,
N_CTX,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
): ):
start_m = tl.program_id(0) start_m = tl.program_id(0)
...@@ -100,9 +119,13 @@ def _fwd_kernel( ...@@ -100,9 +119,13 @@ def _fwd_kernel(
@triton.jit @triton.jit
def _bwd_preprocess( def _bwd_preprocess(
Out, DO, L, Out,
NewDO, Delta, DO,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, L,
NewDO,
Delta,
BLOCK_M: tl.constexpr,
D_HEAD: tl.constexpr,
): ):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD) off_n = tl.arange(0, D_HEAD)
...@@ -120,16 +143,36 @@ def _bwd_preprocess( ...@@ -120,16 +143,36 @@ def _bwd_preprocess(
@triton.jit @triton.jit
def _bwd_kernel( def _bwd_kernel(
Q, K, V, sm_scale, Out, DO, Q,
DQ, DK, DV, K,
L, M, V,
sm_scale,
Out,
DO,
DQ,
DK,
DV,
L,
M,
D, D,
stride_qz, stride_qh, stride_qm, stride_qk, stride_qz,
stride_kz, stride_kh, stride_kn, stride_kk, stride_qh,
stride_vz, stride_vh, stride_vk, stride_vn, stride_qm,
Z, H, N_CTX, stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
Z,
H,
N_CTX,
num_block, num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
): ):
off_hz = tl.program_id(0) off_hz = tl.program_id(0)
...@@ -203,7 +246,6 @@ def _bwd_kernel( ...@@ -203,7 +246,6 @@ def _bwd_kernel(
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, sm_scale): def forward(ctx, q, k, v, sm_scale):
BLOCK = 128 BLOCK = 128
...@@ -213,22 +255,45 @@ class _attention(torch.autograd.Function): ...@@ -213,22 +255,45 @@ class _attention(torch.autograd.Function):
assert Lk in {16, 32, 64, 128} assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q) o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) tmp = torch.empty(
(q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8 num_warps = 4 if Lk <= 64 else 8
_fwd_kernel[grid]( _fwd_kernel[grid](
q, k, v, sm_scale, q,
tmp, L, m, k,
v,
sm_scale,
tmp,
L,
m,
o, o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3), q.stride(0),
k.stride(0), k.stride(1), k.stride(2), k.stride(3), q.stride(1),
v.stride(0), v.stride(1), v.stride(2), v.stride(3), q.stride(2),
o.stride(0), o.stride(1), o.stride(2), o.stride(3), q.stride(3),
q.shape[0], q.shape[1], q.shape[2], k.stride(0),
BLOCK_M=BLOCK, BLOCK_N=BLOCK, k.stride(1),
BLOCK_DMODEL=Lk, num_warps=num_warps, k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
o.stride(0),
o.stride(1),
o.stride(2),
o.stride(3),
q.shape[0],
q.shape[1],
q.shape[2],
BLOCK_M=BLOCK,
BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk,
num_warps=num_warps,
num_stages=1, num_stages=1,
) )
ctx.save_for_backward(q, k, v, o, L, m) ctx.save_for_backward(q, k, v, o, L, m)
...@@ -247,27 +312,51 @@ class _attention(torch.autograd.Function): ...@@ -247,27 +312,51 @@ class _attention(torch.autograd.Function):
dv = torch.empty_like(v) dv = torch.empty_like(v)
do_scaled = torch.empty_like(do) do_scaled = torch.empty_like(do)
delta = torch.empty_like(l) delta = torch.empty_like(l)
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( _bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)](
o, do, l, o,
do_scaled, delta, do,
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, l,
do_scaled,
delta,
BLOCK_M=ctx.BLOCK,
D_HEAD=ctx.BLOCK_DMODEL,
) )
# NOTE: kernel currently buggy for other values of `num_warps` # NOTE: kernel currently buggy for other values of `num_warps`
num_warps = 8 num_warps = 8
_bwd_kernel[(ctx.grid[1],)]( _bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale, q,
o, do_scaled, k,
dq, dk, dv, v,
l, m, ctx.sm_scale,
o,
do_scaled,
dq,
dk,
dv,
l,
m,
delta, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3), q.stride(0),
k.stride(0), k.stride(1), k.stride(2), k.stride(3), q.stride(1),
v.stride(0), v.stride(1), v.stride(2), v.stride(3), q.stride(2),
q.shape[0], q.shape[1], q.shape[2], q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
q.shape[0],
q.shape[1],
q.shape[2],
ctx.grid[0], ctx.grid[0],
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, BLOCK_M=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps, BLOCK_N=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
num_warps=num_warps,
num_stages=1, num_stages=1,
) )
return dq.to(q.dtype), dk, dv, None return dq.to(q.dtype), dk, dv, None
......
import math import math
import hydra
import torch import torch
import torch.nn as nn import torch.nn as nn
from einops import rearrange from einops import rearrange
import hydra from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
from flash_attn.flash_blocksparse_attn_interface import (
from flash_attn.flash_blocksparse_attn_interface import flash_blocksparse_attn_func convert_blockmask,
from flash_attn.flash_blocksparse_attn_interface import convert_blockmask flash_blocksparse_attn_func,
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis )
class FlashBlocksparseAttention(nn.Module): class FlashBlocksparseAttention(nn.Module):
...@@ -21,8 +22,16 @@ class FlashBlocksparseAttention(nn.Module): ...@@ -21,8 +22,16 @@ class FlashBlocksparseAttention(nn.Module):
attention_dropout: The dropout rate to apply to the attention attention_dropout: The dropout rate to apply to the attention
(default: 0.1) (default: 0.1)
""" """
def __init__(self, sparsity_config, softmax_temp=None, attention_dropout=0.0,
max_seq_length=2048, device=None, dtype=None): def __init__(
self,
sparsity_config,
softmax_temp=None,
attention_dropout=0.0,
max_seq_length=2048,
device=None,
dtype=None,
):
super().__init__() super().__init__()
self.sparsity_config = hydra.utils.instantiate(sparsity_config) self.sparsity_config = hydra.utils.instantiate(sparsity_config)
self.softmax_temp = softmax_temp self.softmax_temp = softmax_temp
...@@ -36,8 +45,17 @@ class FlashBlocksparseAttention(nn.Module): ...@@ -36,8 +45,17 @@ class FlashBlocksparseAttention(nn.Module):
self.register_buffer("blockmask_converted", blockmask_converted) self.register_buffer("blockmask_converted", blockmask_converted)
# logger.info(f'Attention class {self.__class__}: saving={self.layout.float().mean()}') # logger.info(f'Attention class {self.__class__}: saving={self.layout.float().mean()}')
def forward(self, qkv, attn_mask=None, key_padding_mask=None, causal=False, cu_seqlens=None, def forward(
max_s=None, need_weights=False, convert_mask=True): self,
qkv,
attn_mask=None,
key_padding_mask=None,
causal=False,
cu_seqlens=None,
max_s=None,
need_weights=False,
convert_mask=True,
):
"""Implements the multihead softmax attention. """Implements the multihead softmax attention.
Arguments Arguments
--------- ---------
...@@ -57,47 +75,76 @@ class FlashBlocksparseAttention(nn.Module): ...@@ -57,47 +75,76 @@ class FlashBlocksparseAttention(nn.Module):
seqlen = qkv.shape[1] seqlen = qkv.shape[1]
# Convert mask to take a subset # Convert mask to take a subset
seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256 seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256
assert seqlen_rounded // 16 <= self.layout.shape[0], seqlen_rounded // 256 <= self.layout.shape[1] assert seqlen_rounded // 16 <= self.layout.shape[0], (
blockmask = self.layout[:seqlen_rounded // 16, :seqlen_rounded // 256] seqlen_rounded // 256 <= self.layout.shape[1]
)
blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256]
if key_padding_mask is None: if key_padding_mask is None:
qkv = rearrange(qkv, 'b s ... -> (b s) ...') qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = seqlen max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, cu_seqlens = torch.arange(
device=qkv.device) 0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device
)
output = flash_blocksparse_attn_func( output = flash_blocksparse_attn_func(
qkv, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0, qkv,
max_s, softmax_scale=self.softmax_temp, causal=causal cu_seqlens,
blockmask,
self.dropout_p if self.training else 0.0,
max_s,
softmax_scale=self.softmax_temp,
causal=causal,
) )
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
else: else:
key_padding_mask_bool = key_padding_mask.bool_matrix key_padding_mask_bool = key_padding_mask.bool_matrix
nheads = qkv.shape[-2] nheads = qkv.shape[-2]
x = rearrange(qkv, 'b s three h d -> b s (three h d)') x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool) x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
output_unpad = flash_blocksparse_attn_func( output_unpad = flash_blocksparse_attn_func(
x_unpad, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0, x_unpad,
max_s, softmax_scale=self.softmax_temp, causal=causal cu_seqlens,
blockmask,
self.dropout_p if self.training else 0.0,
max_s,
softmax_scale=self.softmax_temp,
causal=causal,
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen
),
"b s (h d) -> b s h d",
h=nheads,
) )
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
indices, batch_size, seqlen),
'b s (h d) -> b s h d', h=nheads)
else: else:
assert max_s is not None assert max_s is not None
seqlen = max_s seqlen = max_s
# Convert mask to take a subset # Convert mask to take a subset
seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256 seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256
assert seqlen_rounded // 16 <= self.layout.shape[0], seqlen_rounded // 256 <= self.layout.shape[1] assert seqlen_rounded // 16 <= self.layout.shape[0], (
blockmask = self.layout[:seqlen_rounded // 16, :seqlen_rounded // 256] seqlen_rounded // 256 <= self.layout.shape[1]
)
blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256]
if convert_mask: if convert_mask:
output = flash_blocksparse_attn_func( output = flash_blocksparse_attn_func(
qkv, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0, qkv,
max_s, softmax_scale=self.softmax_temp, causal=causal cu_seqlens,
blockmask,
self.dropout_p if self.training else 0.0,
max_s,
softmax_scale=self.softmax_temp,
causal=causal,
) )
else: else:
output = flash_blocksparse_attn_func( output = flash_blocksparse_attn_func(
qkv, cu_seqlens, self.blockmask_converted, self.dropout_p if self.training else 0.0, qkv,
max_s, softmax_scale=self.softmax_temp, causal=causal, cu_seqlens,
self.blockmask_converted,
self.dropout_p if self.training else 0.0,
max_s,
softmax_scale=self.softmax_temp,
causal=causal,
convert_mask=False, convert_mask=False,
) )
...@@ -105,12 +152,22 @@ class FlashBlocksparseAttention(nn.Module): ...@@ -105,12 +152,22 @@ class FlashBlocksparseAttention(nn.Module):
class FlashBlocksparseMHA(nn.Module): class FlashBlocksparseMHA(nn.Module):
def __init__(
def __init__(self, embed_dim, num_heads, sparsity_config, bias=True, batch_first=True, self,
attention_dropout=0.0, causal=False, max_seq_length=2048, embed_dim,
device=None, dtype=None, **kwargs) -> None: num_heads,
sparsity_config,
bias=True,
batch_first=True,
attention_dropout=0.0,
causal=False,
max_seq_length=2048,
device=None,
dtype=None,
**kwargs,
) -> None:
assert batch_first assert batch_first
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.causal = causal self.causal = causal
...@@ -122,15 +179,19 @@ class FlashBlocksparseMHA(nn.Module): ...@@ -122,15 +179,19 @@ class FlashBlocksparseMHA(nn.Module):
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
self.inner_attn = FlashBlocksparseAttention( self.inner_attn = FlashBlocksparseAttention(
sparsity_config, attention_dropout=attention_dropout, sparsity_config,
max_seq_length=max_seq_length, **factory_kwargs attention_dropout=attention_dropout,
max_seq_length=max_seq_length,
**factory_kwargs,
) )
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
def forward(self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None, def forward(
need_weights=False): self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None, need_weights=False
):
qkv = self.Wqkv(x) qkv = self.Wqkv(x)
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask, context, attn_weights = self.inner_attn(
need_weights=need_weights, causal=self.causal) qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights )
return self.out_proj(rearrange(context, "b s h d -> b s (h d)")), attn_weights
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py
import flash_attn_cuda
import torch import torch
import torch.nn as nn import torch.nn as nn
import flash_attn_cuda
def convert_blockmask(blockmask, causal): def convert_blockmask(blockmask, causal):
"""Convert from the 0-1 format to the format used by the CUDA code. """Convert from the 0-1 format to the format used by the CUDA code.
...@@ -40,29 +39,51 @@ def convert_blockmask(blockmask, causal): ...@@ -40,29 +39,51 @@ def convert_blockmask(blockmask, causal):
return nonzero_idx.T.contiguous().to(dtype=torch.int32) return nonzero_idx.T.contiguous().to(dtype=torch.int32)
def _flash_blocksparse_attn_forward(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, def _flash_blocksparse_attn_forward(
causal, return_softmax): qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax
context, softmax_lse, *rest = flash_attn_cuda.fwd_block(qkv, cu_seqlens, blockmask, dropout_p, ):
max_s, softmax_scale, causal, context, softmax_lse, *rest = flash_attn_cuda.fwd_block(
return_softmax, None) qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax, None
)
# if context.isnan().any() or softmax_lse.isnan().any(): # if context.isnan().any() or softmax_lse.isnan().any():
# breakpoint() # breakpoint()
S_dmask = rest[0] if return_softmax else None S_dmask = rest[0] if return_softmax else None
return context, softmax_lse, S_dmask return context, softmax_lse, S_dmask
def _flash_blocksparse_attn_backward(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, blockmask, def _flash_blocksparse_attn_backward(
dropout_p, max_s, softmax_scale, causal): dout,
dqkv, dp, softmax_d = flash_attn_cuda.bwd_block(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, qkv,
blockmask, dropout_p, softmax_scale, max_s, out,
causal, None) S_dmask,
softmax_lse,
cu_seqlens,
blockmask,
dropout_p,
max_s,
softmax_scale,
causal,
):
dqkv, dp, softmax_d = flash_attn_cuda.bwd_block(
dout,
qkv,
out,
S_dmask,
softmax_lse,
cu_seqlens,
blockmask,
dropout_p,
softmax_scale,
max_s,
causal,
None,
)
# if dqkv.isnan().any() or softmax_d.isnan().any(): # if dqkv.isnan().any() or softmax_d.isnan().any():
# breakpoint() # breakpoint()
return dqkv return dqkv
class FlashBlocksparseAttnFun(torch.autograd.Function): class FlashBlocksparseAttnFun(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal): def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
# Save rng_state because the backward pass will regenerate the dropout mask # Save rng_state because the backward pass will regenerate the dropout mask
...@@ -70,8 +91,14 @@ class FlashBlocksparseAttnFun(torch.autograd.Function): ...@@ -70,8 +91,14 @@ class FlashBlocksparseAttnFun(torch.autograd.Function):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5) softmax_scale = qkv.shape[-1] ** (-0.5)
context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward( context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(
qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal=causal, qkv,
return_softmax=False cu_seqlens,
blockmask,
dropout_p,
max_s,
softmax_scale,
causal=causal,
return_softmax=False,
) )
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state) ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
...@@ -88,8 +115,17 @@ class FlashBlocksparseAttnFun(torch.autograd.Function): ...@@ -88,8 +115,17 @@ class FlashBlocksparseAttnFun(torch.autograd.Function):
torch.cuda.set_rng_state(rng_state) torch.cuda.set_rng_state(rng_state)
# S_dmask is None, temporarily use another tensor just to get it running # S_dmask is None, temporarily use another tensor just to get it running
dqkv = _flash_blocksparse_attn_backward( dqkv = _flash_blocksparse_attn_backward(
dout, qkv, context, context, softmax_lse, cu_seqlens, blockmask, ctx.dropout_p, dout,
ctx.max_s, ctx.softmax_scale, ctx.causal qkv,
context,
context,
softmax_lse,
cu_seqlens,
blockmask,
ctx.dropout_p,
ctx.max_s,
ctx.softmax_scale,
ctx.causal,
) )
if rng_state is not None: if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state) torch.cuda.set_rng_state(cur_rng_state)
...@@ -99,7 +135,6 @@ class FlashBlocksparseAttnFun(torch.autograd.Function): ...@@ -99,7 +135,6 @@ class FlashBlocksparseAttnFun(torch.autograd.Function):
# We duplicate code to return both the output and the softmax for testing # We duplicate code to return both the output and the softmax for testing
# Returning both makes backward a bit slower, so we want to keep using the other version for speed. # Returning both makes backward a bit slower, so we want to keep using the other version for speed.
class FlashBlocksparseAttnFunWithS(torch.autograd.Function): class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal): def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
# Save rng_state because the backward pass is gonna regenerate the dropout mask # Save rng_state because the backward pass is gonna regenerate the dropout mask
...@@ -107,8 +142,14 @@ class FlashBlocksparseAttnFunWithS(torch.autograd.Function): ...@@ -107,8 +142,14 @@ class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5) softmax_scale = qkv.shape[-1] ** (-0.5)
context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward( context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(
qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal=causal, qkv,
return_softmax=True cu_seqlens,
blockmask,
dropout_p,
max_s,
softmax_scale,
causal=causal,
return_softmax=True,
) )
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state) ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
...@@ -124,18 +165,35 @@ class FlashBlocksparseAttnFunWithS(torch.autograd.Function): ...@@ -124,18 +165,35 @@ class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
cur_rng_state = torch.cuda.get_rng_state() cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state) torch.cuda.set_rng_state(rng_state)
dqkv = _flash_blocksparse_attn_backward( dqkv = _flash_blocksparse_attn_backward(
dout, qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, ctx.dropout_p, dout,
ctx.max_s, ctx.softmax_scale, ctx.causal qkv,
context,
S_dmask,
softmax_lse,
cu_seqlens,
blockmask,
ctx.dropout_p,
ctx.max_s,
ctx.softmax_scale,
ctx.causal,
) )
if rng_state is not None: if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state) torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None return dqkv, None, None, None, None, None, None
def flash_blocksparse_attn_func(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale=None, def flash_blocksparse_attn_func(
causal=False, return_attn_probs=False, convert_mask=True): qkv,
"""dropout_p should be set to 0.0 during evaluation cu_seqlens,
""" blockmask,
dropout_p,
max_s,
softmax_scale=None,
causal=False,
return_attn_probs=False,
convert_mask=True,
):
"""dropout_p should be set to 0.0 during evaluation"""
func = FlashBlocksparseAttnFun if not return_attn_probs else FlashBlocksparseAttnFunWithS func = FlashBlocksparseAttnFun if not return_attn_probs else FlashBlocksparseAttnFunWithS
if convert_mask: if convert_mask:
blockmask = convert_blockmask(blockmask, causal=causal) blockmask = convert_blockmask(blockmask, causal=causal)
......
...@@ -17,13 +17,15 @@ ...@@ -17,13 +17,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch import torch
from apex._autocast_utils import _cast_if_autocast_enabled from apex._autocast_utils import _cast_if_autocast_enabled
from apex.transformer.enums import AttnMaskType from apex.transformer.enums import AttnMaskType
from fused_softmax_lib import (
from fused_softmax_lib import scaled_masked_softmax_forward, scaled_masked_softmax_backward scaled_masked_softmax_backward,
from fused_softmax_lib import scaled_masked_softmax_get_batch_per_block scaled_masked_softmax_forward,
from fused_softmax_lib import scaled_upper_triang_masked_softmax_forward, scaled_upper_triang_masked_softmax_backward scaled_masked_softmax_get_batch_per_block,
scaled_upper_triang_masked_softmax_backward,
scaled_upper_triang_masked_softmax_forward,
)
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
...@@ -37,9 +39,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): ...@@ -37,9 +39,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, inputs, scale): def forward(ctx, inputs, scale):
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax_forward( softmax_results = scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0])
inputs, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t) ctx.save_for_backward(softmax_results, scale_t)
return softmax_results return softmax_results
...@@ -81,9 +81,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function): ...@@ -81,9 +81,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, output_grads): def backward(ctx, output_grads):
softmax_results, scale_t = ctx.saved_tensors softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_masked_softmax_backward( input_grads = scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0])
output_grads, softmax_results, scale_t[0]
)
return input_grads, None, None return input_grads, None, None
...@@ -122,9 +120,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -122,9 +120,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
self.input_in_fp16 = input_in_fp16 self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16 self.input_in_bf16 = input_in_bf16
if self.input_in_fp16 and self.input_in_bf16: if self.input_in_fp16 and self.input_in_bf16:
raise RuntimeError( raise RuntimeError("both fp16 and bf16 flags cannot be active at the same time.")
"both fp16 and bf16 flags cannot be active at the same time."
)
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
......
...@@ -4,11 +4,10 @@ ...@@ -4,11 +4,10 @@
from functools import partial from functools import partial
import torch.nn as nn import torch.nn as nn
from einops import rearrange
from torch import _assert from torch import _assert
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from einops import rearrange
try: try:
from flash_attn.ops.fused_dense import FusedDense from flash_attn.ops.fused_dense import FusedDense
except ImportError: except ImportError:
...@@ -16,8 +15,8 @@ except ImportError: ...@@ -16,8 +15,8 @@ except ImportError:
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding """2D Image to Patch Embedding"""
"""
def __init__( def __init__(
self, self,
img_size=224, img_size=224,
...@@ -38,7 +37,7 @@ class PatchEmbed(nn.Module): ...@@ -38,7 +37,7 @@ class PatchEmbed(nn.Module):
self.num_patches = self.grid_size[0] * self.grid_size[1] self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten self.flatten = flatten
if fused_bias_fc and FusedDense is None: if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed') raise ImportError("fused_dense is not installed")
linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense
self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias) self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias)
...@@ -46,11 +45,23 @@ class PatchEmbed(nn.Module): ...@@ -46,11 +45,23 @@ class PatchEmbed(nn.Module):
def forward(self, x): def forward(self, x):
_, _, H, W = x.shape _, _, H, W = x.shape
_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") _assert(
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") H == self.img_size[0],
x = self.proj(rearrange(x, 'b c (h p1) (w p2) -> b h w (c p1 p2)', f"Input image height ({H}) doesn't match model ({self.img_size[0]}).",
p1=self.patch_size[0], p2=self.patch_size[1])) )
_assert(
W == self.img_size[1],
f"Input image width ({W}) doesn't match model ({self.img_size[1]}).",
)
x = self.proj(
rearrange(
x,
"b c (h p1) (w p2) -> b h w (c p1 p2)",
p1=self.patch_size[0],
p2=self.patch_size[1],
)
)
if self.flatten: if self.flatten:
x = rearrange(x, 'b h w c -> b (h w) c') x = rearrange(x, "b h w c -> b (h w) c")
x = self.norm(x) x = self.norm(x)
return x return x
# Copyright (c) 2023, Tri Dao. # Copyright (c) 2023, Tri Dao.
from typing import Tuple, Optional
import math import math
from typing import Optional, Tuple
import rotary_emb
import torch import torch
from einops import rearrange, repeat from einops import rearrange, repeat
import rotary_emb
def rotate_half(x, interleaved=False): def rotate_half(x, interleaved=False):
if not interleaved: if not interleaved:
...@@ -16,7 +14,7 @@ def rotate_half(x, interleaved=False): ...@@ -16,7 +14,7 @@ def rotate_half(x, interleaved=False):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
else: else:
x1, x2 = x[..., ::2], x[..., 1::2] x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1), '... d two -> ... (d two)', two=2) return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
def apply_rotary_emb_torch(x, cos, sin, interleaved=False): def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
...@@ -26,14 +24,15 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False): ...@@ -26,14 +24,15 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
""" """
ro_dim = cos.shape[-1] * 2 ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1] assert ro_dim <= x.shape[-1]
cos = repeat(cos, 's d -> s 1 (2 d)') cos = repeat(cos, "s d -> s 1 (2 d)")
sin = repeat(sin, 's d -> s 1 (2 d)') sin = repeat(sin, "s d -> s 1 (2 d)")
return torch.cat([x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, return torch.cat(
x[..., ro_dim:]], dim=-1) [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
dim=-1,
)
class ApplyRotaryEmb(torch.autograd.Function): class ApplyRotaryEmb(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x, cos, sin, interleaved=False, inplace=False): def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
""" """
...@@ -57,10 +56,20 @@ class ApplyRotaryEmb(torch.autograd.Function): ...@@ -57,10 +56,20 @@ class ApplyRotaryEmb(torch.autograd.Function):
if inplace: if inplace:
o1, o2 = x1, x2 o1, o2 = x1, x2
else: else:
o1, o2 = (out_ro.chunk(2, dim=-1) if not interleaved o1, o2 = (
else (out_ro[..., ::2], out_ro[..., 1::2])) out_ro.chunk(2, dim=-1)
rotary_emb.apply_rotary(x1, x2, rearrange(cos[:seqlen], 's d -> s 1 d'), if not interleaved
rearrange(sin[:seqlen], 's d -> s 1 d'), o1, o2, False) else (out_ro[..., ::2], out_ro[..., 1::2])
)
rotary_emb.apply_rotary(
x1,
x2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
o1,
o2,
False,
)
if not inplace and rotary_dim < headdim: if not inplace and rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:]) out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin) ctx.save_for_backward(cos, sin)
...@@ -76,17 +85,28 @@ class ApplyRotaryEmb(torch.autograd.Function): ...@@ -76,17 +85,28 @@ class ApplyRotaryEmb(torch.autograd.Function):
rotary_dim *= 2 rotary_dim *= 2
inplace = ctx.inplace inplace = ctx.inplace
do_ro = do[..., :rotary_dim] do_ro = do[..., :rotary_dim]
do1, do2 = (do_ro.chunk(2, dim=-1) if not ctx.interleaved do1, do2 = (
else (do_ro[..., ::2], do_ro[..., 1::2])) do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2])
)
dx = torch.empty_like(do) if not inplace else do dx = torch.empty_like(do) if not inplace else do
if inplace: if inplace:
dx1, dx2 = do1, do2 dx1, dx2 = do1, do2
else: else:
dx_ro = dx[..., :rotary_dim] dx_ro = dx[..., :rotary_dim]
dx1, dx2 = (dx_ro.chunk(2, dim=-1) if not ctx.interleaved dx1, dx2 = (
else (dx_ro[..., ::2], dx_ro[..., 1::2])) dx_ro.chunk(2, dim=-1)
rotary_emb.apply_rotary(do1, do2, rearrange(cos[:seqlen], 's d -> s 1 d'), if not ctx.interleaved
rearrange(sin[:seqlen], 's d -> s 1 d'), dx1, dx2, True) else (dx_ro[..., ::2], dx_ro[..., 1::2])
)
rotary_emb.apply_rotary(
do1,
do2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
dx1,
dx2,
True,
)
if not inplace and rotary_dim < headdim: if not inplace and rotary_dim < headdim:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None, None return dx, None, None, None, None
...@@ -96,7 +116,6 @@ apply_rotary_emb_func = ApplyRotaryEmb.apply ...@@ -96,7 +116,6 @@ apply_rotary_emb_func = ApplyRotaryEmb.apply
class ApplyRotaryEmbQKV_(torch.autograd.Function): class ApplyRotaryEmbQKV_(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
""" """
...@@ -119,12 +138,26 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function): ...@@ -119,12 +138,26 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
q_ro = qkv[:, :, 0, :, :rotary_dim] q_ro = qkv[:, :, 0, :, :rotary_dim]
q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2]) q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2])
rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'), rotary_emb.apply_rotary(
rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False) q1,
q2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
q1,
q2,
False,
)
k_ro = qkv[:, :, 1, :, :rotary_dim] k_ro = qkv[:, :, 1, :, :rotary_dim]
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2]) k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'), rotary_emb.apply_rotary(
rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False) k1,
k2,
rearrange(cos_k[:seqlen], "s d -> s 1 d"),
rearrange(sin_k[:seqlen], "s d -> s 1 d"),
k1,
k2,
False,
)
ctx.save_for_backward(cos, sin, cos_k, sin_k) ctx.save_for_backward(cos, sin, cos_k, sin_k)
ctx.interleaved = interleaved ctx.interleaved = interleaved
return qkv return qkv
...@@ -136,15 +169,31 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function): ...@@ -136,15 +169,31 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
rotary_dim = cos.shape[-1] rotary_dim = cos.shape[-1]
rotary_dim *= 2 rotary_dim *= 2
dq_ro = dqkv[:, :, 0, :, :rotary_dim] dq_ro = dqkv[:, :, 0, :, :rotary_dim]
dq1, dq2 = (dq_ro.chunk(2, dim=-1) if not ctx.interleaved dq1, dq2 = (
else (dq_ro[..., ::2], dq_ro[..., 1::2])) dq_ro.chunk(2, dim=-1) if not ctx.interleaved else (dq_ro[..., ::2], dq_ro[..., 1::2])
rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'), )
rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True) rotary_emb.apply_rotary(
dq1,
dq2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
dq1,
dq2,
True,
)
dk_ro = dqkv[:, :, 1, :, :rotary_dim] dk_ro = dqkv[:, :, 1, :, :rotary_dim]
dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved dk1, dk2 = (
else (dk_ro[..., ::2], dk_ro[..., 1::2])) dk_ro.chunk(2, dim=-1) if not ctx.interleaved else (dk_ro[..., ::2], dk_ro[..., 1::2])
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'), )
rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1, dk2, True) rotary_emb.apply_rotary(
dk1,
dk2,
rearrange(cos_k[:seqlen], "s d -> s 1 d"),
rearrange(sin_k[:seqlen], "s d -> s 1 d"),
dk1,
dk2,
True,
)
return dqkv, None, None, None, None, None return dqkv, None, None, None, None, None
...@@ -152,7 +201,6 @@ apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply ...@@ -152,7 +201,6 @@ apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
class ApplyRotaryEmbKV_(torch.autograd.Function): class ApplyRotaryEmbKV_(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, kv, cos, sin, interleaved=False): def forward(ctx, kv, cos, sin, interleaved=False):
""" """
...@@ -171,9 +219,15 @@ class ApplyRotaryEmbKV_(torch.autograd.Function): ...@@ -171,9 +219,15 @@ class ApplyRotaryEmbKV_(torch.autograd.Function):
assert seqlen <= rotary_seqlen assert seqlen <= rotary_seqlen
k_ro = kv[:, :, 0, :, :rotary_dim] k_ro = kv[:, :, 0, :, :rotary_dim]
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2]) k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
rotary_emb.apply_rotary(k1, k2, rearrange(cos[:seqlen], 's d -> s 1 d'), rotary_emb.apply_rotary(
rearrange(sin[:seqlen], 's d -> s 1 d'), k1, k2, k1,
False) # conj=False since this is the forward pass k2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
k1,
k2,
False,
) # conj=False since this is the forward pass
ctx.save_for_backward(cos, sin) ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved ctx.interleaved = interleaved
return kv return kv
...@@ -185,11 +239,18 @@ class ApplyRotaryEmbKV_(torch.autograd.Function): ...@@ -185,11 +239,18 @@ class ApplyRotaryEmbKV_(torch.autograd.Function):
rotary_dim = cos.shape[-1] rotary_dim = cos.shape[-1]
rotary_dim *= 2 rotary_dim *= 2
dk_ro = dkv[:, :, 0, :, :rotary_dim] dk_ro = dkv[:, :, 0, :, :rotary_dim]
dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved dk1, dk2 = (
else (dk_ro[..., ::2], dk_ro[..., 1::2])) dk_ro.chunk(2, dim=-1) if not ctx.interleaved else (dk_ro[..., ::2], dk_ro[..., 1::2])
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos[:seqlen], 's d -> s 1 d'), )
rearrange(sin[:seqlen], 's d -> s 1 d'), dk1, dk2, rotary_emb.apply_rotary(
True) # conj=True since this is the backward pass dk1,
dk2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
dk1,
dk2,
True,
) # conj=True since this is the backward pass
return dkv, None, None, None return dkv, None, None, None
...@@ -214,8 +275,15 @@ class RotaryEmbedding(torch.nn.Module): ...@@ -214,8 +275,15 @@ class RotaryEmbedding(torch.nn.Module):
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
""" """
def __init__(self, dim: int, base=10000.0, interleaved=False, scale_base=None, def __init__(
pos_idx_in_fp32=True, device=None): self,
dim: int,
base=10000.0,
interleaved=False,
scale_base=None,
pos_idx_in_fp32=True,
device=None,
):
""" """
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style). of 1st half and 2nd half (GPT-NeoX style).
...@@ -239,8 +307,11 @@ class RotaryEmbedding(torch.nn.Module): ...@@ -239,8 +307,11 @@ class RotaryEmbedding(torch.nn.Module):
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.interleaved = interleaved self.interleaved = interleaved
self.scale_base = scale_base self.scale_base = scale_base
scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) scale = (
/ (1.4 * dim) if scale_base is not None else None) (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
if scale_base is not None
else None
)
self.register_buffer("scale", scale, persistent=False) self.register_buffer("scale", scale, persistent=False)
self._seq_len_cached = 0 self._seq_len_cached = 0
...@@ -250,17 +321,21 @@ class RotaryEmbedding(torch.nn.Module): ...@@ -250,17 +321,21 @@ class RotaryEmbedding(torch.nn.Module):
self._sin_k_cached = None self._sin_k_cached = None
def _compute_inv_freq(self, device=None): def _compute_inv_freq(self, device=None):
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, return 1.0 / (
dtype=torch.float32) / self.dim)) self.base
** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
)
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
# if we're on a new device (possibly due to tracing for instance), # if we're on a new device (possibly due to tracing for instance),
# or if we're switching from inference mode to training # or if we're switching from inference mode to training
if (seqlen > self._seq_len_cached or self._cos_cached.device != device if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype or self._cos_cached.dtype != dtype
or (self.training and self._cos_cached.is_inference())): or (self.training and self._cos_cached.is_inference())
):
self._seq_len_cached = seqlen self._seq_len_cached = seqlen
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
# And the output of arange can be quite large, so bf16 would lose a lot of precision. # And the output of arange can be quite large, so bf16 would lose a lot of precision.
...@@ -285,17 +360,20 @@ class RotaryEmbedding(torch.nn.Module): ...@@ -285,17 +360,20 @@ class RotaryEmbedding(torch.nn.Module):
self._cos_cached = torch.cos(freqs).to(dtype) self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype)
else: else:
power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) power = (
- seqlen // 2) / self.scale_base) torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
scale = self.scale.to(device=power.device) ** rearrange(power, 's -> s 1') - seqlen // 2
) / self.scale_base
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
# We want the multiplication by scale to happen in fp32 # We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(dtype) self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(dtype) self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
def forward(self, qkv: torch.Tensor, kv: Optional[torch.Tensor] = None, def forward(
seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: self, qkv: torch.Tensor, kv: Optional[torch.Tensor] = None, seqlen_offset: int = 0
) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none, qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
else it's just q of shape (batch, seqlen, nheads, headdim) else it's just q of shape (batch, seqlen, nheads, headdim)
...@@ -308,29 +386,43 @@ class RotaryEmbedding(torch.nn.Module): ...@@ -308,29 +386,43 @@ class RotaryEmbedding(torch.nn.Module):
if kv is None: if kv is None:
if self.scale is None: if self.scale is None:
return apply_rotary_emb_qkv_( return apply_rotary_emb_qkv_(
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], qkv,
None, None, self.interleaved self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:],
None,
None,
self.interleaved,
) )
else: else:
return apply_rotary_emb_qkv_( return apply_rotary_emb_qkv_(
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], qkv,
self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:], self._cos_cached[seqlen_offset:],
self.interleaved self._sin_cached[seqlen_offset:],
self._cos_k_cached[seqlen_offset:],
self._sin_k_cached[seqlen_offset:],
self.interleaved,
) )
else: else:
q = qkv q = qkv
q = apply_rotary_emb_func( q = apply_rotary_emb_func(
q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], q,
self.interleaved, True self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:],
self.interleaved,
True,
) )
if self.scale is None: if self.scale is None:
kv = apply_rotary_emb_kv_( kv = apply_rotary_emb_kv_(
kv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], kv,
self.interleaved self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:],
self.interleaved,
) )
else: else:
kv = apply_rotary_emb_kv_( kv = apply_rotary_emb_kv_(
kv, self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:], kv,
self.interleaved self._cos_k_cached[seqlen_offset:],
self._sin_k_cached[seqlen_offset:],
self.interleaved,
) )
return q, kv return q, kv
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py # The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
import torch import torch
import torch.nn as nn import torch.nn as nn
import xentropy_cuda_lib import xentropy_cuda_lib
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
...@@ -17,10 +16,16 @@ if "all_gather_into_tensor" not in dir(torch.distributed): ...@@ -17,10 +16,16 @@ if "all_gather_into_tensor" not in dir(torch.distributed):
class SoftmaxCrossEntropyLossFn(torch.autograd.Function): class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, logits, labels, smoothing=0.0, ignored_index=-100, inplace_backward=False, def forward(
process_group=None): ctx,
logits,
labels,
smoothing=0.0,
ignored_index=-100,
inplace_backward=False,
process_group=None,
):
""" """
logits: (batch, vocab_size) logits: (batch, vocab_size)
labels: (batch,) labels: (batch,)
...@@ -34,7 +39,7 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function): ...@@ -34,7 +39,7 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
if world_size == 1: if world_size == 1:
losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing) losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing)
losses.masked_fill_(labels==ignored_index, 0) losses.masked_fill_(labels == ignored_index, 0)
labels_local = labels labels_local = labels
else: else:
rank = torch.distributed.get_rank(process_group) rank = torch.distributed.get_rank(process_group)
...@@ -48,8 +53,9 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function): ...@@ -48,8 +53,9 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
# For tensor parallel cross entropy with smoothing, we want to pass in the total number # For tensor parallel cross entropy with smoothing, we want to pass in the total number
# of classes so that smoothing can be applied correctly. If total_classes=-1, use the # of classes so that smoothing can be applied correctly. If total_classes=-1, use the
# last dimension of the input tensor. # last dimension of the input tensor.
losses, lse_local = xentropy_cuda_lib.forward(logits, labels_local, smoothing, losses, lse_local = xentropy_cuda_lib.forward(
world_size * vocab_size) logits, labels_local, smoothing, world_size * vocab_size
)
assert lse_local.shape == (batch,) assert lse_local.shape == (batch,)
assert losses.shape == (batch,) assert losses.shape == (batch,)
losses.masked_fill_(ignored_mask, 0) losses.masked_fill_(ignored_mask, 0)
...@@ -61,10 +67,12 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function): ...@@ -61,10 +67,12 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
# For labels not in the vocab of this partition, losses contains # For labels not in the vocab of this partition, losses contains
# 0.1 * (lse_local - sum logit / total_classes). # 0.1 * (lse_local - sum logit / total_classes).
lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype, lse_allgather = torch.empty(
device=lse_local.device) world_size, batch, dtype=lse_local.dtype, device=lse_local.device
torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(), )
group=process_group) torch.distributed.all_gather_into_tensor(
lse_allgather, lse_local.contiguous(), group=process_group
)
handle_losses = torch.distributed.all_reduce( handle_losses = torch.distributed.all_reduce(
losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
) )
...@@ -74,16 +82,18 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function): ...@@ -74,16 +82,18 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
# If there's smoothing=0.1, the total losses are # If there's smoothing=0.1, the total losses are
# 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes) # 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
# We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes). # We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
rank_per_sample = torch.div(labels, vocab_size, rounding_mode='floor') rank_per_sample = torch.div(labels, vocab_size, rounding_mode="floor")
lse_local = lse_allgather[rank_per_sample, lse_local = lse_allgather[
torch.arange(batch, device=lse_allgather.device)] rank_per_sample, torch.arange(batch, device=lse_allgather.device)
]
handle_losses.wait() handle_losses.wait()
if smoothing == 0.0: if smoothing == 0.0:
losses += lse - lse_local losses += lse - lse_local
else: else:
losses += ((1 - smoothing) * (lse - lse_local) losses += (1 - smoothing) * (lse - lse_local) + smoothing * (
+ smoothing * (lse - lse_allgather.sum(dim=0))) lse - lse_allgather.sum(dim=0)
)
losses.masked_fill_(ignored_mask, 0) losses.masked_fill_(ignored_mask, 0)
ctx.save_for_backward(logits, lse, labels_local) ctx.save_for_backward(logits, lse, labels_local)
...@@ -96,19 +106,24 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function): ...@@ -96,19 +106,24 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
def backward(ctx, grad_loss): def backward(ctx, grad_loss):
logits, lse, labels = ctx.saved_tensors logits, lse, labels = ctx.saved_tensors
grad_loss = grad_loss.contiguous() grad_loss = grad_loss.contiguous()
grad_loss.masked_fill_(labels==ctx.ignored_index, 0) grad_loss.masked_fill_(labels == ctx.ignored_index, 0)
grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, lse, labels, grad_logits = xentropy_cuda_lib.backward(
ctx.smoothing, ctx.inplace_backward, grad_loss, logits, lse, labels, ctx.smoothing, ctx.inplace_backward, ctx.total_classes
ctx.total_classes) )
return grad_logits, None, None, None, None, None, None return grad_logits, None, None, None, None, None, None
class CrossEntropyLoss(nn.Module): class CrossEntropyLoss(nn.Module):
def __init__(
def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0, self,
inplace_backward=False, process_group=None): ignore_index=-100,
reduction="mean",
label_smoothing=0.0,
inplace_backward=False,
process_group=None,
):
super().__init__() super().__init__()
if reduction not in ['mean', 'none']: if reduction not in ["mean", "none"]:
raise NotImplementedError("Only support reduction = 'mean' or 'none'") raise NotImplementedError("Only support reduction = 'mean' or 'none'")
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.reduction = reduction self.reduction = reduction
...@@ -120,10 +135,14 @@ class CrossEntropyLoss(nn.Module): ...@@ -120,10 +135,14 @@ class CrossEntropyLoss(nn.Module):
assert input.is_cuda and target.is_cuda assert input.is_cuda and target.is_cuda
# SoftmaxCrossEntropyLoss implicitly casts to float # SoftmaxCrossEntropyLoss implicitly casts to float
loss = SoftmaxCrossEntropyLossFn.apply( loss = SoftmaxCrossEntropyLossFn.apply(
input, target, self.label_smoothing, self.ignore_index, self.inplace_backward, input,
self.process_group target,
self.label_smoothing,
self.ignore_index,
self.inplace_backward,
self.process_group,
) )
if self.reduction == 'mean': if self.reduction == "mean":
return loss.sum() / (target != self.ignore_index).sum() return loss.sum() / (target != self.ignore_index).sum()
else: else:
return loss return loss
...@@ -5,29 +5,32 @@ ...@@ -5,29 +5,32 @@
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
import re
import logging import logging
from functools import partial import re
from collections.abc import Sequence
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Sequence
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import BertConfig
from transformers.models.bert.modeling_bert import BaseModelOutputWithPoolingAndCrossAttentions
from transformers.models.bert.modeling_bert import BertForPreTrainingOutput
from einops import rearrange from einops import rearrange
from transformers import BertConfig
from flash_attn.modules.mha import MHA from transformers.models.bert.modeling_bert import (
from flash_attn.modules.mlp import Mlp, FusedMLP BaseModelOutputWithPoolingAndCrossAttentions,
BertForPreTrainingOutput,
)
from flash_attn.bert_padding import (
index_first_axis,
index_first_axis_residual,
pad_input,
unpad_input,
)
from flash_attn.modules.block import Block from flash_attn.modules.block import Block
from flash_attn.modules.embedding import BertEmbeddings from flash_attn.modules.embedding import BertEmbeddings
from flash_attn.bert_padding import unpad_input, pad_input from flash_attn.modules.mha import MHA
from flash_attn.bert_padding import index_first_axis, index_first_axis_residual from flash_attn.modules.mlp import FusedMLP, Mlp
from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.utils.pretrained import state_dict_from_pretrained
try: try:
...@@ -50,48 +53,63 @@ logger = logging.getLogger(__name__) ...@@ -50,48 +53,63 @@ logger = logging.getLogger(__name__)
def create_mixer_cls(config, cross_attn=False, return_residual=False): def create_mixer_cls(config, cross_attn=False, return_residual=False):
use_flash_attn = getattr(config, 'use_flash_attn', False) use_flash_attn = getattr(config, "use_flash_attn", False)
fused_bias_fc = getattr(config, 'fused_bias_fc', False) fused_bias_fc = getattr(config, "fused_bias_fc", False)
rotary_kwargs = {} rotary_kwargs = {}
if config.position_embedding_type == "rotary": if config.position_embedding_type == "rotary":
rotary_kwargs["rotary_emb_dim"] = getattr(config, "rotary_emb_dim", config.hidden_size) rotary_kwargs["rotary_emb_dim"] = getattr(config, "rotary_emb_dim", config.hidden_size)
rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0) rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
rotary_kwargs["rotary_emb_scale_base"] = getattr(config, "rotary_emb_scale_base", None) rotary_kwargs["rotary_emb_scale_base"] = getattr(config, "rotary_emb_scale_base", None)
rotary_kwargs["rotary_emb_interleaved"] = getattr(config, "rotary_emb_interleaved", False) rotary_kwargs["rotary_emb_interleaved"] = getattr(config, "rotary_emb_interleaved", False)
mixer_cls = partial(MHA, num_heads=config.num_attention_heads, cross_attn=cross_attn, mixer_cls = partial(
dropout=config.attention_probs_dropout_prob, causal=False, MHA,
fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn, num_heads=config.num_attention_heads,
return_residual=return_residual, **rotary_kwargs) cross_attn=cross_attn,
dropout=config.attention_probs_dropout_prob,
causal=False,
fused_bias_fc=fused_bias_fc,
use_flash_attn=use_flash_attn,
return_residual=return_residual,
**rotary_kwargs,
)
return mixer_cls return mixer_cls
def create_mlp_cls(config, layer_idx=None, return_residual=False): def create_mlp_cls(config, layer_idx=None, return_residual=False):
inner_dim = config.intermediate_size inner_dim = config.intermediate_size
fused_mlp = getattr(config, 'fused_mlp', False) fused_mlp = getattr(config, "fused_mlp", False)
if fused_mlp: if fused_mlp:
assert config.hidden_act in ['gelu_new', 'gelu_fast'], ('fused_mlp only ' assert config.hidden_act in ["gelu_new", "gelu_fast"], (
'supports approximate gelu') "fused_mlp only " "supports approximate gelu"
)
if not fused_mlp: if not fused_mlp:
approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none' approximate = "tanh" if config.hidden_act in ["gelu_new", "gelu_fast"] else "none"
mlp_cls = partial(Mlp, hidden_features=inner_dim, mlp_cls = partial(
Mlp,
hidden_features=inner_dim,
activation=partial(F.gelu, approximate=approximate), activation=partial(F.gelu, approximate=approximate),
return_residual=return_residual) return_residual=return_residual,
)
else: else:
if FusedMLP is None: if FusedMLP is None:
raise ImportError('fused_dense is not installed') raise ImportError("fused_dense is not installed")
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0) mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
if isinstance(mlp_checkpoint_lvl, Sequence): if isinstance(mlp_checkpoint_lvl, Sequence):
assert layer_idx is not None assert layer_idx is not None
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
mlp_cls = partial(FusedMLP, hidden_features=inner_dim, mlp_cls = partial(
checkpoint_lvl=mlp_checkpoint_lvl, return_residual=return_residual) FusedMLP,
hidden_features=inner_dim,
checkpoint_lvl=mlp_checkpoint_lvl,
return_residual=return_residual,
)
return mlp_cls return mlp_cls
def create_block(config, layer_idx=None): def create_block(config, layer_idx=None):
last_layer_subset = getattr(config, 'last_layer_subset', False) last_layer_subset = getattr(config, "last_layer_subset", False)
cross_attn=last_layer_subset and layer_idx == config.num_hidden_layers - 1 cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1
# TD [2022-12-19]: For cross attention (last layer), we actually want to return the # TD [2022-12-19]: For cross attention (last layer), we actually want to return the
# residual x_kv, not residual x. But it's annoying to change the API (and it only affects # residual x_kv, not residual x. But it's annoying to change the API (and it only affects
# one layer) so we just choose not to return residual in this case. # one layer) so we just choose not to return residual in this case.
...@@ -99,11 +117,17 @@ def create_block(config, layer_idx=None): ...@@ -99,11 +117,17 @@ def create_block(config, layer_idx=None):
mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual) mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual) mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps) norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls, block = Block(
prenorm=False, resid_dropout1=config.hidden_dropout_prob, config.hidden_size,
mixer_cls,
mlp_cls,
norm_cls=norm_cls,
prenorm=False,
resid_dropout1=config.hidden_dropout_prob,
resid_dropout2=config.hidden_dropout_prob, resid_dropout2=config.hidden_dropout_prob,
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False), fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
return_residual=return_residual) return_residual=return_residual,
)
return block return block
...@@ -120,12 +144,12 @@ def _init_weights(module, initializer_range=0.02): ...@@ -120,12 +144,12 @@ def _init_weights(module, initializer_range=0.02):
class BertEncoder(nn.Module): class BertEncoder(nn.Module):
def __init__(self, config: BertConfig): def __init__(self, config: BertConfig):
super().__init__() super().__init__()
self.use_flash_attn = getattr(config, 'use_flash_attn', False) self.use_flash_attn = getattr(config, "use_flash_attn", False)
self.layers = nn.ModuleList([create_block(config, layer_idx=i) self.layers = nn.ModuleList(
for i in range(config.num_hidden_layers)]) [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
)
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None): def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
"""If subset_mask is not None, we only want output for the subset of the sequence. """If subset_mask is not None, we only want output for the subset of the sequence.
...@@ -133,8 +157,9 @@ class BertEncoder(nn.Module): ...@@ -133,8 +157,9 @@ class BertEncoder(nn.Module):
subset_mask: (batch, seqlen), dtype=torch.bool subset_mask: (batch, seqlen), dtype=torch.bool
""" """
if key_padding_mask is None or not self.use_flash_attn: if key_padding_mask is None or not self.use_flash_attn:
mixer_kwargs = ({'key_padding_mask': key_padding_mask} mixer_kwargs = (
if key_padding_mask is not None else None) {"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
)
for layer in self.layers: for layer in self.layers:
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
if subset_mask is not None: if subset_mask is not None:
...@@ -144,7 +169,7 @@ class BertEncoder(nn.Module): ...@@ -144,7 +169,7 @@ class BertEncoder(nn.Module):
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input( hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
hidden_states, key_padding_mask hidden_states, key_padding_mask
) )
mixer_kwargs = {'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen_in_batch} mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
if subset_mask is None: if subset_mask is None:
for layer in self.layers: for layer in self.layers:
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
...@@ -153,33 +178,40 @@ class BertEncoder(nn.Module): ...@@ -153,33 +178,40 @@ class BertEncoder(nn.Module):
for layer in self.layers[:-1]: for layer in self.layers[:-1]:
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
if key_padding_mask is not None: if key_padding_mask is not None:
subset_idx = torch.nonzero(subset_mask[key_padding_mask], as_tuple=False).flatten() subset_idx = torch.nonzero(
subset_mask[key_padding_mask], as_tuple=False
).flatten()
subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32) subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32)
subset_cu_seqlens = F.pad(torch.cumsum(subset_seqlens, dim=0, subset_cu_seqlens = F.pad(
dtype=torch.torch.int32), (1, 0)) torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
)
else: else:
subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten() subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32) subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
subset_cu_seqlens = F.pad(torch.cumsum(subset_seqlens, dim=0, subset_cu_seqlens = F.pad(
dtype=torch.torch.int32), (1, 0)) torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
)
hidden_states_subset, hidden_states = index_first_axis_residual( hidden_states_subset, hidden_states = index_first_axis_residual(
hidden_states, subset_idx hidden_states, subset_idx
) )
# It's ok to set max_seqlen_q to be much larger # It's ok to set max_seqlen_q to be much larger
mixer_kwargs = {'x_kv': hidden_states, mixer_kwargs = {
'cu_seqlens': subset_cu_seqlens, 'max_seqlen': max_seqlen_in_batch, "x_kv": hidden_states,
'cu_seqlens_k': cu_seqlens, 'max_seqlen_k': max_seqlen_in_batch} "cu_seqlens": subset_cu_seqlens,
"max_seqlen": max_seqlen_in_batch,
"cu_seqlens_k": cu_seqlens,
"max_seqlen_k": max_seqlen_in_batch,
}
hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs) hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
return hidden_states return hidden_states
class BertPooler(nn.Module): class BertPooler(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
fused_bias_fc = getattr(config, 'fused_bias_fc', False) fused_bias_fc = getattr(config, "fused_bias_fc", False)
if fused_bias_fc and FusedDense is None: if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed') raise ImportError("fused_dense is not installed")
linear_cls = nn.Linear if not fused_bias_fc else FusedDense linear_cls = nn.Linear if not fused_bias_fc else FusedDense
self.dense = linear_cls(config.hidden_size, config.hidden_size) self.dense = linear_cls(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
...@@ -194,18 +226,17 @@ class BertPooler(nn.Module): ...@@ -194,18 +226,17 @@ class BertPooler(nn.Module):
class BertPredictionHeadTransform(nn.Module): class BertPredictionHeadTransform(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
fused_bias_fc = getattr(config, 'fused_bias_fc', False) fused_bias_fc = getattr(config, "fused_bias_fc", False)
if fused_bias_fc and FusedDense is None: if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed') raise ImportError("fused_dense is not installed")
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False) self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
if self.fused_dropout_add_ln and layer_norm is None: if self.fused_dropout_add_ln and layer_norm is None:
raise ImportError('dropout_add_layer_norm is not installed') raise ImportError("dropout_add_layer_norm is not installed")
linear_cls = nn.Linear if not fused_bias_fc else FusedDense linear_cls = nn.Linear if not fused_bias_fc else FusedDense
self.dense = linear_cls(config.hidden_size, config.hidden_size) self.dense = linear_cls(config.hidden_size, config.hidden_size)
approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none' approximate = "tanh" if config.hidden_act in ["gelu_new", "gelu_fast"] else "none"
self.transform_act_fn = nn.GELU(approximate=approximate) self.transform_act_fn = nn.GELU(approximate=approximate)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
...@@ -215,18 +246,18 @@ class BertPredictionHeadTransform(nn.Module): ...@@ -215,18 +246,18 @@ class BertPredictionHeadTransform(nn.Module):
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln:
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
else: else:
hidden_states = layer_norm(hidden_states, self.layer_norm.weight, self.layer_norm.bias, hidden_states = layer_norm(
self.layer_norm.eps) hidden_states, self.layer_norm.weight, self.layer_norm.bias, self.layer_norm.eps
)
return hidden_states return hidden_states
class BertLMPredictionHead(nn.Module): class BertLMPredictionHead(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
fused_bias_fc = getattr(config, 'fused_bias_fc', False) fused_bias_fc = getattr(config, "fused_bias_fc", False)
if fused_bias_fc and FusedDense is None: if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed') raise ImportError("fused_dense is not installed")
linear_cls = nn.Linear if not fused_bias_fc else FusedDense linear_cls = nn.Linear if not fused_bias_fc else FusedDense
self.transform = BertPredictionHeadTransform(config) self.transform = BertPredictionHeadTransform(config)
...@@ -254,9 +285,10 @@ class BertPreTrainingHeads(nn.Module): ...@@ -254,9 +285,10 @@ class BertPreTrainingHeads(nn.Module):
class BertPreTrainedModel(nn.Module): class BertPreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__() super().__init__()
if not isinstance(config, BertConfig): if not isinstance(config, BertConfig):
...@@ -265,7 +297,8 @@ class BertPreTrainedModel(nn.Module): ...@@ -265,7 +297,8 @@ class BertPreTrainedModel(nn.Module):
"To create a model from a Google pretrained model use " "To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__ self.__class__.__name__, self.__class__.__name__
)) )
)
self.config = config self.config = config
@classmethod @classmethod
...@@ -287,28 +320,33 @@ class BertPreTrainedModel(nn.Module): ...@@ -287,28 +320,33 @@ class BertPreTrainedModel(nn.Module):
""" """
# Instantiate model. # Instantiate model.
model = cls(config, *inputs, **kwargs) model = cls(config, *inputs, **kwargs)
load_return = model.load_state_dict(remap_state_dict(state_dict_from_pretrained(model_name), load_return = model.load_state_dict(
config), strict=False) remap_state_dict(state_dict_from_pretrained(model_name), config), strict=False
)
logger.info(load_return) logger.info(load_return)
return model return model
class BertModel(BertPreTrainedModel): class BertModel(BertPreTrainedModel):
def __init__(self, config: BertConfig, add_pooling_layer=True): def __init__(self, config: BertConfig, add_pooling_layer=True):
super().__init__(config) super().__init__(config)
self.pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
if config.vocab_size % self.pad_vocab_size_multiple != 0: if config.vocab_size % self.pad_vocab_size_multiple != 0:
config.vocab_size += (self.pad_vocab_size_multiple config.vocab_size += self.pad_vocab_size_multiple - (
- (config.vocab_size % self.pad_vocab_size_multiple)) config.vocab_size % self.pad_vocab_size_multiple
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False) )
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
if self.fused_dropout_add_ln and layer_norm is None: if self.fused_dropout_add_ln and layer_norm is None:
raise ImportError('dropout_add_layer_norm is not installed') raise ImportError("dropout_add_layer_norm is not installed")
assert config.hidden_act in ['gelu', 'gelu_new', 'gelu_fast'] assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast"]
self.embeddings = BertEmbeddings(config.hidden_size, config.vocab_size, self.embeddings = BertEmbeddings(
config.max_position_embeddings, config.type_vocab_size, config.hidden_size,
padding_idx=config.pad_token_id) config.vocab_size,
config.max_position_embeddings,
config.type_vocab_size,
padding_idx=config.pad_token_id,
)
self.emb_drop = nn.Dropout(config.hidden_dropout_prob) self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.encoder = BertEncoder(config) self.encoder = BertEncoder(config)
...@@ -316,36 +354,46 @@ class BertModel(BertPreTrainedModel): ...@@ -316,36 +354,46 @@ class BertModel(BertPreTrainedModel):
self.apply(partial(_init_weights, initializer_range=config.initializer_range)) self.apply(partial(_init_weights, initializer_range=config.initializer_range))
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, def forward(
masked_tokens_mask=None): self,
input_ids,
position_ids=None,
token_type_ids=None,
attention_mask=None,
masked_tokens_mask=None,
):
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining), """If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
we only want the output for the masked tokens. This means that we only compute the last we only want the output for the masked tokens. This means that we only compute the last
layer output for these tokens. layer output for these tokens.
masked_tokens_mask: (batch, seqlen), dtype=torch.bool masked_tokens_mask: (batch, seqlen), dtype=torch.bool
""" """
hidden_states = self.embeddings(input_ids, position_ids=position_ids, hidden_states = self.embeddings(
token_type_ids=token_type_ids) input_ids, position_ids=position_ids, token_type_ids=token_type_ids
)
# TD [2022-12:18]: Don't need to force residual in fp32 # TD [2022-12:18]: Don't need to force residual in fp32
# BERT puts embedding LayerNorm before embedding dropout. # BERT puts embedding LayerNorm before embedding dropout.
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln:
hidden_states = self.emb_ln(hidden_states) hidden_states = self.emb_ln(hidden_states)
else: else:
hidden_states = layer_norm(hidden_states, self.emb_ln.weight, self.emb_ln.bias, hidden_states = layer_norm(
self.emb_ln.eps) hidden_states, self.emb_ln.weight, self.emb_ln.bias, self.emb_ln.eps
)
hidden_states = self.emb_drop(hidden_states) hidden_states = self.emb_drop(hidden_states)
if masked_tokens_mask is not None: if masked_tokens_mask is not None:
batch_size, seqlen = input_ids.shape[:2] batch_size, seqlen = input_ids.shape[:2]
# We also need the first column for the CLS token # We also need the first column for the CLS token
first_col_mask = torch.zeros(batch_size, seqlen, dtype=torch.bool, first_col_mask = torch.zeros(
device=input_ids.device) batch_size, seqlen, dtype=torch.bool, device=input_ids.device
)
first_col_mask[:, 0] = True first_col_mask[:, 0] = True
subset_mask = masked_tokens_mask | first_col_mask subset_mask = masked_tokens_mask | first_col_mask
else: else:
subset_mask = None subset_mask = None
sequence_output = self.encoder(hidden_states, key_padding_mask=attention_mask, sequence_output = self.encoder(
subset_mask=subset_mask) hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
)
if masked_tokens_mask is None: if masked_tokens_mask is None:
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
...@@ -358,8 +406,7 @@ class BertModel(BertPreTrainedModel): ...@@ -358,8 +406,7 @@ class BertModel(BertPreTrainedModel):
else: else:
pool_input = sequence_output[first_col_mask[subset_mask]] pool_input = sequence_output[first_col_mask[subset_mask]]
sequence_output = sequence_output[masked_tokens_mask[subset_mask]] sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
pooled_output = (self.pooler(pool_input, pool=False) pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None
if self.pooler is not None else None)
return BaseModelOutputWithPoolingAndCrossAttentions( return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
...@@ -368,22 +415,24 @@ class BertModel(BertPreTrainedModel): ...@@ -368,22 +415,24 @@ class BertModel(BertPreTrainedModel):
class BertForPreTraining(BertPreTrainedModel): class BertForPreTraining(BertPreTrainedModel):
def __init__(self, config: BertConfig): def __init__(self, config: BertConfig):
super().__init__(config) super().__init__(config)
# If dense_seq_output, we only need to pass the hidden states for the masked out tokens # If dense_seq_output, we only need to pass the hidden states for the masked out tokens
# (around 15%) to the classifier heads. # (around 15%) to the classifier heads.
self.dense_seq_output = getattr(config, 'dense_seq_output', False) self.dense_seq_output = getattr(config, "dense_seq_output", False)
# If last_layer_subset, we only need the compute the last layer for a subset of tokens # If last_layer_subset, we only need the compute the last layer for a subset of tokens
# (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction). # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
self.last_layer_subset = getattr(config, 'last_layer_subset', False) self.last_layer_subset = getattr(config, "last_layer_subset", False)
if self.last_layer_subset: if self.last_layer_subset:
assert self.dense_seq_output, 'last_layer_subset requires dense_seq_output' assert self.dense_seq_output, "last_layer_subset requires dense_seq_output"
use_xentropy = getattr(config, 'use_xentropy', False) use_xentropy = getattr(config, "use_xentropy", False)
if use_xentropy and CrossEntropyLoss is None: if use_xentropy and CrossEntropyLoss is None:
raise ImportError('xentropy_cuda is not installed') raise ImportError("xentropy_cuda is not installed")
loss_cls = (nn.CrossEntropyLoss if not use_xentropy loss_cls = (
else partial(CrossEntropyLoss, inplace_backward=True)) nn.CrossEntropyLoss
if not use_xentropy
else partial(CrossEntropyLoss, inplace_backward=True)
)
self.bert = BertModel(config) self.bert = BertModel(config)
self.cls = BertPreTrainingHeads(config) self.cls = BertPreTrainingHeads(config)
...@@ -397,8 +446,15 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -397,8 +446,15 @@ class BertForPreTraining(BertPreTrainedModel):
def tie_weights(self): def tie_weights(self):
self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, def forward(
labels=None, next_sentence_label=None): self,
input_ids,
position_ids=None,
token_type_ids=None,
attention_mask=None,
labels=None,
next_sentence_label=None,
):
""" """
If labels are provided, they must be 0 for masked out tokens (as specified in the attention If labels are provided, they must be 0 for masked out tokens (as specified in the attention
mask). mask).
...@@ -414,28 +470,38 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -414,28 +470,38 @@ class BertForPreTraining(BertPreTrainedModel):
""" """
masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
outputs = self.bert( outputs = self.bert(
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask.bool() if attention_mask is not None else None, attention_mask=attention_mask.bool() if attention_mask is not None else None,
masked_tokens_mask=masked_tokens_mask masked_tokens_mask=masked_tokens_mask,
) )
sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
if self.dense_seq_output and labels is not None: if self.dense_seq_output and labels is not None:
masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten() masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
if not self.last_layer_subset: if not self.last_layer_subset:
sequence_output = index_first_axis(rearrange(sequence_output, 'b s d -> (b s) d'), sequence_output = index_first_axis(
masked_token_idx) rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx
)
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
total_loss = None total_loss = None
if labels is not None and next_sentence_label is not None: if labels is not None and next_sentence_label is not None:
if self.dense_seq_output and labels is not None: # prediction_scores are already flattened if (
masked_lm_loss = self.mlm_loss(prediction_scores, self.dense_seq_output and labels is not None
labels.flatten()[masked_token_idx]) ): # prediction_scores are already flattened
masked_lm_loss = self.mlm_loss(
prediction_scores, labels.flatten()[masked_token_idx]
)
else: else:
masked_lm_loss = self.mlm_loss(rearrange(prediction_scores, '... v -> (...) v'), masked_lm_loss = self.mlm_loss(
rearrange(labels, '... -> (...)')) rearrange(prediction_scores, "... v -> (...) v"),
next_sentence_loss = self.nsp_loss(rearrange(seq_relationship_score, '... t -> (...) t'), rearrange(labels, "... -> (...)"),
rearrange(next_sentence_label, '... -> (...)')) )
next_sentence_loss = self.nsp_loss(
rearrange(seq_relationship_score, "... t -> (...) t"),
rearrange(next_sentence_label, "... -> (...)"),
)
total_loss = masked_lm_loss.float() + next_sentence_loss.float() total_loss = masked_lm_loss.float() + next_sentence_loss.float()
return BertForPreTrainingOutput( return BertForPreTrainingOutput(
...@@ -448,83 +514,106 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -448,83 +514,106 @@ class BertForPreTraining(BertPreTrainedModel):
def remap_state_dict(state_dict, config): def remap_state_dict(state_dict, config):
# LayerNorm # LayerNorm
def key_mapping_ln_gamma_beta(key): def key_mapping_ln_gamma_beta(key):
key = re.sub(r'LayerNorm.gamma$', 'LayerNorm.weight', key) key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
key = re.sub(r'LayerNorm.beta$', 'LayerNorm.bias', key) key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
return key return key
state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
# Layers # Layers
def key_mapping_layers(key): def key_mapping_layers(key):
return re.sub(r'^bert.encoder.layer.', 'bert.encoder.layers.', key) return re.sub(r"^bert.encoder.layer.", "bert.encoder.layers.", key)
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# LayerNorm # LayerNorm
def key_mapping_ln(key): def key_mapping_ln(key):
key = re.sub(r'^bert.embeddings.LayerNorm.', 'bert.emb_ln.', key) key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key)
key = re.sub(r'^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)', key = re.sub(
r'bert.encoder.layers.\1.norm1.\2', key) r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
key = re.sub(r'^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)', r"bert.encoder.layers.\1.norm1.\2",
r'bert.encoder.layers.\1.norm2.\2', key) key,
key = re.sub(r'^cls.predictions.transform.LayerNorm.(weight|bias)', )
r'cls.predictions.transform.layer_norm.\1', key) key = re.sub(
r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
r"bert.encoder.layers.\1.norm2.\2",
key,
)
key = re.sub(
r"^cls.predictions.transform.LayerNorm.(weight|bias)",
r"cls.predictions.transform.layer_norm.\1",
key,
)
return key return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP # MLP
def key_mapping_mlp(key): def key_mapping_mlp(key):
key = re.sub(r'^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)', key = re.sub(
r'bert.encoder.layers.\1.mlp.fc1.\2', key) r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
key = re.sub(r'^bert.encoder.layers.(\d+).output.dense.(weight|bias)', r"bert.encoder.layers.\1.mlp.fc1.\2",
r'bert.encoder.layers.\1.mlp.fc2.\2', key) key,
)
key = re.sub(
r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)",
r"bert.encoder.layers.\1.mlp.fc2.\2",
key,
)
return key return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention # Attention
last_layer_subset = getattr(config, 'last_layer_subset', False) last_layer_subset = getattr(config, "last_layer_subset", False)
for d in range(config.num_hidden_layers): for d in range(config.num_hidden_layers):
Wq = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.query.weight') Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight")
Wk = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.key.weight') Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight")
Wv = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.value.weight') Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight")
bq = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.query.bias') bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias")
bk = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.key.bias') bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
bv = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.value.bias') bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
if not (last_layer_subset and d == config.num_hidden_layers - 1): if not (last_layer_subset and d == config.num_hidden_layers - 1):
state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.weight'] = torch.cat( state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
[Wq, Wk, Wv], dim=0 [Wq, Wk, Wv], dim=0
) )
state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.bias'] = torch.cat([bq, bk, bv], dim=0) state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
else: else:
state_dict[f'bert.encoder.layers.{d}.mixer.Wq.weight'] = Wq state_dict[f"bert.encoder.layers.{d}.mixer.Wq.weight"] = Wq
state_dict[f'bert.encoder.layers.{d}.mixer.Wkv.weight'] = torch.cat( state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
[Wk, Wv], dim=0 state_dict[f"bert.encoder.layers.{d}.mixer.Wq.bias"] = bq
) state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat([bk, bv], dim=0)
state_dict[f'bert.encoder.layers.{d}.mixer.Wq.bias'] = bq
state_dict[f'bert.encoder.layers.{d}.mixer.Wkv.bias'] = torch.cat([bk, bv], dim=0)
def key_mapping_attn(key): def key_mapping_attn(key):
return re.sub(r'^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)', return re.sub(
r'bert.encoder.layers.\1.mixer.out_proj.\2', key) r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
r"bert.encoder.layers.\1.mixer.out_proj.\2",
key,
)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
def key_mapping_decoder_bias(key): def key_mapping_decoder_bias(key):
return re.sub(r'^cls.predictions.bias', 'cls.predictions.decoder.bias', key) return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
# Word embedding # Word embedding
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
if pad_vocab_size_multiple > 1: if pad_vocab_size_multiple > 1:
word_embeddings = state_dict['bert.embeddings.word_embeddings.weight'] word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
state_dict['bert.embeddings.word_embeddings.weight'] = F.pad( state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0]) word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
) )
decoder_weight = state_dict['cls.predictions.decoder.weight'] decoder_weight = state_dict["cls.predictions.decoder.weight"]
state_dict['cls.predictions.decoder.weight'] = F.pad( state_dict["cls.predictions.decoder.weight"] = F.pad(
decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0]) decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
) )
# If the vocab was padded, we want to set the decoder bias for those padded indices to be # If the vocab was padded, we want to set the decoder bias for those padded indices to be
# strongly negative (i.e. the decoder shouldn't predict those indices). # strongly negative (i.e. the decoder shouldn't predict those indices).
# TD [2022-05-09]: I don't think it affects the MLPerf training. # TD [2022-05-09]: I don't think it affects the MLPerf training.
decoder_bias = state_dict['cls.predictions.decoder.bias'] decoder_bias = state_dict["cls.predictions.decoder.bias"]
state_dict['cls.predictions.decoder.bias'] = F.pad( state_dict["cls.predictions.decoder.bias"] = F.pad(
decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0 decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
) )
......
...@@ -2,93 +2,114 @@ ...@@ -2,93 +2,114 @@
import math import math
import re import re
from collections import OrderedDict from collections import OrderedDict
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from transformers import FalconConfig, GPT2Config
from transformers import GPT2Config, FalconConfig
def remap_state_dict_hf_falcon(state_dict, config): def remap_state_dict_hf_falcon(state_dict, config):
def key_mapping_layers(key): def key_mapping_layers(key):
return re.sub(r'^transformer.h.', 'transformer.layers.', key) return re.sub(r"^transformer.h.", "transformer.layers.", key)
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding # Word embedding
def key_mapping_emb(key): def key_mapping_emb(key):
return re.sub(r'^transformer.word_embeddings.', 'transformer.embeddings.word_embeddings.', key) return re.sub(
r"^transformer.word_embeddings.", "transformer.embeddings.word_embeddings.", key
)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight') word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example. # It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad( state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
) )
if getattr(config, 'tie_word_embeddings'): if getattr(config, "tie_word_embeddings"):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight'] state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
else: else:
output_embeddings = state_dict.pop('lm_head.weight') output_embeddings = state_dict.pop("lm_head.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example. # It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad( state_dict["lm_head.weight"] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
) )
output_embeddings_bias = state_dict.pop('lm_head.bias') output_embeddings_bias = state_dict.pop("lm_head.bias")
state_dict['lm_head.bias'] = F.pad( state_dict["lm_head.bias"] = F.pad(
output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0]) output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
) )
# LayerNorm # LayerNorm
def key_mapping_ln(key): def key_mapping_ln(key):
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.', key = re.sub(
r'transformer.layers.\1.norm1.', key) r"^transformer.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.', )
r'transformer.layers.\1.norm2.', key) key = re.sub(
key = re.sub(r'^transformer.layers.(\d+).ln_attn.', r'transformer.layers.\1.norm1.', key) r"^transformer.layers.(\d+).post_attention_layernorm.",
key = re.sub(r'^transformer.layers.(\d+).ln_mlp.', r'transformer.layers.\1.norm2.', key) r"transformer.layers.\1.norm2.",
key,
)
key = re.sub(r"^transformer.layers.(\d+).ln_attn.", r"transformer.layers.\1.norm1.", key)
key = re.sub(r"^transformer.layers.(\d+).ln_mlp.", r"transformer.layers.\1.norm2.", key)
return key return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP # MLP
def key_mapping_mlp(key): def key_mapping_mlp(key):
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.', key = re.sub(
r'transformer.layers.\1.mlp.fc1.', key) r"^transformer.layers.(\d+).mlp.dense_h_to_4h.", r"transformer.layers.\1.mlp.fc1.", key
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.', )
r'transformer.layers.\1.mlp.fc2.', key) key = re.sub(
r"^transformer.layers.(\d+).mlp.dense_4h_to_h.", r"transformer.layers.\1.mlp.fc2.", key
)
return key return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
def key_mapping_attn(key): def key_mapping_attn(key):
key = re.sub(r'^transformer.layers.(\d+).self_attention.query_key_value.', key = re.sub(
r'transformer.layers.\1.mixer.Wqkv.', key) r"^transformer.layers.(\d+).self_attention.query_key_value.",
key = re.sub(r'^transformer.layers.(\d+).self_attention.dense.', r"transformer.layers.\1.mixer.Wqkv.",
r'transformer.layers.\1.mixer.out_proj.', key) key,
)
key = re.sub(
r"^transformer.layers.(\d+).self_attention.dense.",
r"transformer.layers.\1.mixer.out_proj.",
key,
)
return key return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
n_head = config.n_head n_head = config.n_head
n_head_kv = getattr(config, "n_head_kv", 1) n_head_kv = getattr(config, "n_head_kv", 1)
headdim = config.hidden_size // n_head headdim = config.hidden_size // n_head
for l in range(config.n_layer): for l in range(config.n_layer):
# The weights are stored in a different layout compared to our implementation # The weights are stored in a different layout compared to our implementation
Wqkv = rearrange(state_dict.pop(f'transformer.layers.{l}.mixer.Wqkv.weight'), Wqkv = rearrange(
state_dict.pop(f"transformer.layers.{l}.mixer.Wqkv.weight"),
"(group ratio headdim) ... -> group ratio headdim ...", "(group ratio headdim) ... -> group ratio headdim ...",
ratio=n_head // n_head_kv + 2, headdim=headdim) ratio=n_head // n_head_kv + 2,
headdim=headdim,
)
Wq = rearrange(Wqkv[:, :-2], "group ratio headdim ... -> (group ratio headdim) ...") Wq = rearrange(Wqkv[:, :-2], "group ratio headdim ... -> (group ratio headdim) ...")
Wk = rearrange(Wqkv[:, [-2]], "group ratio headdim ... -> (group ratio headdim) ...") Wk = rearrange(Wqkv[:, [-2]], "group ratio headdim ... -> (group ratio headdim) ...")
Wv = rearrange(Wqkv[:, [-1]], "group ratio headdim ... -> (group ratio headdim) ...") Wv = rearrange(Wqkv[:, [-1]], "group ratio headdim ... -> (group ratio headdim) ...")
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0) state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
return state_dict return state_dict
def falcon_config_to_gpt2_config(falcon_config: FalconConfig) -> GPT2Config: def falcon_config_to_gpt2_config(falcon_config: FalconConfig) -> GPT2Config:
# The 40b config uses "n_head_kv" instead of "num_kv_heads" # The 40b config uses "n_head_kv" instead of "num_kv_heads"
n_head_kv = getattr(falcon_config, "n_head_kv", n_head_kv = getattr(
1 if getattr(falcon_config, "multi_query", False) falcon_config,
else falcon_config.n_head) "n_head_kv",
1 if getattr(falcon_config, "multi_query", False) else falcon_config.n_head,
)
# HACK: the 40b config has 2 LN per layer instead of 1, but that's not reflected in the config. # HACK: the 40b config has 2 LN per layer instead of 1, but that's not reflected in the config.
# So we have to infer it from the number of heads in the key/value block # So we have to infer it from the number of heads in the key/value block
parallel_block_tied_norm = n_head_kv == 1 parallel_block_tied_norm = n_head_kv == 1
......
...@@ -11,6 +11,8 @@ import torch ...@@ -11,6 +11,8 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from transformers import GPT2Config
from flash_attn.models.falcon import remap_state_dict_hf_falcon from flash_attn.models.falcon import remap_state_dict_hf_falcon
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
from flash_attn.models.gptj import remap_state_dict_hf_gptj from flash_attn.models.gptj import remap_state_dict_hf_gptj
...@@ -27,10 +29,9 @@ from flash_attn.modules.mlp import ( ...@@ -27,10 +29,9 @@ from flash_attn.modules.mlp import (
ParallelMLP, ParallelMLP,
) )
from flash_attn.ops.activations import sqrelu_fwd from flash_attn.ops.activations import sqrelu_fwd
from flash_attn.utils.distributed import all_gather_raw, sync_shared_params, get_dim_for_local_rank from flash_attn.utils.distributed import all_gather_raw, get_dim_for_local_rank, sync_shared_params
from flash_attn.utils.generation import GenerationMixin from flash_attn.utils.generation import GenerationMixin
from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import GPT2Config
try: try:
from flash_attn.ops.fused_dense import ColumnParallelLinear from flash_attn.ops.fused_dense import ColumnParallelLinear
...@@ -690,7 +691,7 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): ...@@ -690,7 +691,7 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
if key in state_dict: if key in state_dict:
x = state_dict[key] x = state_dict[key]
dim = x.shape[0] // world_size dim = x.shape[0] // world_size
state_dict[key] = x[rank * dim: (rank + 1) * dim] state_dict[key] = x[rank * dim : (rank + 1) * dim]
def shard_last_dim(state_dict, key, multiple_of=1): def shard_last_dim(state_dict, key, multiple_of=1):
if key in state_dict: if key in state_dict:
...@@ -707,17 +708,19 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): ...@@ -707,17 +708,19 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
x = state_dict[key] x = state_dict[key]
dim = x.shape[0] // world_size // 2 dim = x.shape[0] // world_size // 2
state_dict[key] = rearrange( state_dict[key] = rearrange(
rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim: (rank + 1) * dim], rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim : (rank + 1) * dim],
"two o ... -> (two o) ...", "two o ... -> (two o) ...",
) )
def shard_qkv_headdim(state_dict, key): def shard_qkv_headdim(state_dict, key):
if key in state_dict: if key in state_dict:
n_head_each_rank = [ n_head_each_rank = [
get_dim_for_local_rank(n_head, world_size, local_rank) for local_rank in range(world_size) get_dim_for_local_rank(n_head, world_size, local_rank)
for local_rank in range(world_size)
] ]
n_head_kv_each_rank = [ n_head_kv_each_rank = [
get_dim_for_local_rank(n_head_kv, world_size, local_rank) for local_rank in range(world_size) get_dim_for_local_rank(n_head_kv, world_size, local_rank)
for local_rank in range(world_size)
] ]
beg_n_head = sum(n_head_each_rank[:rank]) beg_n_head = sum(n_head_each_rank[:rank])
...@@ -729,7 +732,8 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): ...@@ -729,7 +732,8 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
if n_head_kv == n_head: if n_head_kv == n_head:
x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3) x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3)
state_dict[key] = rearrange( state_dict[key] = rearrange(
x[:, beg_n_head * head_dim : end_n_head * head_dim], "three d ... -> (three d) ..." x[:, beg_n_head * head_dim : end_n_head * head_dim],
"three d ... -> (three d) ...",
) )
else: else:
x = rearrange( x = rearrange(
...@@ -741,8 +745,14 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): ...@@ -741,8 +745,14 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
torch.cat( torch.cat(
[ [
x[beg_n_head:end_n_head], x[beg_n_head:end_n_head],
x[n_head + beg_n_head_kv: n_head + end_n_head_kv], x[n_head + beg_n_head_kv : n_head + end_n_head_kv],
x[n_head + n_head_kv + beg_n_head_kv: n_head + n_head_kv + end_n_head_kv], x[
n_head
+ n_head_kv
+ beg_n_head_kv : n_head
+ n_head_kv
+ end_n_head_kv
],
], ],
dim=0, dim=0,
), ),
...@@ -824,7 +834,7 @@ def combine_state_dicts_tp(state_dicts, config): ...@@ -824,7 +834,7 @@ def combine_state_dicts_tp(state_dicts, config):
torch.cat([x[:n_head_per_rank] for x in xs], dim=0), torch.cat([x[:n_head_per_rank] for x in xs], dim=0),
torch.cat( torch.cat(
[ [
x[n_head_per_rank: n_head_per_rank + n_head_kv_per_rank] x[n_head_per_rank : n_head_per_rank + n_head_kv_per_rank]
for x in xs for x in xs
], ],
dim=0, dim=0,
......
...@@ -2,80 +2,100 @@ ...@@ -2,80 +2,100 @@
import math import math
import re import re
from collections import OrderedDict from collections import OrderedDict
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from transformers import GPT2Config, GPTNeoXConfig from transformers import GPT2Config, GPTNeoXConfig
def remap_state_dict_hf_gpt_neox(state_dict, config): def remap_state_dict_hf_gpt_neox(state_dict, config):
def key_mapping_layers(key): def key_mapping_layers(key):
return re.sub(r'^gpt_neox.', 'transformer.', key) return re.sub(r"^gpt_neox.", "transformer.", key)
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding # Word embedding
def key_mapping_emb(key): def key_mapping_emb(key):
return re.sub(r'^transformer.embed_in.', 'transformer.embeddings.word_embeddings.', key) return re.sub(r"^transformer.embed_in.", "transformer.embeddings.word_embeddings.", key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight') word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example. # It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad( state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
) )
if getattr(config, 'tie_word_embeddings'): if getattr(config, "tie_word_embeddings"):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight'] state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
else: else:
output_embeddings = state_dict.pop('embed_out.weight') output_embeddings = state_dict.pop("embed_out.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example. # It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad( state_dict["lm_head.weight"] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
) )
# LayerNorm # LayerNorm
def key_mapping_ln(key): def key_mapping_ln(key):
key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key) key = re.sub(r"^transformer.final_layer_norm.", r"transformer.ln_f.", key)
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.', r'transformer.layers.\1.norm1.', key) key = re.sub(
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.', r'transformer.layers.\1.norm2.', key) r"^transformer.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key
)
key = re.sub(
r"^transformer.layers.(\d+).post_attention_layernorm.",
r"transformer.layers.\1.norm2.",
key,
)
return key return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP # MLP
def key_mapping_mlp(key): def key_mapping_mlp(key):
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.', r'transformer.layers.\1.mlp.fc1.', key) key = re.sub(
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.', r'transformer.layers.\1.mlp.fc2.', key) r"^transformer.layers.(\d+).mlp.dense_h_to_4h.", r"transformer.layers.\1.mlp.fc1.", key
)
key = re.sub(
r"^transformer.layers.(\d+).mlp.dense_4h_to_h.", r"transformer.layers.\1.mlp.fc2.", key
)
return key return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention # Attention
for l in range(config.n_layer): for l in range(config.n_layer):
# We don't store these biases # We don't store these biases
state_dict.pop(f'transformer.layers.{l}.attention.bias') state_dict.pop(f"transformer.layers.{l}.attention.bias")
state_dict.pop(f'transformer.layers.{l}.attention.masked_bias') state_dict.pop(f"transformer.layers.{l}.attention.masked_bias")
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim) # GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim) # while we store Wqkv as ((3 nheads headdim), hidden_dim)
headdim = config.hidden_size // config.num_attention_heads headdim = config.hidden_size // config.num_attention_heads
Wqkv = state_dict.pop(f'transformer.layers.{l}.attention.query_key_value.weight') Wqkv = state_dict.pop(f"transformer.layers.{l}.attention.query_key_value.weight")
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = rearrange( state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = rearrange(
Wqkv, '(nheads three headdim) ... -> (three nheads headdim) ...', Wqkv,
three=3, headdim=headdim "(nheads three headdim) ... -> (three nheads headdim) ...",
three=3,
headdim=headdim,
) )
bqkv = state_dict.pop(f'transformer.layers.{l}.attention.query_key_value.bias') bqkv = state_dict.pop(f"transformer.layers.{l}.attention.query_key_value.bias")
state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = rearrange( state_dict[f"transformer.layers.{l}.mixer.Wqkv.bias"] = rearrange(
bqkv, '(nheads three headdim) -> (three nheads headdim)', bqkv, "(nheads three headdim) -> (three nheads headdim)", three=3, headdim=headdim
three=3, headdim=headdim
) )
def key_mapping_attn(key): def key_mapping_attn(key):
key = re.sub(r'^transformer.layers.(\d+).attention.dense.', key = re.sub(
r'transformer.layers.\1.mixer.out_proj.', key) r"^transformer.layers.(\d+).attention.dense.",
key = re.sub(r'^transformer.layers.(\d+).attention.rotary_emb.', r"transformer.layers.\1.mixer.out_proj.",
r'transformer.layers.\1.mixer.rotary_emb.', key) key,
)
key = re.sub(
r"^transformer.layers.(\d+).attention.rotary_emb.",
r"transformer.layers.\1.mixer.rotary_emb.",
key,
)
return key return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict return state_dict
......
...@@ -2,67 +2,78 @@ ...@@ -2,67 +2,78 @@
import math import math
import re import re
from collections import OrderedDict from collections import OrderedDict
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import GPT2Config, GPTJConfig from transformers import GPT2Config, GPTJConfig
def remap_state_dict_hf_gptj(state_dict, config): def remap_state_dict_hf_gptj(state_dict, config):
def key_mapping_layers(key): def key_mapping_layers(key):
return re.sub(r'^transformer.h.', 'transformer.layers.', key) return re.sub(r"^transformer.h.", "transformer.layers.", key)
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding # Word embedding
def key_mapping_emb(key): def key_mapping_emb(key):
return re.sub(r'^transformer.wte.', 'transformer.embeddings.word_embeddings.', key) return re.sub(r"^transformer.wte.", "transformer.embeddings.word_embeddings.", key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight') word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example. # It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad( state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
) )
if getattr(config, 'tie_word_embeddings'): if getattr(config, "tie_word_embeddings"):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight'] state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
else: else:
output_embeddings = state_dict.pop('lm_head.weight') output_embeddings = state_dict.pop("lm_head.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example. # It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad( state_dict["lm_head.weight"] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
) )
output_embeddings_bias = state_dict.pop('lm_head.bias') output_embeddings_bias = state_dict.pop("lm_head.bias")
state_dict['lm_head.bias'] = F.pad( state_dict["lm_head.bias"] = F.pad(
output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0]) output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
) )
# LayerNorm # LayerNorm
def key_mapping_ln(key): def key_mapping_ln(key):
return re.sub(r'^transformer.layers.(\d+).ln_1.', r'transformer.layers.\1.norm1.', key) return re.sub(r"^transformer.layers.(\d+).ln_1.", r"transformer.layers.\1.norm1.", key)
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP # MLP
def key_mapping_mlp(key): def key_mapping_mlp(key):
key = re.sub(r'^transformer.layers.(\d+).mlp.fc_in.', r'transformer.layers.\1.mlp.fc1.', key) key = re.sub(
key = re.sub(r'^transformer.layers.(\d+).mlp.fc_out.', r'transformer.layers.\1.mlp.fc2.', key) r"^transformer.layers.(\d+).mlp.fc_in.", r"transformer.layers.\1.mlp.fc1.", key
)
key = re.sub(
r"^transformer.layers.(\d+).mlp.fc_out.", r"transformer.layers.\1.mlp.fc2.", key
)
return key return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention # Attention
for l in range(config.n_layer): for l in range(config.n_layer):
Wq = state_dict.pop(f'transformer.layers.{l}.attn.q_proj.weight') Wq = state_dict.pop(f"transformer.layers.{l}.attn.q_proj.weight")
Wk = state_dict.pop(f'transformer.layers.{l}.attn.k_proj.weight') Wk = state_dict.pop(f"transformer.layers.{l}.attn.k_proj.weight")
Wv = state_dict.pop(f'transformer.layers.{l}.attn.v_proj.weight') Wv = state_dict.pop(f"transformer.layers.{l}.attn.v_proj.weight")
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0) state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
# We don't store these biases # We don't store these biases
state_dict.pop(f'transformer.layers.{l}.attn.bias') state_dict.pop(f"transformer.layers.{l}.attn.bias")
state_dict.pop(f'transformer.layers.{l}.attn.masked_bias') state_dict.pop(f"transformer.layers.{l}.attn.masked_bias")
def key_mapping_attn(key): def key_mapping_attn(key):
return re.sub(r'^transformer.layers.(\d+).attn.out_proj.', return re.sub(
r'transformer.layers.\1.mixer.out_proj.', key) r"^transformer.layers.(\d+).attn.out_proj.",
r"transformer.layers.\1.mixer.out_proj.",
key,
)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict return state_dict
......
...@@ -15,63 +15,81 @@ from transformers import GPT2Config, LlamaConfig ...@@ -15,63 +15,81 @@ from transformers import GPT2Config, LlamaConfig
def remap_state_dict_meta_llama(state_dict, config): def remap_state_dict_meta_llama(state_dict, config):
def key_mapping_layers(key): def key_mapping_layers(key):
return f'transformer.{key}' if not key.startswith('output.') else key return f"transformer.{key}" if not key.startswith("output.") else key
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding # Word embedding
def key_mapping_emb(key): def key_mapping_emb(key):
return re.sub(r'^transformer.tok_embeddings.', 'transformer.embeddings.word_embeddings.', key) return re.sub(
r"^transformer.tok_embeddings.", "transformer.embeddings.word_embeddings.", key
)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight') word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example. # It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) vocab_size = (
* pad_vocab_size_multiple) math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad( )
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
) )
if getattr(config, 'tie_word_embeddings'): if getattr(config, "tie_word_embeddings"):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight'] state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
else: else:
output_embeddings = state_dict.pop('output.weight') output_embeddings = state_dict.pop("output.weight")
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings # Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
# differently. # differently.
vocab_size = (math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple) vocab_size = (
* pad_vocab_size_multiple) math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple
)
# It's possible that vocab_size is padded to be a multiple of 8, for example. # It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad( state_dict["lm_head.weight"] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
) )
# LayerNorm # LayerNorm
def key_mapping_ln(key): def key_mapping_ln(key):
key = re.sub(r'^transformer.norm.', r'transformer.ln_f.', key) key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key)
key = re.sub(r'^transformer.layers.(\d+).attention_norm.', r'transformer.layers.\1.norm1.', key) key = re.sub(
key = re.sub(r'^transformer.layers.(\d+).ffn_norm.', r'transformer.layers.\1.norm2.', key) r"^transformer.layers.(\d+).attention_norm.", r"transformer.layers.\1.norm1.", key
)
key = re.sub(r"^transformer.layers.(\d+).ffn_norm.", r"transformer.layers.\1.norm2.", key)
return key return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP # MLP
for l in range(config.n_layer): for l in range(config.n_layer):
w1 = state_dict.pop(f'transformer.layers.{l}.feed_forward.w1.weight') w1 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w1.weight")
w3 = state_dict.pop(f'transformer.layers.{l}.feed_forward.w3.weight') w3 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w3.weight")
# Our ordering is different # Our ordering is different
state_dict[f'transformer.layers.{l}.mlp.fc1.weight'] = torch.cat([w3, w1], dim=0) state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0)
def key_mapping_mlp(key): def key_mapping_mlp(key):
return re.sub(r'^transformer.layers.(\d+).feed_forward.w2.', return re.sub(
r'transformer.layers.\1.mlp.fc2.', key) r"^transformer.layers.(\d+).feed_forward.w2.", r"transformer.layers.\1.mlp.fc2.", key
)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention # Attention
for l in range(config.n_layer): for l in range(config.n_layer):
Wq = state_dict.pop(f'transformer.layers.{l}.attention.wq.weight') Wq = state_dict.pop(f"transformer.layers.{l}.attention.wq.weight")
Wk = state_dict.pop(f'transformer.layers.{l}.attention.wk.weight') Wk = state_dict.pop(f"transformer.layers.{l}.attention.wk.weight")
Wv = state_dict.pop(f'transformer.layers.{l}.attention.wv.weight') Wv = state_dict.pop(f"transformer.layers.{l}.attention.wv.weight")
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0) state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
# We don't store these # We don't store these
state_dict.pop(f'transformer.layers.{l}.attention.inner_attention.rope.freqs', None) state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None)
def key_mapping_attn(key): def key_mapping_attn(key):
return re.sub(r'^transformer.layers.(\d+).attention.wo.', return re.sub(
r'transformer.layers.\1.mixer.out_proj.', key) r"^transformer.layers.(\d+).attention.wo.",
r"transformer.layers.\1.mixer.out_proj.",
key,
)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
state_dict.pop("transformer.rope.freqs", None) state_dict.pop("transformer.rope.freqs", None)
...@@ -82,29 +100,32 @@ def remap_state_dict_meta_llama(state_dict, config): ...@@ -82,29 +100,32 @@ def remap_state_dict_meta_llama(state_dict, config):
def remap_state_dict_hf_llama(state_dict, config): def remap_state_dict_hf_llama(state_dict, config):
# Embedding # Embedding
def key_mapping_emb(key): def key_mapping_emb(key):
return re.sub(r'^model.embed_tokens.', 'transformer.embeddings.word_embeddings.', key) return re.sub(r"^model.embed_tokens.", "transformer.embeddings.word_embeddings.", key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight') word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example. # It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) vocab_size = (
* pad_vocab_size_multiple) math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad( )
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
) )
# LM head # LM head
if getattr(config, 'tie_word_embeddings'): if getattr(config, "tie_word_embeddings"):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight'] state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
else: else:
output_embeddings = state_dict.pop('lm_head.weight') output_embeddings = state_dict.pop("lm_head.weight")
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings # Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
# differently. # differently.
vocab_size = (math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple) vocab_size = (
* pad_vocab_size_multiple) math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple
)
# It's possible that vocab_size is padded to be a multiple of 8, for example. # It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad( state_dict["lm_head.weight"] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
) )
...@@ -113,21 +134,22 @@ def remap_state_dict_hf_llama(state_dict, config): ...@@ -113,21 +134,22 @@ def remap_state_dict_hf_llama(state_dict, config):
# Fusing weights this way based on difference in the following: # Fusing weights this way based on difference in the following:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220 # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220
# https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115 # https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115
w1 = state_dict.pop(f'model.layers.{l}.mlp.gate_proj.weight') w1 = state_dict.pop(f"model.layers.{l}.mlp.gate_proj.weight")
w3 = state_dict.pop(f'model.layers.{l}.mlp.up_proj.weight') w3 = state_dict.pop(f"model.layers.{l}.mlp.up_proj.weight")
state_dict[f'transformer.layers.{l}.mlp.fc1.weight'] = torch.cat([w3, w1], dim=0) state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0)
def key_mapping_mlp(key): def key_mapping_mlp(key):
return re.sub(r'^model.layers.(\d+).mlp.down_proj.', return re.sub(r"^model.layers.(\d+).mlp.down_proj.", r"transformer.layers.\1.mlp.fc2.", key)
r'transformer.layers.\1.mlp.fc2.', key)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# LayerNorm # LayerNorm
def key_mapping_ln(key): def key_mapping_ln(key):
key = re.sub(r'^model.norm.', r'transformer.ln_f.', key) key = re.sub(r"^model.norm.", r"transformer.ln_f.", key)
key = re.sub(r'^model.layers.(\d+).input_layernorm.', r'transformer.layers.\1.norm1.', key) key = re.sub(r"^model.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key)
key = re.sub(r'^model.layers.(\d+).post_attention_layernorm.', r'transformer.layers.\1.norm2.', key) key = re.sub(
r"^model.layers.(\d+).post_attention_layernorm.", r"transformer.layers.\1.norm2.", key
)
return key return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
...@@ -135,42 +157,52 @@ def remap_state_dict_hf_llama(state_dict, config): ...@@ -135,42 +157,52 @@ def remap_state_dict_hf_llama(state_dict, config):
def inv_permute(w): def inv_permute(w):
# Inverse of permute implemented in: # Inverse of permute implemented in:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114 # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
return w.reshape( return (
config.n_head, 2, config.n_embd // config.n_head // 2, config.n_embd w.reshape(config.n_head, 2, config.n_embd // config.n_head // 2, config.n_embd)
).transpose(1, 2).reshape(config.n_embd, config.n_embd) .transpose(1, 2)
.reshape(config.n_embd, config.n_embd)
)
# Attention # Attention
for l in range(config.n_layer): for l in range(config.n_layer):
Wq = state_dict.pop(f'model.layers.{l}.self_attn.q_proj.weight') Wq = state_dict.pop(f"model.layers.{l}.self_attn.q_proj.weight")
Wk = state_dict.pop(f'model.layers.{l}.self_attn.k_proj.weight') Wk = state_dict.pop(f"model.layers.{l}.self_attn.k_proj.weight")
Wv = state_dict.pop(f'model.layers.{l}.self_attn.v_proj.weight') Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight")
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat( state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat(
[inv_permute(Wq), inv_permute(Wk), Wv], dim=0 [inv_permute(Wq), inv_permute(Wk), Wv], dim=0
) )
# We don't store these # We don't store these
state_dict.pop(f'model.layers.{l}.self_attn.rotary_emb.inv_freq', None) state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None)
def key_mapping_attn(key): def key_mapping_attn(key):
return re.sub(r'^model.layers.(\d+).self_attn.o_proj.', return re.sub(
r'transformer.layers.\1.mixer.out_proj.', key) r"^model.layers.(\d+).self_attn.o_proj.", r"transformer.layers.\1.mixer.out_proj.", key
)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict return state_dict
def config_from_meta_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> LlamaConfig: def config_from_meta_checkpoint(
checkpoint_path: Union[str, os.PathLike], model_name: str
) -> LlamaConfig:
"""Load a LlamaConfig from a checkpoint path.""" """Load a LlamaConfig from a checkpoint path."""
with open(Path(checkpoint_path) / model_name / 'params.json') as f: with open(Path(checkpoint_path) / model_name / "params.json") as f:
params = json.load(f) params = json.load(f)
config = LlamaConfig(hidden_size=params['dim'], intermediate_size=None, config = LlamaConfig(
num_attention_heads=params['n_heads'], hidden_size=params["dim"],
num_hidden_layers=params['n_layers'], intermediate_size=None,
rms_norm_eps=params['norm_eps']) num_attention_heads=params["n_heads"],
num_hidden_layers=params["n_layers"],
rms_norm_eps=params["norm_eps"],
)
return config return config
def config_from_hf_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> LlamaConfig: def config_from_hf_checkpoint(
return LlamaConfig.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf' / "config.json") checkpoint_path: Union[str, os.PathLike], model_name: str
) -> LlamaConfig:
return LlamaConfig.from_pretrained(Path(checkpoint_path) / f"{model_name}-hf" / "config.json")
def config_from_checkpoint( def config_from_checkpoint(
...@@ -182,10 +214,14 @@ def config_from_checkpoint( ...@@ -182,10 +214,14 @@ def config_from_checkpoint(
return config_from_hf_checkpoint(checkpoint_path, model_name) return config_from_hf_checkpoint(checkpoint_path, model_name)
def state_dicts_from_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> list[dict]: def state_dicts_from_checkpoint(
checkpoint_path: Union[str, os.PathLike], model_name: str
) -> list[dict]:
# Need to sort, otherwise we mess up the ordering and the weights are wrong # Need to sort, otherwise we mess up the ordering and the weights are wrong
return [torch.load(path, map_location='cpu') return [
for path in sorted((Path(checkpoint_path) / model_name).glob('consolidated.*.pth'))] torch.load(path, map_location="cpu")
for path in sorted((Path(checkpoint_path) / model_name).glob("consolidated.*.pth"))
]
def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config: def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
...@@ -196,7 +232,7 @@ def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config: ...@@ -196,7 +232,7 @@ def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
n_layer=llama_config.num_hidden_layers, n_layer=llama_config.num_hidden_layers,
n_head=llama_config.num_attention_heads, n_head=llama_config.num_attention_heads,
n_inner=llama_config.intermediate_size, n_inner=llama_config.intermediate_size,
activation_function='swiglu', # Hardcode since HF calls it 'silu' activation_function="swiglu", # Hardcode since HF calls it 'silu'
# Llama doesn't have dropout, idk if it's because they only release the inference code # Llama doesn't have dropout, idk if it's because they only release the inference code
resid_pdrop=0.0, resid_pdrop=0.0,
embd_pdrop=0.0, embd_pdrop=0.0,
......
...@@ -2,75 +2,86 @@ ...@@ -2,75 +2,86 @@
import math import math
import re import re
from collections import OrderedDict from collections import OrderedDict
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import GPT2Config, OPTConfig from transformers import GPT2Config, OPTConfig
def remap_state_dict_hf_opt(state_dict, config): def remap_state_dict_hf_opt(state_dict, config):
def key_mapping_model(key): def key_mapping_model(key):
key = re.sub(r'^model.decoder.', 'transformer.', key) key = re.sub(r"^model.decoder.", "transformer.", key)
# The OPT-350m model uses '^decoder' instead of '^model.decoder' # The OPT-350m model uses '^decoder' instead of '^model.decoder'
key = re.sub(r'^decoder.', 'transformer.', key) key = re.sub(r"^decoder.", "transformer.", key)
return key return key
state_dict = OrderedDict((key_mapping_model(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_model(k), v) for k, v in state_dict.items())
# Word embedding and position embedding # Word embedding and position embedding
def key_mapping_emb(key): def key_mapping_emb(key):
key = re.sub(r'^transformer.embed_tokens.', 'transformer.embeddings.word_embeddings.', key) key = re.sub(r"^transformer.embed_tokens.", "transformer.embeddings.word_embeddings.", key)
# The OPT-350m model uses has project_in and project_out # The OPT-350m model uses has project_in and project_out
key = re.sub(r'^transformer.project_in.', 'transformer.embeddings.project_in.', key) key = re.sub(r"^transformer.project_in.", "transformer.embeddings.project_in.", key)
key = re.sub(r'^transformer.project_out.', 'project_out.', key) key = re.sub(r"^transformer.project_out.", "project_out.", key)
key = re.sub(r'^transformer.embed_positions.', key = re.sub(
'transformer.embeddings.position_embeddings.', key) r"^transformer.embed_positions.", "transformer.embeddings.position_embeddings.", key
)
return key return key
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
# OPT uses the first 2 indices of pos_emb for padding tokens # OPT uses the first 2 indices of pos_emb for padding tokens
pos_embeddings = state_dict.pop('transformer.embeddings.position_embeddings.weight') pos_embeddings = state_dict.pop("transformer.embeddings.position_embeddings.weight")
state_dict['transformer.embeddings.position_embeddings.weight'] = pos_embeddings[2:] state_dict["transformer.embeddings.position_embeddings.weight"] = pos_embeddings[2:]
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight') word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example. # It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad( state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
) )
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight'] state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
# LayerNorm # LayerNorm
def key_mapping_ln(key): def key_mapping_ln(key):
key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key) key = re.sub(r"^transformer.final_layer_norm.", r"transformer.ln_f.", key)
# The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm' # The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm'
key = re.sub(r'^transformer.layer_norm.', r'transformer.ln_f.', key) key = re.sub(r"^transformer.layer_norm.", r"transformer.ln_f.", key)
key = re.sub(r'^transformer.layers.(\d+).self_attn_layer_norm.', key = re.sub(
r'transformer.layers.\1.norm1.', key) r"^transformer.layers.(\d+).self_attn_layer_norm.", r"transformer.layers.\1.norm1.", key
key = re.sub(r'^transformer.layers.(\d+).final_layer_norm.', )
r'transformer.layers.\1.norm2.', key) key = re.sub(
r"^transformer.layers.(\d+).final_layer_norm.", r"transformer.layers.\1.norm2.", key
)
return key return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP # MLP
def key_mapping_mlp(key): def key_mapping_mlp(key):
return re.sub(r'^transformer.layers.(\d+).fc(1|2).', return re.sub(
r'transformer.layers.\1.mlp.fc\2.', key) r"^transformer.layers.(\d+).fc(1|2).", r"transformer.layers.\1.mlp.fc\2.", key
)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention # Attention
for l in range(config.n_layer): for l in range(config.n_layer):
Wq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.weight') Wq = state_dict.pop(f"transformer.layers.{l}.self_attn.q_proj.weight")
Wk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.weight') Wk = state_dict.pop(f"transformer.layers.{l}.self_attn.k_proj.weight")
Wv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.weight') Wv = state_dict.pop(f"transformer.layers.{l}.self_attn.v_proj.weight")
bq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.bias') bq = state_dict.pop(f"transformer.layers.{l}.self_attn.q_proj.bias")
bk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.bias') bk = state_dict.pop(f"transformer.layers.{l}.self_attn.k_proj.bias")
bv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.bias') bv = state_dict.pop(f"transformer.layers.{l}.self_attn.v_proj.bias")
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0) state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = torch.cat([bq, bk, bv], dim=0) state_dict[f"transformer.layers.{l}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
def key_mapping_attn(key): def key_mapping_attn(key):
return re.sub(r'^transformer.layers.(\d+).self_attn.out_proj.', return re.sub(
r'transformer.layers.\1.mixer.out_proj.', key) r"^transformer.layers.(\d+).self_attn.out_proj.",
r"transformer.layers.\1.mixer.out_proj.",
key,
)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict return state_dict
...@@ -79,8 +90,11 @@ def remap_state_dict_hf_opt(state_dict, config): ...@@ -79,8 +90,11 @@ def remap_state_dict_hf_opt(state_dict, config):
def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config: def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config:
assert opt_config.layerdrop == 0.0 assert opt_config.layerdrop == 0.0
assert opt_config.layer_norm_elementwise_affine assert opt_config.layer_norm_elementwise_affine
word_embed_proj_dim = (None if opt_config.word_embed_proj_dim == opt_config.hidden_size word_embed_proj_dim = (
else opt_config.word_embed_proj_dim) None
if opt_config.word_embed_proj_dim == opt_config.hidden_size
else opt_config.word_embed_proj_dim
)
return GPT2Config( return GPT2Config(
vocab_size=opt_config.vocab_size, vocab_size=opt_config.vocab_size,
n_positions=opt_config.max_position_embeddings, n_positions=opt_config.max_position_embeddings,
...@@ -98,5 +112,5 @@ def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config: ...@@ -98,5 +112,5 @@ def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config:
eos_token_id=opt_config.eos_token_id, eos_token_id=opt_config.eos_token_id,
# These are new arguments not in the original GPT2Config # These are new arguments not in the original GPT2Config
prenorm=opt_config.do_layer_norm_before, prenorm=opt_config.do_layer_norm_before,
word_embed_proj_dim=word_embed_proj_dim word_embed_proj_dim=word_embed_proj_dim,
) )
...@@ -10,13 +10,14 @@ import torch ...@@ -10,13 +10,14 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from timm.models.helpers import named_apply
from torch.nn.init import trunc_normal_
from torchvision.ops import StochasticDepth
from flash_attn.layers.patch_embed import PatchEmbed from flash_attn.layers.patch_embed import PatchEmbed
from flash_attn.modules.block import Block from flash_attn.modules.block import Block
from flash_attn.modules.mha import MHA from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import FusedMLP, Mlp from flash_attn.modules.mlp import FusedMLP, Mlp
from timm.models.helpers import named_apply
from torch.nn.init import trunc_normal_
from torchvision.ops import StochasticDepth
try: try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm from flash_attn.ops.layer_norm import dropout_add_layer_norm
......
# Copyright (c) 2022, Tri Dao. # Copyright (c) 2022, Tri Dao.
from typing import Optional
from functools import partial from functools import partial
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torchvision.ops import StochasticDepth from torchvision.ops import StochasticDepth
from flash_attn.modules.mha import MHA from flash_attn.modules.mha import MHA
...@@ -35,11 +34,24 @@ except ImportError: ...@@ -35,11 +34,24 @@ except ImportError:
class Block(nn.Module): class Block(nn.Module):
def __init__(
def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm, self,
dropout_cls=nn.Dropout, prenorm=True, resid_dropout1=0., resid_dropout2=0., dim,
drop_path1=0., drop_path2=0., fused_dropout_add_ln=False, return_residual=False, mixer_cls=None,
residual_in_fp32=False, sequence_parallel=False, mark_shared_params=False): mlp_cls=None,
norm_cls=nn.LayerNorm,
dropout_cls=nn.Dropout,
prenorm=True,
resid_dropout1=0.0,
resid_dropout2=0.0,
drop_path1=0.0,
drop_path2=0.0,
fused_dropout_add_ln=False,
return_residual=False,
residual_in_fp32=False,
sequence_parallel=False,
mark_shared_params=False,
):
""" """
For prenorm=True, this Block has a slightly different structure compared to a regular For prenorm=True, this Block has a slightly different structure compared to a regular
prenorm Transformer block. prenorm Transformer block.
...@@ -63,26 +75,27 @@ class Block(nn.Module): ...@@ -63,26 +75,27 @@ class Block(nn.Module):
self.return_residual = return_residual self.return_residual = return_residual
self.residual_in_fp32 = residual_in_fp32 self.residual_in_fp32 = residual_in_fp32
if self.residual_in_fp32: if self.residual_in_fp32:
assert self.prenorm, 'residual_in_fp32 is only compatible with prenorm=True' assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
if mixer_cls is None: if mixer_cls is None:
mixer_cls = partial(MHA, num_heads=dim // 64) mixer_cls = partial(MHA, num_heads=dim // 64)
if mlp_cls is None: if mlp_cls is None:
mlp_cls = partial(Mlp, hidden_features=4 * dim) mlp_cls = partial(Mlp, hidden_features=4 * dim)
self.mixer = mixer_cls(dim) self.mixer = mixer_cls(dim)
self.dropout1 = dropout_cls(resid_dropout1) self.dropout1 = dropout_cls(resid_dropout1)
self.drop_path1 = StochasticDepth(drop_path1, mode='row') self.drop_path1 = StochasticDepth(drop_path1, mode="row")
self.norm1 = norm_cls(dim) self.norm1 = norm_cls(dim)
self.mlp = mlp_cls(dim) self.mlp = mlp_cls(dim)
if not isinstance(self.mlp, nn.Identity): if not isinstance(self.mlp, nn.Identity):
self.dropout2 = dropout_cls(resid_dropout2) self.dropout2 = dropout_cls(resid_dropout2)
self.drop_path2 = StochasticDepth(drop_path2, mode='row') self.drop_path2 = StochasticDepth(drop_path2, mode="row")
self.norm2 = norm_cls(dim) self.norm2 = norm_cls(dim)
if self.fused_dropout_add_ln: if self.fused_dropout_add_ln:
assert dropout_add_layer_norm is not None, 'dropout_layer_norm is not installed' assert dropout_add_layer_norm is not None, "dropout_layer_norm is not installed"
assert dropout_add_rms_norm is not None, 'dropout_layer_norm is not installed' assert dropout_add_rms_norm is not None, "dropout_layer_norm is not installed"
assert (isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
and isinstance(self.dropout1, nn.Dropout)) self.dropout1, nn.Dropout
)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0, # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# then the input to each worker in the tensor parallel group will be different. # then the input to each worker in the tensor parallel group will be different.
...@@ -94,22 +107,27 @@ class Block(nn.Module): ...@@ -94,22 +107,27 @@ class Block(nn.Module):
if sequence_parallel: if sequence_parallel:
for p in self.norm1.parameters(): for p in self.norm1.parameters():
p._sequence_parallel = True p._sequence_parallel = True
if hasattr(self, 'norm2'): if hasattr(self, "norm2"):
for p in self.norm2.parameters(): for p in self.norm2.parameters():
p._sequence_parallel = True p._sequence_parallel = True
# Mark the norm parameters as "shared_params" so that we sync their values at init. # Mark the norm parameters as "shared_params" so that we sync their values at init.
if mark_shared_params: if mark_shared_params:
for p in self.norm1.parameters(): for p in self.norm1.parameters():
p._shared_params = True p._shared_params = True
if hasattr(self, 'norm2'): if hasattr(self, "norm2"):
for p in self.norm2.parameters(): for p in self.norm2.parameters():
p._shared_params = True p._shared_params = True
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, def forward(
mixer_subset=None, mixer_kwargs=None): self,
hidden_states: Tensor,
residual: Optional[Tensor] = None,
mixer_subset=None,
mixer_kwargs=None,
):
r"""Pass the input through the encoder layer. r"""Pass the input through the encoder layer.
Args: Args:
...@@ -119,8 +137,11 @@ class Block(nn.Module): ...@@ -119,8 +137,11 @@ class Block(nn.Module):
before applying the query projection. Useful for e.g., ViT where we only care before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer. about the CLS token in the last layer.
""" """
fused_add_norm_fn = (dropout_add_rms_norm if RMSNorm and isinstance(self.norm1, RMSNorm) fused_add_norm_fn = (
else dropout_add_layer_norm) dropout_add_rms_norm
if RMSNorm and isinstance(self.norm1, RMSNorm)
else dropout_add_layer_norm
)
if self.prenorm: if self.prenorm:
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln:
dropped = self.drop_path1(self.dropout1(hidden_states)) dropped = self.drop_path1(self.dropout1(hidden_states))
...@@ -132,19 +153,28 @@ class Block(nn.Module): ...@@ -132,19 +153,28 @@ class Block(nn.Module):
if self.drop_path1.p == 0 or not self.training: if self.drop_path1.p == 0 or not self.training:
rowscale1 = None rowscale1 = None
else: else:
rowscale1 = self.drop_path1(torch.ones( rowscale1 = self.drop_path1(
hidden_states.shape[:-1], device=hidden_states.device, torch.ones(
dtype=hidden_states.dtype) hidden_states.shape[:-1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
) )
hidden_states, residual = fused_add_norm_fn( hidden_states, residual = fused_add_norm_fn(
hidden_states, residual, self.norm1.weight, self.norm1.bias, hidden_states,
self.dropout1.p if self.training else 0.0, self.norm1.eps, residual,
rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32 self.norm1.weight,
self.norm1.bias,
self.dropout1.p if self.training else 0.0,
self.norm1.eps,
rowscale=rowscale1,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
) )
if mixer_kwargs is None: if mixer_kwargs is None:
mixer_kwargs = {} mixer_kwargs = {}
if mixer_subset is not None: if mixer_subset is not None:
mixer_kwargs['mixer_subset'] = mixer_subset mixer_kwargs["mixer_subset"] = mixer_subset
hidden_states = self.mixer(hidden_states, **mixer_kwargs) hidden_states = self.mixer(hidden_states, **mixer_kwargs)
if mixer_subset is not None: if mixer_subset is not None:
residual = residual[:, mixer_subset] residual = residual[:, mixer_subset]
...@@ -159,14 +189,23 @@ class Block(nn.Module): ...@@ -159,14 +189,23 @@ class Block(nn.Module):
if self.drop_path2.p == 0 or not self.training: if self.drop_path2.p == 0 or not self.training:
rowscale2 = None rowscale2 = None
else: else:
rowscale2 = self.drop_path2(torch.ones( rowscale2 = self.drop_path2(
hidden_states.shape[:-1], device=hidden_states.device, torch.ones(
dtype=hidden_states.dtype) hidden_states.shape[:-1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
) )
hidden_states, residual = fused_add_norm_fn( hidden_states, residual = fused_add_norm_fn(
hidden_states, residual, self.norm2.weight, self.norm2.bias, hidden_states,
self.dropout2.p if self.training else 0.0, self.norm2.eps, residual,
rowscale=rowscale2, prenorm=True, residual_in_fp32=self.residual_in_fp32 self.norm2.weight,
self.norm2.bias,
self.dropout2.p if self.training else 0.0,
self.norm2.eps,
rowscale=rowscale2,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
) )
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
return hidden_states, residual return hidden_states, residual
...@@ -178,38 +217,58 @@ class Block(nn.Module): ...@@ -178,38 +217,58 @@ class Block(nn.Module):
if self.return_residual: # mixer out is actually a pair here if self.return_residual: # mixer out is actually a pair here
mixer_out, hidden_states = mixer_out mixer_out, hidden_states = mixer_out
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln:
hidden_states = self.norm1((self.drop_path1(self.dropout1(mixer_out)) hidden_states = self.norm1(
+ hidden_states).to(dtype=self.norm1.weight.dtype)) (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
dtype=self.norm1.weight.dtype
)
)
else: else:
if self.drop_path1.p == 0 or not self.training: if self.drop_path1.p == 0 or not self.training:
rowscale1 = None rowscale1 = None
else: else:
rowscale1 = self.drop_path1(torch.ones( rowscale1 = self.drop_path1(
mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype) torch.ones(
mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype
)
) )
hidden_states = fused_add_norm_fn( hidden_states = fused_add_norm_fn(
mixer_out, hidden_states, self.norm1.weight, self.norm1.bias, mixer_out,
self.dropout1.p if self.training else 0.0, self.norm1.eps, hidden_states,
rowscale=rowscale1, prenorm=False self.norm1.weight,
self.norm1.bias,
self.dropout1.p if self.training else 0.0,
self.norm1.eps,
rowscale=rowscale1,
prenorm=False,
) )
if not isinstance(self.mlp, nn.Identity): if not isinstance(self.mlp, nn.Identity):
mlp_out = self.mlp(hidden_states) mlp_out = self.mlp(hidden_states)
if self.return_residual: # mlp out is actually a pair here if self.return_residual: # mlp out is actually a pair here
mlp_out, hidden_states = mlp_out mlp_out, hidden_states = mlp_out
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln:
hidden_states = self.norm2((self.drop_path2(self.dropout2(mlp_out)) hidden_states = self.norm2(
+ hidden_states).to(dtype=self.norm2.weight.dtype)) (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
dtype=self.norm2.weight.dtype
)
)
else: else:
if self.drop_path2.p == 0 or not self.training: if self.drop_path2.p == 0 or not self.training:
rowscale2 = None rowscale2 = None
else: else:
rowscale2 = self.drop_path2(torch.ones( rowscale2 = self.drop_path2(
mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype) torch.ones(
mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype
)
) )
hidden_states = fused_add_norm_fn( hidden_states = fused_add_norm_fn(
mlp_out, hidden_states, self.norm2.weight, self.norm2.bias, mlp_out,
self.dropout2.p if self.training else 0.0, self.norm2.eps, hidden_states,
rowscale=rowscale2, prenorm=False self.norm2.weight,
self.norm2.bias,
self.dropout2.p if self.training else 0.0,
self.norm2.eps,
rowscale=rowscale2,
prenorm=False,
) )
return hidden_states return hidden_states
...@@ -219,10 +278,21 @@ class ParallelBlock(nn.Module): ...@@ -219,10 +278,21 @@ class ParallelBlock(nn.Module):
and PaLM. and PaLM.
""" """
def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm, def __init__(
dropout_cls=nn.Dropout, resid_dropout1=0., resid_dropout2=0., self,
tied_norm=False, fused_dropout_add_ln=False, residual_in_fp32=False, dim,
sequence_parallel=False, mark_shared_params=False): mixer_cls=None,
mlp_cls=None,
norm_cls=nn.LayerNorm,
dropout_cls=nn.Dropout,
resid_dropout1=0.0,
resid_dropout2=0.0,
tied_norm=False,
fused_dropout_add_ln=False,
residual_in_fp32=False,
sequence_parallel=False,
mark_shared_params=False,
):
""" """
This Block has a slightly different structure compared to a regular This Block has a slightly different structure compared to a regular
prenorm Transformer block. prenorm Transformer block.
...@@ -250,10 +320,15 @@ class ParallelBlock(nn.Module): ...@@ -250,10 +320,15 @@ class ParallelBlock(nn.Module):
self.norm2 = norm_cls(dim) self.norm2 = norm_cls(dim)
if self.fused_dropout_add_ln: if self.fused_dropout_add_ln:
assert dropout_add_layer_norm_parallel_residual is not None, 'dropout_layer_norm is not installed' assert (
assert dropout_add_rms_norm_parallel_residual is not None, 'dropout_layer_norm is not installed' dropout_add_layer_norm_parallel_residual is not None
assert (isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) ), "dropout_layer_norm is not installed"
and isinstance(self.dropout1, nn.Dropout)) assert (
dropout_add_rms_norm_parallel_residual is not None
), "dropout_layer_norm is not installed"
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
self.dropout1, nn.Dropout
)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0, # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# then the input to each worker in the tensor parallel group will be different. # then the input to each worker in the tensor parallel group will be different.
...@@ -265,22 +340,27 @@ class ParallelBlock(nn.Module): ...@@ -265,22 +340,27 @@ class ParallelBlock(nn.Module):
if sequence_parallel: if sequence_parallel:
for p in self.norm1.parameters(): for p in self.norm1.parameters():
p._sequence_parallel = True p._sequence_parallel = True
if hasattr(self, 'norm2'): if hasattr(self, "norm2"):
for p in self.norm2.parameters(): for p in self.norm2.parameters():
p._sequence_parallel = True p._sequence_parallel = True
# Mark the norm parameters as "shared_params" so that we sync their values at init. # Mark the norm parameters as "shared_params" so that we sync their values at init.
if mark_shared_params: if mark_shared_params:
for p in self.norm1.parameters(): for p in self.norm1.parameters():
p._shared_params = True p._shared_params = True
if hasattr(self, 'norm2'): if hasattr(self, "norm2"):
for p in self.norm2.parameters(): for p in self.norm2.parameters():
p._shared_params = True p._shared_params = True
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
def forward(self, hidden_states1: Tensor, hidden_states2: Optional[Tensor] = None, def forward(
residual: Optional[Tensor] = None, mixer_kwargs=None): self,
hidden_states1: Tensor,
hidden_states2: Optional[Tensor] = None,
residual: Optional[Tensor] = None,
mixer_kwargs=None,
):
r"""Pass the input through the encoder layer. r"""Pass the input through the encoder layer.
Args: Args:
...@@ -290,30 +370,47 @@ class ParallelBlock(nn.Module): ...@@ -290,30 +370,47 @@ class ParallelBlock(nn.Module):
""" """
# TODO: Ideally we should only do the allgather / allreduce once for # TODO: Ideally we should only do the allgather / allreduce once for
# the Linear to MLP & Attention # the Linear to MLP & Attention
fused_add_norm_fn = (dropout_add_rms_norm_parallel_residual fused_add_norm_fn = (
dropout_add_rms_norm_parallel_residual
if isinstance(self.norm1, RMSNorm) if isinstance(self.norm1, RMSNorm)
else dropout_add_layer_norm_parallel_residual) else dropout_add_layer_norm_parallel_residual
)
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln:
dropped1 = self.dropout1(hidden_states1) dropped1 = self.dropout1(hidden_states1)
# For the very 1st block, we only want 1 dropout, not two different dropouts # For the very 1st block, we only want 1 dropout, not two different dropouts
if hidden_states2 is not None: if hidden_states2 is not None:
dropped2 = self.dropout2(hidden_states2) dropped2 = self.dropout2(hidden_states2)
residual = ((residual + dropped1 + dropped2) residual = (
if residual is not None else dropped1 + dropped2) (residual + dropped1 + dropped2)
if residual is not None
else dropped1 + dropped2
)
else: else:
residual = (residual + dropped1) if residual is not None else dropped1 residual = (residual + dropped1) if residual is not None else dropped1
hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
hidden_states2 = (self.norm2(residual.to(dtype=self.norm2.weight.dtype)) hidden_states2 = (
if not self.tied_norm else hidden_states1) self.norm2(residual.to(dtype=self.norm2.weight.dtype))
if not self.tied_norm
else hidden_states1
)
if self.residual_in_fp32: if self.residual_in_fp32:
residual = residual.to(torch.float32) residual = residual.to(torch.float32)
else: else:
weight2, bias2 = ((self.norm2.weight, self.norm2.bias) weight2, bias2 = (
if not self.tied_norm else (None, None)) (self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None)
)
hidden_states1, hidden_states2, residual = fused_add_norm_fn( hidden_states1, hidden_states2, residual = fused_add_norm_fn(
hidden_states1, hidden_states2, residual, self.norm1.weight, self.norm1.bias, hidden_states1,
weight2, bias2, self.dropout1.p if self.training else 0.0, self.norm1.eps, hidden_states2,
prenorm=True, residual_in_fp32=self.residual_in_fp32 residual,
self.norm1.weight,
self.norm1.bias,
weight2,
bias2,
self.dropout1.p if self.training else 0.0,
self.norm1.eps,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
) )
if self.tied_norm: if self.tied_norm:
hidden_states2 = hidden_states1 hidden_states2 = hidden_states1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment