Commit f1a73d07 authored by Tri Dao's avatar Tri Dao
Browse files

Run isort and black on python files

parent cbb4cf5f
......@@ -2,42 +2,52 @@
import torch
import torch.nn as nn
from torch import Tensor
from einops import rearrange
from torch import Tensor
from flash_attn.utils.distributed import reduce_scatter, all_reduce
from flash_attn.utils.distributed import all_reduce, reduce_scatter
class GPT2Embeddings(nn.Module):
def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None,
word_embed_proj_dim=None, device=None, dtype=None):
def __init__(
self,
embed_dim,
vocab_size,
max_position_embeddings,
padding_idx=None,
word_embed_proj_dim=None,
device=None,
dtype=None,
):
"""
If max_position_embeddings <= 0, there's no position embeddings
If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
the project up to embed_dim
If max_position_embeddings <= 0, there's no position embeddings
If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
the project up to embed_dim
"""
factory_kwargs = {'device': device, 'dtype': dtype}
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if word_embed_proj_dim is None:
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
**factory_kwargs)
self.word_embeddings = nn.Embedding(
vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
)
self.project_in = None
else:
self.word_embeddings = nn.Embedding(vocab_size, word_embed_proj_dim,
padding_idx=padding_idx, **factory_kwargs)
self.project_in = nn.Linear(word_embed_proj_dim, embed_dim, bias=False,
**factory_kwargs)
self.word_embeddings = nn.Embedding(
vocab_size, word_embed_proj_dim, padding_idx=padding_idx, **factory_kwargs
)
self.project_in = nn.Linear(
word_embed_proj_dim, embed_dim, bias=False, **factory_kwargs
)
self.max_position_embeddings = max_position_embeddings
if self.max_position_embeddings > 0:
self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
**factory_kwargs)
self.position_embeddings = nn.Embedding(
max_position_embeddings, embed_dim, **factory_kwargs
)
def forward(self, input_ids, position_ids=None):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
"""
batch_size, seqlen = input_ids.shape
embeddings = self.word_embeddings(input_ids)
......@@ -52,31 +62,39 @@ class GPT2Embeddings(nn.Module):
class BertEmbeddings(nn.Module):
def __init__(self, embed_dim, vocab_size, max_position_embeddings, type_vocab_size,
padding_idx=None, device=None, dtype=None):
def __init__(
self,
embed_dim,
vocab_size,
max_position_embeddings,
type_vocab_size,
padding_idx=None,
device=None,
dtype=None,
):
"""
If max_position_embeddings <= 0, there's no position embeddings
If type_vocab_size <= 0, there's no token type embeddings
If max_position_embeddings <= 0, there's no position embeddings
If type_vocab_size <= 0, there's no token type embeddings
"""
factory_kwargs = {'device': device, 'dtype': dtype}
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
**factory_kwargs)
self.word_embeddings = nn.Embedding(
vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
)
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
if self.max_position_embeddings > 0:
self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
**factory_kwargs)
self.position_embeddings = nn.Embedding(
max_position_embeddings, embed_dim, **factory_kwargs
)
if self.type_vocab_size > 0:
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim,
**factory_kwargs)
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
def forward(self, input_ids, position_ids=None, token_type_ids=None):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
token_type_ids: (batch, seqlen)
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
token_type_ids: (batch, seqlen)
"""
batch_size, seqlen = input_ids.shape
embeddings = self.word_embeddings(input_ids)
......@@ -94,16 +112,17 @@ class BertEmbeddings(nn.Module):
class VocabParallelEmbedding(nn.Embedding):
def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
self.process_group = process_group
if process_group is not None:
world_size = torch.distributed.get_world_size(process_group)
if num_embeddings % world_size != 0:
raise ValueError(f'num_embeddings ({num_embeddings}) must be divisible by '
f'world_size ({world_size})')
raise ValueError(
f"num_embeddings ({num_embeddings}) must be divisible by "
f"world_size ({world_size})"
)
if world_size > 1 and padding_idx is not None:
raise RuntimeError('ParallelEmbedding does not support padding_idx')
raise RuntimeError("ParallelEmbedding does not support padding_idx")
else:
world_size = 1
super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)
......@@ -125,33 +144,45 @@ class VocabParallelEmbedding(nn.Embedding):
class ColumnParallelEmbedding(nn.Embedding):
def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
self.process_group = process_group
if process_group is not None:
world_size = torch.distributed.get_world_size(process_group)
if embedding_dim % world_size != 0:
raise ValueError(f'embedding_dim ({embedding_dim}) must be divisible by '
f'world_size ({world_size})')
raise ValueError(
f"embedding_dim ({embedding_dim}) must be divisible by "
f"world_size ({world_size})"
)
else:
world_size = 1
super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
class ParallelGPT2Embeddings(nn.Module):
def __init__(self, embed_dim, vocab_size, max_position_embeddings, process_group,
padding_idx=None, sequence_parallel=True, device=None, dtype=None):
def __init__(
self,
embed_dim,
vocab_size,
max_position_embeddings,
process_group,
padding_idx=None,
sequence_parallel=True,
device=None,
dtype=None,
):
"""
If max_position_embeddings <= 0, there's no position embeddings
If max_position_embeddings <= 0, there's no position embeddings
"""
factory_kwargs = {'device': device, 'dtype': dtype}
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.process_group = process_group
self.sequence_parallel = sequence_parallel
self.word_embeddings = VocabParallelEmbedding(
vocab_size, embed_dim, padding_idx=padding_idx, process_group=process_group,
**factory_kwargs
vocab_size,
embed_dim,
padding_idx=padding_idx,
process_group=process_group,
**factory_kwargs,
)
self.max_position_embeddings = max_position_embeddings
if self.max_position_embeddings > 0:
......@@ -161,8 +192,8 @@ class ParallelGPT2Embeddings(nn.Module):
def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
"""
batch_size, seqlen = input_ids.shape
world_size = torch.distributed.get_world_size(self.process_group)
......@@ -176,8 +207,10 @@ class ParallelGPT2Embeddings(nn.Module):
else:
partition_dim = self.position_embeddings.embedding_dim
rank = torch.distributed.get_rank(self.process_group)
embeddings[..., rank * partition_dim:(rank + 1) * partition_dim] += position_embeddings
embeddings[
..., rank * partition_dim : (rank + 1) * partition_dim
] += position_embeddings
if combine_batch_seqlen_dim:
embeddings = rearrange(embeddings, 'b s d -> (b s) d')
embeddings = rearrange(embeddings, "b s d -> (b s) d")
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
......@@ -732,8 +732,12 @@ class ParallelMHA(nn.Module):
self.num_heads % self.num_heads_kv == 0
), "num_heads must be divisible by num_heads_kv"
self.num_heads_per_rank = get_dim_for_local_rank(self.num_heads, self.world_size, self.local_rank)
self.num_heads_kv_per_rank = get_dim_for_local_rank(self.num_heads, self.world_size, self.local_rank)
self.num_heads_per_rank = get_dim_for_local_rank(
self.num_heads, self.world_size, self.local_rank
)
self.num_heads_kv_per_rank = get_dim_for_local_rank(
self.num_heads, self.world_size, self.local_rank
)
self.head_dim = self.embed_dim // num_heads
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
......
......@@ -17,10 +17,19 @@ except ImportError:
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
bias1=True, bias2=True, return_residual=False, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
activation=F.gelu,
bias1=True,
bias2=True,
return_residual=False,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
out_features = out_features if out_features is not None else in_features
hidden_features = hidden_features if hidden_features is not None else in_features * 4
......@@ -37,21 +46,42 @@ class Mlp(nn.Module):
class ParallelMLP(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
process_group: ProcessGroup = None, sequence_parallel=True,
bias1=True, bias2=True, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
activation=F.gelu,
process_group: ProcessGroup = None,
sequence_parallel=True,
bias1=True,
bias2=True,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
assert ColumnParallelLinear is not None, "Need to install fused_dense"
assert RowParallelLinear is not None, "Need to install fused_dense"
out_features = out_features if out_features is not None else in_features
hidden_features = hidden_features if hidden_features is not None else in_features * 4
self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group, bias=bias1,
sequence_parallel=sequence_parallel, **factory_kwargs)
self.fc1 = ColumnParallelLinear(
in_features,
hidden_features,
process_group,
bias=bias1,
sequence_parallel=sequence_parallel,
**factory_kwargs,
)
self.activation = activation
self.fc2 = RowParallelLinear(hidden_features, out_features, process_group, bias=bias2,
sequence_parallel=sequence_parallel, **factory_kwargs)
self.fc2 = RowParallelLinear(
hidden_features,
out_features,
process_group,
bias=bias2,
sequence_parallel=sequence_parallel,
**factory_kwargs,
)
def forward(self, x):
y = self.fc1(x)
......@@ -61,15 +91,25 @@ class ParallelMLP(nn.Module):
class GatedMlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.sigmoid,
bias1=True, bias2=True, multiple_of=256, return_residual=False,
device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
activation=F.sigmoid,
bias1=True,
bias2=True,
multiple_of=256,
return_residual=False,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
out_features = out_features if out_features is not None else in_features
hidden_features = (hidden_features if hidden_features is not None
else int(8 * in_features / 3))
hidden_features = (
hidden_features if hidden_features is not None else int(8 * in_features / 3)
)
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
self.return_residual = return_residual
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
......@@ -88,24 +128,48 @@ class GatedMlp(nn.Module):
class ParallelGatedMlp(nn.Module):
""" Parallel GatedMlp """
def __init__(self, in_features, process_group, hidden_features=None, out_features=None,
activation=F.sigmoid, bias1=True, bias2=True, multiple_of=256,
sequence_parallel=True, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
"""Parallel GatedMlp"""
def __init__(
self,
in_features,
process_group,
hidden_features=None,
out_features=None,
activation=F.sigmoid,
bias1=True,
bias2=True,
multiple_of=256,
sequence_parallel=True,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
out_features = out_features if out_features is not None else in_features
hidden_features = (hidden_features if hidden_features is not None
else int(8 * in_features / 3))
hidden_features = (
hidden_features if hidden_features is not None else int(8 * in_features / 3)
)
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
if ColumnParallelLinear is None or RowParallelLinear is None:
raise ImportError('fused_dense is not installed')
self.fc1 = ColumnParallelLinear(in_features, 2 * hidden_features, process_group, bias=bias1,
sequence_parallel=sequence_parallel, **factory_kwargs)
raise ImportError("fused_dense is not installed")
self.fc1 = ColumnParallelLinear(
in_features,
2 * hidden_features,
process_group,
bias=bias1,
sequence_parallel=sequence_parallel,
**factory_kwargs,
)
self.activation = activation
self.fc2 = RowParallelLinear(hidden_features, out_features, process_group, bias=bias2,
sequence_parallel=sequence_parallel, **factory_kwargs)
self.fc2 = RowParallelLinear(
hidden_features,
out_features,
process_group,
bias=bias2,
sequence_parallel=sequence_parallel,
**factory_kwargs,
)
def forward(self, x):
y = self.fc1(x)
......
......@@ -5,7 +5,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
......@@ -18,17 +17,19 @@ def bias_gelu(y, bias):
x = bias + y
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, y, bias):
"""Assume that y has shape (B, D) and bias has shape (D)
"""
"""Assume that y has shape (B, D) and bias has shape (D)"""
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
1 + tanh_out
)
grad_y = ff * g
return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
......@@ -56,6 +57,7 @@ bias_gelu_impl = GeLUFunction.apply
def gelu_fwd(x):
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
......@@ -63,7 +65,9 @@ def gelu_fwd(x):
def gelu_bwd(g, x):
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
1 + tanh_out
)
return (ff * g).to(dtype=x.dtype)
......@@ -76,10 +80,11 @@ class FastGeLUFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
(input,) = ctx.saved_tensors
tmp = gelu_bwd(grad_output, input)
return tmp
fast_gelu_impl = FastGeLUFunction.apply
......
......@@ -10,6 +10,10 @@ import fused_dense_lib as fused_dense_cuda
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup
from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_bwd, sqrelu_fwd
from flash_attn.utils.distributed import (
all_gather_raw,
......@@ -18,9 +22,6 @@ from flash_attn.utils.distributed import (
reduce_scatter,
reduce_scatter_raw,
)
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup
class FusedDenseFunc(torch.autograd.Function):
......
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
import dropout_layer_norm
import torch
from torch.nn import init
import dropout_layer_norm
def maybe_align(x, alignment_in_bytes=16):
"""Assume that x already has last dim divisible by alignment_in_bytes
"""
"""Assume that x already has last dim divisible by alignment_in_bytes"""
# TD [2023-07-04] I'm not 100% sure that clone will align the memory
# https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()
def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscale, dropout_p,
epsilon, residual_in_fp32=False, is_rms_norm=False):
""" Assume that arguments are contiguous and aligned to 16 bytes
"""
def _dropout_add_layer_norm_forward(
x0,
residual,
gamma,
beta,
rowscale,
colscale,
dropout_p,
epsilon,
residual_in_fp32=False,
is_rms_norm=False,
):
"""Assume that arguments are contiguous and aligned to 16 bytes"""
hidden_size = gamma.numel()
x0mat = x0.view((-1, hidden_size))
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
rowscale = rowscale.view(-1) if rowscale is not None else None
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
x0mat, residualmat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon,
1.0, 0, None, residual_in_fp32, is_rms_norm
x0mat,
residualmat,
gamma,
beta,
rowscale,
colscale,
None,
None,
dropout_p,
epsilon,
1.0,
0,
None,
residual_in_fp32,
is_rms_norm,
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale,
dropout_p, has_residual, is_rms_norm=False):
""" Assume that arguments are contiguous and aligned to 16 bytes
def _dropout_add_layer_norm_backward(
dz,
dx,
x,
x0,
dmask,
mu,
rsigma,
gamma,
rowscale,
colscale,
dropout_p,
has_residual,
is_rms_norm=False,
):
"""Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale.
......@@ -46,10 +79,25 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
rowscale = rowscale.view(-1) if rowscale is not None else None
if colscale is not None:
assert x0 is not None, 'x0 is required to compute the gradient of colscale'
assert x0 is not None, "x0 is required to compute the gradient of colscale"
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None,
dropout_p, 1.0, 0, has_residual, is_rms_norm
dzmat,
dxmat,
xmat,
x0mat,
dmask,
mu,
rsigma,
gamma,
rowscale,
colscale,
None,
None,
dropout_p,
1.0,
0,
has_residual,
is_rms_norm,
)
# dresidualmat is None if not has_residual
if colscale is None:
......@@ -59,29 +107,68 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale, x0_subset,
out_subset, dropout_p, epsilon, rowscale_const,
out_numrows, residual_in_fp32=False, is_rms_norm=False):
""" Assume that arguments are contiguous and aligned to 16 bytes
"""
def _dropout_add_layer_norm_subset_forward(
x0,
residual,
gamma,
beta,
colscale,
x0_subset,
out_subset,
dropout_p,
epsilon,
rowscale_const,
out_numrows,
residual_in_fp32=False,
is_rms_norm=False,
):
"""Assume that arguments are contiguous and aligned to 16 bytes"""
hidden_size = gamma.numel()
x0mat = x0.view((-1, hidden_size))
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
x0_subset = x0_subset.view(-1) if x0_subset is not None else None
out_subset = out_subset.view(-1) if out_subset is not None else None
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
x0mat, residualmat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, None, residual_in_fp32, is_rms_norm
x0mat,
residualmat,
gamma,
beta,
None,
colscale,
x0_subset,
out_subset,
dropout_p,
epsilon,
rowscale_const,
out_numrows,
None,
residual_in_fp32,
is_rms_norm,
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale,
x0_subset, out_subset, dropout_p, rowscale_const,
x0_numrows, has_residual, is_rms_norm=False):
""" Assume that arguments are contiguous and aligned to 16 bytes
def _dropout_add_layer_norm_subset_backward(
dz,
dx,
x,
x0,
dmask,
mu,
rsigma,
gamma,
colscale,
x0_subset,
out_subset,
dropout_p,
rowscale_const,
x0_numrows,
has_residual,
is_rms_norm=False,
):
"""Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale.
......@@ -94,10 +181,25 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
x0_subset = x0_subset.view(-1) if x0_subset is not None else None
out_subset = out_subset.view(-1) if out_subset is not None else None
if colscale is not None:
assert x0 is not None, 'x0 is required to compute the gradient of colscale'
assert x0 is not None, "x0 is required to compute the gradient of colscale"
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset,
dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm
dzmat,
dxmat,
xmat,
x0mat,
dmask,
mu,
rsigma,
gamma,
None,
colscale,
x0_subset,
out_subset,
dropout_p,
rowscale_const,
x0_numrows,
has_residual,
is_rms_norm,
)
# dresidualmat is None if not has_residual
if colscale is None:
......@@ -108,18 +210,44 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
def _dropout_add_layer_norm_parallel_residual_forward(
x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p,
epsilon, residual_in_fp32=False, is_rms_norm=False
x0,
x1,
residual,
gamma0,
beta0,
gamma1,
beta1,
dropout_p,
epsilon,
residual_in_fp32=False,
is_rms_norm=False,
):
""" Assume that arguments are contiguous and aligned to 16 bytes
"""
"""Assume that arguments are contiguous and aligned to 16 bytes"""
hidden_size = gamma0.numel()
x0mat = x0.view((-1, hidden_size))
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
x0mat, x1mat, residualmat, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
None, residual_in_fp32, is_rms_norm
(
z0mat,
z1mat,
xmat,
dmask0,
dmask1,
mu,
rsigma,
) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
x0mat,
x1mat,
residualmat,
gamma0,
beta0,
gamma1,
beta1,
dropout_p,
epsilon,
None,
residual_in_fp32,
is_rms_norm,
)
# dmask0 and dmask1 are None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
......@@ -127,10 +255,22 @@ def _dropout_add_layer_norm_parallel_residual_forward(
def _dropout_add_layer_norm_parallel_residual_backward(
dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1,
dropout_p, has_x1, has_residual, is_rms_norm=False
dz0,
dz1,
dx,
x,
dmask0,
dmask1,
mu,
rsigma,
gamma0,
gamma1,
dropout_p,
has_x1,
has_residual,
is_rms_norm=False,
):
""" Assume that arguments are contiguous and aligned to 16 bytes
"""Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
"""
......@@ -139,9 +279,30 @@ def _dropout_add_layer_norm_parallel_residual_backward(
dz0mat = dz0.view(xmat.shape)
dz1mat = dz1.view(xmat.shape) if dz1 is not None else None
dxmat = dx.view(xmat.shape) if dx is not None else None
dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1, *rest = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
dz0mat, dz1mat, dxmat, xmat, dmask0, dmask1, mu, rsigma, gamma0, gamma1,
dropout_p, has_x1, has_residual, is_rms_norm
(
dx0mat,
dx1mat,
dresidualmat,
dgamma0,
dbeta0,
dgamma1,
dbeta1,
*rest,
) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
dz0mat,
dz1mat,
dxmat,
xmat,
dmask0,
dmask1,
mu,
rsigma,
gamma0,
gamma1,
dropout_p,
has_x1,
has_residual,
is_rms_norm,
)
# dresidualmat is None if not has_residual
return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1
......@@ -149,8 +310,21 @@ def _dropout_add_layer_norm_parallel_residual_backward(
class DropoutAddLayerNormFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
def forward(
ctx,
x0,
residual,
gamma,
beta,
rowscale,
colscale,
dropout_p,
epsilon,
residual_in_fp32=False,
prenorm=False,
is_rms_norm=False,
return_dmask=False,
):
x0 = maybe_align(x0.contiguous(), 16)
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
gamma = maybe_align(gamma.contiguous(), 16)
......@@ -158,26 +332,43 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
residual_in_fp32, is_rms_norm
x0,
residual,
gamma,
beta,
rowscale,
colscale,
dropout_p,
epsilon,
residual_in_fp32,
is_rms_norm,
)
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved = x0 if colscale is not None else None
ctx.save_for_backward(xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale)
ctx.save_for_backward(
xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale
)
ctx.prenorm = prenorm
ctx.dropout_p = dropout_p
ctx.has_residual = residual is not None
ctx.is_rms_norm = is_rms_norm
ctx.has_beta = beta is not None
if not return_dmask:
return (zmat.view(x0.shape) if not prenorm
else (zmat.view(x0.shape), xmat.view(x0.shape)))
return (
zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape))
)
else:
dmask = (dmask.view(x0.shape) if dropout_p > 0.
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
dmask = (
dmask.view(x0.shape)
if dropout_p > 0.0
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
)
ctx.mark_non_differentiable(dmask)
return ((zmat.view(x0.shape), dmask) if not prenorm
else (zmat.view(x0.shape), xmat.view(x0.shape), dmask))
return (
(zmat.view(x0.shape), dmask)
if not prenorm
else (zmat.view(x0.shape), xmat.view(x0.shape), dmask)
)
@staticmethod
def backward(ctx, dz, *args):
......@@ -189,35 +380,85 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
dropout_p = ctx.dropout_p
has_residual = ctx.has_residual
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual,
ctx.is_rms_norm
dz,
dx,
x,
x0,
dmask,
mu,
rsigma,
gamma,
rowscale,
colscale,
dropout_p,
has_residual,
ctx.is_rms_norm,
)
dx0 = dx0mat.view(x.shape)
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
dcolscale = rest[0] if colscale is not None else None
return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None,
None, None, None, None, None)
return (
dx0,
dresidual,
dgamma,
dbeta if ctx.has_beta else None,
None,
dcolscale,
None,
None,
None,
None,
None,
None,
)
class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32=False,
prenorm=False, is_rms_norm=False, return_dmask=False):
def forward(
ctx,
x0,
residual,
gamma,
beta,
colscale,
x0_subset,
out_subset,
dropout_p,
epsilon,
rowscale_const,
out_numrows,
residual_in_fp32=False,
prenorm=False,
is_rms_norm=False,
return_dmask=False,
):
x0 = maybe_align(x0.contiguous(), 16)
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
gamma = maybe_align(gamma.contiguous(), 16)
beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32, is_rms_norm
x0,
residual,
gamma,
beta,
colscale,
x0_subset,
out_subset,
dropout_p,
epsilon,
rowscale_const,
out_numrows,
residual_in_fp32,
is_rms_norm,
)
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved = x0 if colscale is not None else None
x_shape = (-1, *x0.shape[1:])
ctx.save_for_backward(xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale,
x0_subset, out_subset)
ctx.save_for_backward(
xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset
)
ctx.prenorm = prenorm
ctx.dropout_p = dropout_p
ctx.rowscale_const = rowscale_const
......@@ -227,14 +468,16 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
ctx.has_beta = beta is not None
z_shape = (-1, *x0.shape[1:])
if not return_dmask:
return (zmat.view(z_shape) if not prenorm
else (zmat.view(z_shape), xmat.view(x0.shape)))
return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape))
else:
z = zmat.view(z_shape)
dmask = (dmask.view(x0.shape) if dropout_p > 0.
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
dmask = (
dmask.view(x0.shape)
if dropout_p > 0.0
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
)
ctx.mark_non_differentiable(dmask)
return ((z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask))
return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask)
@staticmethod
def backward(ctx, dz, *args):
......@@ -246,20 +489,63 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
dropout_p = ctx.dropout_p
has_residual = ctx.has_residual
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p,
ctx.rowscale_const, ctx.x0_numrows, has_residual, ctx.is_rms_norm
dz,
dx,
x,
x0,
dmask,
mu,
rsigma,
gamma,
colscale,
x0_subset,
out_subset,
dropout_p,
ctx.rowscale_const,
ctx.x0_numrows,
has_residual,
ctx.is_rms_norm,
)
dx0 = dx0mat.view(-1, *x.shape[1:])
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
dcolscale = rest[0] if colscale is not None else None
return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None,
None, None, None, None, None, None, None, None)
return (
dx0,
dresidual,
dgamma,
dbeta if ctx.has_beta else None,
dcolscale,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
def forward(
ctx,
x0,
x1,
residual,
gamma0,
beta0,
gamma1,
beta1,
dropout_p,
epsilon,
residual_in_fp32=False,
prenorm=False,
is_rms_norm=False,
return_dmask=False,
):
x0 = maybe_align(x0.contiguous(), 16)
x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
......@@ -267,9 +553,26 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None
gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None
beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None
z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = _dropout_add_layer_norm_parallel_residual_forward(
x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
residual_in_fp32, is_rms_norm
(
z0mat,
z1mat,
xmat,
dmask0,
dmask1,
mu,
rsigma,
) = _dropout_add_layer_norm_parallel_residual_forward(
x0,
x1,
residual,
gamma0,
beta0,
gamma1,
beta1,
dropout_p,
epsilon,
residual_in_fp32,
is_rms_norm,
)
ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma)
ctx.prenorm = prenorm
......@@ -282,13 +585,21 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
if not return_dmask:
return z if not prenorm else (*z, xmat.view(x0.shape))
else:
dmask0 = (dmask0.view(x0.shape) if dropout_p > 0.
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
dmask1 = (dmask1.view(x0.shape) if dropout_p > 0. and x1 is not None
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
dmask0 = (
dmask0.view(x0.shape)
if dropout_p > 0.0
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
)
dmask1 = (
dmask1.view(x0.shape)
if dropout_p > 0.0 and x1 is not None
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
)
ctx.mark_non_differentiable(dmask0)
ctx.mark_non_differentiable(dmask1)
return (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)
return (
(*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)
)
@staticmethod
def backward(ctx, dz0, dz1, *args):
......@@ -299,63 +610,170 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
dropout_p = ctx.dropout_p
has_x1 = ctx.has_x1
has_residual = ctx.has_residual
dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 = _dropout_add_layer_norm_parallel_residual_backward(
dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1, dropout_p, has_x1,
has_residual, ctx.is_rms_norm
(
dx0mat,
dx1mat,
dresidualmat,
dgamma0,
dbeta0,
dgamma1,
dbeta1,
) = _dropout_add_layer_norm_parallel_residual_backward(
dz0,
dz1,
dx,
x,
dmask0,
dmask1,
mu,
rsigma,
gamma0,
gamma1,
dropout_p,
has_x1,
has_residual,
ctx.is_rms_norm,
)
dx0 = dx0mat.view(x.shape)
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
return (dx0, dx1, dresidual, dgamma0, dbeta0 if ctx.has_beta else None, dgamma1,
dbeta1 if ctx.has_beta else None, None, None, None, None, None, None)
return (
dx0,
dx1,
dresidual,
dgamma0,
dbeta0 if ctx.has_beta else None,
dgamma1,
dbeta1 if ctx.has_beta else None,
None,
None,
None,
None,
None,
None,
)
def layer_norm(x, weight, bias, epsilon):
return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
def dropout_add_layer_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
layerscale=None, prenorm=False, residual_in_fp32=False,
return_dropout_mask=False):
def dropout_add_layer_norm(
x0,
residual,
weight,
bias,
dropout_p,
epsilon,
rowscale=None,
layerscale=None,
prenorm=False,
residual_in_fp32=False,
return_dropout_mask=False,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return DropoutAddLayerNormFn.apply(
x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
False, return_dropout_mask
x0,
residual,
weight,
bias,
rowscale,
layerscale,
dropout_p,
epsilon,
residual_in_fp32,
prenorm,
False,
return_dropout_mask,
)
def dropout_add_layer_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
x0_subset=None, out_subset=None, rowscale_const=1.0,
out_numrows=0, prenorm=False, residual_in_fp32=False,
return_dropout_mask=False):
def dropout_add_layer_norm_subset(
x0,
residual,
weight,
bias,
dropout_p,
epsilon,
layerscale=None,
x0_subset=None,
out_subset=None,
rowscale_const=1.0,
out_numrows=0,
prenorm=False,
residual_in_fp32=False,
return_dropout_mask=False,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return DropoutAddLayerNormSubsetFn.apply(
x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32, prenorm, False, return_dropout_mask
x0,
residual,
weight,
bias,
layerscale,
x0_subset,
out_subset,
dropout_p,
epsilon,
rowscale_const,
out_numrows,
residual_in_fp32,
prenorm,
False,
return_dropout_mask,
)
def dropout_add_layer_norm_parallel_residual(
x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, prenorm=False,
residual_in_fp32=False, return_dropout_mask=False
x0,
x1,
residual,
weight0,
bias0,
weight1,
bias1,
dropout_p,
epsilon,
prenorm=False,
residual_in_fp32=False,
return_dropout_mask=False,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return DropoutAddLayerNormParallelResidualFn.apply(
x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm,
False, return_dropout_mask
x0,
x1,
residual,
weight0,
bias0,
weight1,
bias1,
dropout_p,
epsilon,
residual_in_fp32,
prenorm,
False,
return_dropout_mask,
)
class DropoutAddLayerNorm(torch.nn.Module):
def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
def __init__(
self,
hidden_size,
prenorm=False,
p=0.0,
eps=1e-5,
residual_in_fp32=False,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.prenorm = prenorm
self.p = p
......@@ -370,6 +788,13 @@ class DropoutAddLayerNorm(torch.nn.Module):
init.zeros_(self.bias)
def forward(self, x0, residual=None):
return dropout_add_layer_norm(x0, residual, self.weight, self.bias,
self.p if self.training else 0.0, self.eps,
prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)
return dropout_add_layer_norm(
x0,
residual,
self.weight,
self.bias,
self.p if self.training else 0.0,
self.eps,
prenorm=self.prenorm,
residual_in_fp32=self.residual_in_fp32,
)
......@@ -4,60 +4,130 @@
import torch
from torch.nn import init
from flash_attn.ops.layer_norm import DropoutAddLayerNormFn, DropoutAddLayerNormSubsetFn
from flash_attn.ops.layer_norm import DropoutAddLayerNormParallelResidualFn
from flash_attn.ops.layer_norm import (
DropoutAddLayerNormFn,
DropoutAddLayerNormParallelResidualFn,
DropoutAddLayerNormSubsetFn,
)
def rms_norm(x, weight, epsilon):
return DropoutAddLayerNormFn.apply(x, None, weight, None, None, None, 0.0, epsilon, False,
False, True)
return DropoutAddLayerNormFn.apply(
x, None, weight, None, None, None, 0.0, epsilon, False, False, True
)
def dropout_add_rms_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
layerscale=None, prenorm=False, residual_in_fp32=False,
return_dropout_mask=False):
def dropout_add_rms_norm(
x0,
residual,
weight,
bias,
dropout_p,
epsilon,
rowscale=None,
layerscale=None,
prenorm=False,
residual_in_fp32=False,
return_dropout_mask=False,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return DropoutAddLayerNormFn.apply(
x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
True, return_dropout_mask
x0,
residual,
weight,
bias,
rowscale,
layerscale,
dropout_p,
epsilon,
residual_in_fp32,
prenorm,
True,
return_dropout_mask,
)
def dropout_add_rms_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
x0_subset=None, out_subset=None, rowscale_const=1.0,
out_numrows=0, prenorm=False, residual_in_fp32=False,
return_dropout_mask=False):
def dropout_add_rms_norm_subset(
x0,
residual,
weight,
bias,
dropout_p,
epsilon,
layerscale=None,
x0_subset=None,
out_subset=None,
rowscale_const=1.0,
out_numrows=0,
prenorm=False,
residual_in_fp32=False,
return_dropout_mask=False,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return DropoutAddLayerNormSubsetFn.apply(
x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32, prenorm, True, return_dropout_mask
x0,
residual,
weight,
bias,
layerscale,
x0_subset,
out_subset,
dropout_p,
epsilon,
rowscale_const,
out_numrows,
residual_in_fp32,
prenorm,
True,
return_dropout_mask,
)
def dropout_add_rms_norm_parallel_residual(
x0, x1, residual, weight0, bias0, weight1, bias1,
dropout_p, epsilon, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
x0,
x1,
residual,
weight0,
bias0,
weight1,
bias1,
dropout_p,
epsilon,
prenorm=False,
residual_in_fp32=False,
return_dropout_mask=False,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return DropoutAddLayerNormParallelResidualFn.apply(
x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm,
True, return_dropout_mask
x0,
x1,
residual,
weight0,
bias0,
weight1,
bias1,
dropout_p,
epsilon,
residual_in_fp32,
prenorm,
True,
return_dropout_mask,
)
class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter('bias', None)
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self):
......@@ -68,22 +138,37 @@ class RMSNorm(torch.nn.Module):
class DropoutAddRMSNorm(torch.nn.Module):
def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
def __init__(
self,
hidden_size,
prenorm=False,
p=0.0,
eps=1e-5,
residual_in_fp32=False,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.prenorm = prenorm
self.p = p
self.eps = eps
self.residual_in_fp32 = residual_in_fp32
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter('bias', None)
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
def forward(self, x0, residual=None):
return dropout_add_rms_norm(x0, residual, self.weight, None,
self.p if self.training else 0.0, self.eps,
prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)
return dropout_add_rms_norm(
x0,
residual,
self.weight,
None,
self.p if self.training else 0.0,
self.eps,
prenorm=self.prenorm,
residual_in_fp32=self.residual_in_fp32,
)
......@@ -11,7 +11,6 @@ from typing import Optional
import triton
import triton.language as tl
_sqrt2pi = math.sqrt(2.0 / math.pi)
_sqrt1_2 = math.sqrt(1.0 / 2)
_gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi)
......@@ -142,6 +141,7 @@ def gelu_grad(x):
pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization
return cdf + x * pdf
@triton.jit
def gelu_approx(x):
"""
......@@ -157,6 +157,6 @@ def gelu_approx_grad(x):
# CREDITS: Fast implementation proposed in
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))
return 0.5 * x * (
(1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)
) + 0.5 * (1 + tanh_out)
return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
1 + tanh_out
)
......@@ -9,8 +9,14 @@ from torch.autograd.function import FunctionCtx
from torch.cuda.amp import custom_fwd
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
from flash_attn.ops.triton.k_activations import gelu, gelu_grad, gelu_approx, gelu_approx_grad, squared_relu, squared_relu_grad
from flash_attn.ops.triton.k_activations import (
gelu,
gelu_approx,
gelu_approx_grad,
gelu_grad,
squared_relu,
squared_relu_grad,
)
# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications
......@@ -28,7 +34,12 @@ def get_configs_io_bound():
num_warps = 2 if block_n <= 64 else 4
configs.append(
triton.Config(
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1},
{
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"BLOCK_K": block_k,
"SPLIT_K": 1,
},
num_stages=num_stages,
num_warps=num_warps,
)
......@@ -43,29 +54,75 @@ def get_configs_io_bound():
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2
),
# good for int8
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2
),
]
+ get_configs_io_bound(),
key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"],
prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
prune_configs_by={
"early_config_prune": early_config_prune,
"perf_model": estimate_matmul_time,
"top_k": 10,
},
)
@triton.heuristics(
{
......@@ -204,7 +261,7 @@ def triton_linear_act(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: str = 'id',
activation: str = "id",
save_act_input: bool = False,
) -> torch.Tensor:
"""
......@@ -221,7 +278,7 @@ def triton_linear_act(
# dtype = torch.get_autocast_gpu_dtype()
# x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
assert activation in ['id', 'gelu', 'gelu_approx', 'squared_relu']
assert activation in ["id", "gelu", "gelu_approx", "squared_relu"]
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
......@@ -233,12 +290,20 @@ def triton_linear_act(
weight = weight.contiguous()
bias = bias.contiguous() if bias is not None else None
assert x.dtype == weight.dtype, f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}"
assert (
x.dtype == weight.dtype
), f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}"
if bias is not None:
assert x.dtype == bias.dtype, f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}"
assert x_reshaped.shape[1] == weight.shape[1], f"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}"
assert (
x.dtype == bias.dtype
), f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}"
assert (
x_reshaped.shape[1] == weight.shape[1]
), f"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}"
assert bias is None or bias.shape[0] == weight.shape[0], "Incompatible dimensions in between weight and bias"
assert (
bias is None or bias.shape[0] == weight.shape[0]
), "Incompatible dimensions in between weight and bias"
M, K = x_reshaped.shape
N, K = weight.shape
......@@ -278,35 +343,83 @@ def triton_linear_act(
if not save_act_input:
return output.reshape(*batch_shape, output.shape[-1])
else:
return (output.reshape(*batch_shape, output.shape[-1]),
act_input.reshape(*batch_shape, act_input.shape[-1]))
return (
output.reshape(*batch_shape, output.shape[-1]),
act_input.reshape(*batch_shape, act_input.shape[-1]),
)
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2
),
# good for int8
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2
),
]
+ get_configs_io_bound(),
key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"],
prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
prune_configs_by={
"early_config_prune": early_config_prune,
"perf_model": estimate_matmul_time,
"top_k": 10,
},
)
@triton.heuristics(
{
......@@ -395,7 +508,7 @@ def kernel_bwd(
B += BLOCK_K * stride_bk
# optional: fused activation (while the data is in shared memory)
if ACTIVATION != 'id':
if ACTIVATION != "id":
act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]
act_input = tl.load(act_in_ptrs).to(acc.dtype)
if ACTIVATION == "gelu":
......@@ -418,7 +531,7 @@ def kernel_bwd(
def triton_dgrad_act(
grad_output: torch.Tensor,
weight: torch.Tensor,
activation: str = 'id',
activation: str = "id",
act_input: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
......@@ -430,7 +543,7 @@ def triton_dgrad_act(
:param act_input: an optional tensor to save the activation inputs (for backward)
:return: result tensor
"""
assert activation in ['id', 'gelu', 'gelu_approx', 'squared_relu']
assert activation in ["id", "gelu", "gelu_approx", "squared_relu"]
batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1]
batch_dim = batch_shape.numel()
......@@ -441,10 +554,14 @@ def triton_dgrad_act(
if weight.stride(0) > 1 and weight.stride(1) > 1:
weight = weight.contiguous()
assert grad_output.dtype == weight.dtype, f"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}"
assert grad_output_reshaped.shape[1] == weight.shape[0], f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}"
if activation != 'id':
assert act_input is not None, f'act_input is required for activation {activation}'
assert (
grad_output.dtype == weight.dtype
), f"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}"
assert (
grad_output_reshaped.shape[1] == weight.shape[0]
), f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}"
if activation != "id":
assert act_input is not None, f"act_input is required for activation {activation}"
# M, N, K in bwd are different from M, N, K in fwd
M, K = grad_output_reshaped.shape
......
# The triton fused matmul + sqrelu is faster for fp16 but slower for bf16, compared
# to naive implementation.
import fused_dense_lib as fused_dense_cuda
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import custom_bwd, custom_fwd
import fused_dense_lib as fused_dense_cuda
from flash_attn.ops.triton.linear import triton_linear_act, triton_dgrad_act
from flash_attn.ops.activations import sqrelu_fwd, sqrelu_bwd
from flash_attn.ops.activations import sqrelu_bwd, sqrelu_fwd
from flash_attn.ops.triton.linear import triton_dgrad_act, triton_linear_act
class FusedDenseSqreluDenseFunc(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0):
......@@ -23,8 +21,9 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function):
"""
if torch.is_autocast_enabled():
dtype = torch.get_autocast_gpu_dtype()
x, weight1, bias1, weight2, bias2 = [a.to(dtype=dtype)
for a in [x, weight1, bias1, weight2, bias2]]
x, weight1, bias1, weight2, bias2 = [
a.to(dtype=dtype) for a in [x, weight1, bias1, weight2, bias2]
]
is_bf16 = x.dtype == torch.bfloat16
assert checkpoint_lvl in [0, 1, 2]
x = x.contiguous()
......@@ -35,13 +34,18 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function):
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
if is_bf16:
act_input = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
act_input = fused_dense_cuda.linear_bias_forward(
x.reshape(batch_dim, n), weight1, bias1
)
output1 = sqrelu_fwd(act_input)
else:
save_act_input = checkpoint_lvl != 2
result = triton_linear_act(
x.reshape(batch_dim, n), weight1, bias1, activation='squared_relu',
save_act_input=save_act_input
x.reshape(batch_dim, n),
weight1,
bias1,
activation="squared_relu",
save_act_input=save_act_input,
)
if save_act_input:
output1, act_input = result
......@@ -69,16 +73,21 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function):
if checkpoint_lvl == 0:
act_input, output1 = rest
elif checkpoint_lvl == 1:
act_input, = rest
(act_input,) = rest
output1 = sqrelu_fwd(act_input)
elif checkpoint_lvl == 2:
if is_bf16:
act_input = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
act_input = fused_dense_cuda.linear_bias_forward(
x.reshape(batch_dim, n), weight1, bias1
)
output1 = sqrelu_fwd(act_input)
else:
output1, act_input = triton_linear_act(
x.reshape(batch_dim, n), weight1, bias1, activation='squared_relu',
save_act_input=True
x.reshape(batch_dim, n),
weight1,
bias1,
activation="squared_relu",
save_act_input=True,
)
if is_bf16:
......@@ -92,8 +101,9 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function):
else:
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
grad_act_input = triton_dgrad_act(grad_output, weight2, activation='squared_relu',
act_input=act_input)
grad_act_input = triton_dgrad_act(
grad_output, weight2, activation="squared_relu", act_input=act_input
)
grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
x.reshape(batch_dim, n), weight1, grad_act_input
)
......@@ -104,9 +114,17 @@ fused_dense_sqrelu_dense_function = FusedDenseSqreluDenseFunc.apply
class FusedDenseSqreluDense(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, bias1=True, bias2=True,
checkpoint_lvl=0, device=None, dtype=None):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
bias1=True,
bias2=True,
checkpoint_lvl=0,
device=None,
dtype=None,
):
"""
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
......@@ -114,7 +132,7 @@ class FusedDenseSqreluDense(nn.Module):
2: recompute gelu_in and gelu_out in the bwd
"""
assert checkpoint_lvl in [0, 1, 2]
factory_kwargs = {'device': device, 'dtype': dtype}
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features * 4
......@@ -126,6 +144,6 @@ class FusedDenseSqreluDense(nn.Module):
def forward(self, x):
assert x.is_cuda
return fused_dense_sqrelu_dense_function(x, self.fc1.weight, self.fc1.bias,
self.fc2.weight, self.fc2.bias,
self.checkpoint_lvl)
return fused_dense_sqrelu_dense_function(
x, self.fc1.weight, self.fc1.bias, self.fc2.weight, self.fc2.bias, self.checkpoint_lvl
)
......@@ -5,31 +5,43 @@ import torch
import torch.utils.benchmark as benchmark
def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, amp=False,
amp_dtype=torch.float16, **kwinputs):
""" Use Pytorch Benchmark on the forward pass of an arbitrary function. """
def benchmark_forward(
fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs
):
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
if verbose:
print(desc, '- Forward pass')
print(desc, "- Forward pass")
def amp_wrapper(*inputs, **kwinputs):
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
fn(*inputs, **kwinputs)
t = benchmark.Timer(
stmt='fn_amp(*inputs, **kwinputs)',
globals={'fn_amp': amp_wrapper, 'inputs': inputs, 'kwinputs': kwinputs},
num_threads=torch.get_num_threads(),
)
stmt="fn_amp(*inputs, **kwinputs)",
globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
num_threads=torch.get_num_threads(),
)
m = t.timeit(repeats)
if verbose:
print(m)
return t, m
def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
amp_dtype=torch.float16, **kwinputs):
""" Use Pytorch Benchmark on the backward pass of an arbitrary function. """
def benchmark_backward(
fn,
*inputs,
grad=None,
repeats=10,
desc="",
verbose=True,
amp=False,
amp_dtype=torch.float16,
**kwinputs,
):
"""Use Pytorch Benchmark on the backward pass of an arbitrary function."""
if verbose:
print(desc, '- Backward pass')
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
print(desc, "- Backward pass")
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
y = fn(*inputs, **kwinputs)
if type(y) is tuple:
y = y[0]
......@@ -37,7 +49,8 @@ def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
grad = torch.randn_like(y)
else:
if grad.shape != y.shape:
raise RuntimeError('Grad shape does not match output shape')
raise RuntimeError("Grad shape does not match output shape")
def f(*inputs, y, grad):
# Set .grad to None to avoid extra operation of gradient accumulation
for x in inputs:
......@@ -46,22 +59,31 @@ def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
y.backward(grad, retain_graph=True)
t = benchmark.Timer(
stmt='f(*inputs, y=y, grad=grad)',
globals={'f': f, 'inputs': inputs, 'y': y, 'grad': grad},
num_threads=torch.get_num_threads(),
)
stmt="f(*inputs, y=y, grad=grad)",
globals={"f": f, "inputs": inputs, "y": y, "grad": grad},
num_threads=torch.get_num_threads(),
)
m = t.timeit(repeats)
if verbose:
print(m)
return t, m
def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
amp_dtype=torch.float16, **kwinputs):
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
def benchmark_combined(
fn,
*inputs,
grad=None,
repeats=10,
desc="",
verbose=True,
amp=False,
amp_dtype=torch.float16,
**kwinputs,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
if verbose:
print(desc, '- Forward + Backward pass')
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
print(desc, "- Forward + Backward pass")
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
y = fn(*inputs, **kwinputs)
if type(y) is tuple:
y = y[0]
......@@ -69,68 +91,142 @@ def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
grad = torch.randn_like(y)
else:
if grad.shape != y.shape:
raise RuntimeError('Grad shape does not match output shape')
raise RuntimeError("Grad shape does not match output shape")
def f(grad, *inputs, **kwinputs):
for x in inputs:
if isinstance(x, torch.Tensor):
x.grad = None
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
y = fn(*inputs, **kwinputs)
if type(y) is tuple:
y = y[0]
y.backward(grad, retain_graph=True)
t = benchmark.Timer(
stmt='f(grad, *inputs, **kwinputs)',
globals={'f': f, 'fn': fn, 'inputs': inputs, 'grad': grad, 'kwinputs': kwinputs},
num_threads=torch.get_num_threads(),
)
stmt="f(grad, *inputs, **kwinputs)",
globals={"f": f, "fn": fn, "inputs": inputs, "grad": grad, "kwinputs": kwinputs},
num_threads=torch.get_num_threads(),
)
m = t.timeit(repeats)
if verbose:
print(m)
return t, m
def benchmark_fwd_bwd(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
amp_dtype=torch.float16, **kwinputs):
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
def benchmark_fwd_bwd(
fn,
*inputs,
grad=None,
repeats=10,
desc="",
verbose=True,
amp=False,
amp_dtype=torch.float16,
**kwinputs,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
return (
benchmark_forward(fn, *inputs, repeats=repeats, desc=desc, verbose=verbose,
amp=amp, amp_dtype=amp_dtype, **kwinputs),
benchmark_backward(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose,
amp=amp, amp_dtype=amp_dtype, **kwinputs),
benchmark_forward(
fn,
*inputs,
repeats=repeats,
desc=desc,
verbose=verbose,
amp=amp,
amp_dtype=amp_dtype,
**kwinputs,
),
benchmark_backward(
fn,
*inputs,
grad=grad,
repeats=repeats,
desc=desc,
verbose=verbose,
amp=amp,
amp_dtype=amp_dtype,
**kwinputs,
),
)
def benchmark_all(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
amp_dtype=torch.float16, **kwinputs):
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
def benchmark_all(
fn,
*inputs,
grad=None,
repeats=10,
desc="",
verbose=True,
amp=False,
amp_dtype=torch.float16,
**kwinputs,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
return (
benchmark_forward(fn, *inputs, repeats=repeats, desc=desc, verbose=verbose,
amp=amp, amp_dtype=amp_dtype, **kwinputs),
benchmark_backward(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose,
amp=amp, amp_dtype=amp_dtype, **kwinputs),
benchmark_combined(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose,
amp=amp, amp_dtype=amp_dtype, **kwinputs),
benchmark_forward(
fn,
*inputs,
repeats=repeats,
desc=desc,
verbose=verbose,
amp=amp,
amp_dtype=amp_dtype,
**kwinputs,
),
benchmark_backward(
fn,
*inputs,
grad=grad,
repeats=repeats,
desc=desc,
verbose=verbose,
amp=amp,
amp_dtype=amp_dtype,
**kwinputs,
),
benchmark_combined(
fn,
*inputs,
grad=grad,
repeats=repeats,
desc=desc,
verbose=verbose,
amp=amp,
amp_dtype=amp_dtype,
**kwinputs,
),
)
def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False,
amp_dtype=torch.float16, cpu=False, verbose=True, **kwinputs):
""" Wrap benchmark functions in Pytorch profiler to see CUDA information. """
def pytorch_profiler(
fn,
*inputs,
trace_filename=None,
backward=False,
amp=False,
amp_dtype=torch.float16,
cpu=False,
verbose=True,
**kwinputs,
):
"""Wrap benchmark functions in Pytorch profiler to see CUDA information."""
if backward:
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
g = torch.randn_like(fn(*inputs, **kwinputs))
for _ in range(30): # Warm up
for _ in range(30): # Warm up
if backward:
for x in inputs:
if isinstance(x, torch.Tensor):
x.grad = None
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
out = fn(*inputs, **kwinputs)
# Backward should be done outside autocast
if backward:
out.backward(g, retain_graph=True)
activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [torch.profiler.ProfilerActivity.CUDA]
activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [
torch.profiler.ProfilerActivity.CUDA
]
with torch.profiler.profile(
activities=activities,
record_shapes=True,
......@@ -141,9 +237,10 @@ def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False
for x in inputs:
if isinstance(x, torch.Tensor):
x.grad = None
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
out = fn(*inputs, **kwinputs)
if backward: out.backward(g, retain_graph=True)
if backward:
out.backward(g, retain_graph=True)
if verbose:
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
print(prof.key_averages().table(row_limit=50))
......@@ -151,14 +248,14 @@ def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False
prof.export_chrome_trace(trace_filename)
def benchmark_memory(fn, *inputs, desc='', verbose=True, **kwinputs):
def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
fn(*inputs, **kwinputs)
torch.cuda.synchronize()
mem = torch.cuda.max_memory_allocated() / ((2 ** 20) * 1000)
mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000)
if verbose:
print(f'{desc} max memory: {mem}GB')
print(f"{desc} max memory: {mem}GB")
torch.cuda.empty_cache()
return mem
......@@ -17,10 +17,12 @@ if "reduce_scatter_tensor" not in dir(torch.distributed):
# Raw operation, does not support autograd, but does support async
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
world_size = torch.distributed.get_world_size(process_group)
output = torch.empty(world_size * input_.shape[0], *input_.shape[1:],
dtype=input_.dtype, device=input_.device)
handle = torch.distributed.all_gather_into_tensor(output, input_.contiguous(),
group=process_group, async_op=async_op)
output = torch.empty(
world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
)
handle = torch.distributed.all_gather_into_tensor(
output, input_.contiguous(), group=process_group, async_op=async_op
)
return output, handle
......@@ -28,11 +30,12 @@ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool =
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
world_size = torch.distributed.get_world_size(process_group)
assert input_.shape[0] % world_size == 0
output = torch.empty(input_.shape[0] // world_size, *input_.shape[1:],
dtype=input_.dtype, device=input_.device)
handle = torch.distributed.reduce_scatter_tensor(output, input_.contiguous(),
group=process_group,
async_op=async_op)
output = torch.empty(
input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
)
handle = torch.distributed.reduce_scatter_tensor(
output, input_.contiguous(), group=process_group, async_op=async_op
)
return output, handle
......@@ -102,8 +105,9 @@ all_reduce = AllReduceFunc.apply
def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
# We want to iterate over parameters with _shared_params=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
pamams_shared = {name: p for name, p in model.named_parameters()
if getattr(p, '_shared_params', False)}
pamams_shared = {
name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
}
for _, p in sorted(pamams_shared.items()):
with torch.no_grad():
# Broadcast needs src to be global rank, not group rank
......@@ -116,8 +120,9 @@ def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
# We want to iterate over parameters with _sequence_parallel=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
params_seqparallel = {name: p for name, p in model.named_parameters()
if getattr(p, '_sequence_parallel', False)}
params_seqparallel = {
name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
}
grads = [p.grad for _, p in sorted(params_seqparallel.items())]
if grads:
with torch.no_grad():
......
# Copyright (c) 2023, Tri Dao.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
from typing import Optional, Union, Sequence, Callable
import gc
import time
from dataclasses import dataclass, field
from collections import namedtuple
from dataclasses import dataclass, field
from typing import Callable, Optional, Sequence, Union
import torch
from torch import Tensor
from torch.profiler import profile, record_function, ProfilerActivity
from einops import rearrange
from torch import Tensor
from torch.profiler import ProfilerActivity, profile, record_function
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
......@@ -20,6 +17,7 @@ from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoder
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_sequence_len: int
max_batch_size: int
sequence_len_offset: int = 0
......@@ -38,11 +36,13 @@ def modify_logits_for_top_p_filtering(logits, top_p):
# First sort and calculate cumulative sum of probabilities.
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, float('-inf'))
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
logits = logits.masked_fill(indices_to_remove, float("-inf"))
def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
......@@ -54,7 +54,7 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
return logits.argmax(dim=-1)
else:
if top_p > 0.0:
assert top_p <= 1.0, 'top-p should be in (0, 1].'
assert top_p <= 1.0, "top-p should be in (0, 1]."
if top_k > 0:
top_k = min(top_k, logits.size(-1)) # Safety check
logits_top, indices = torch.topk(logits, top_k, dim=-1)
......@@ -62,17 +62,31 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
modify_logits_for_top_p_filtering(logits_top, top_p)
return indices[
torch.arange(indices.shape[0], device=indices.device),
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
]
else:
logits_top = logits / temperature
modify_logits_for_top_p_filtering(logits_top, top_p)
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
dim=-1
)
def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
eos_token_id=None, teacher_outputs=None, vocab_size=None, tensor_parallel=1,
fused_ft_kernel=False, cg=False, timing=False):
def decode(
input_ids,
model,
max_length,
top_k=1,
top_p=0.0,
temperature=1.0,
eos_token_id=None,
teacher_outputs=None,
vocab_size=None,
tensor_parallel=1,
fused_ft_kernel=False,
cg=False,
timing=False,
):
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
......@@ -92,19 +106,24 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
if cg:
assert fused_ft_kernel
if not hasattr(model, '_decoding_cache'):
if not hasattr(model, "_decoding_cache"):
model._decoding_cache = None
model._decoding_cache = update_graph_cache(
model, model._decoding_cache, batch_size, seqlen_og, max_length,
tensor_parallel=tensor_parallel
model,
model._decoding_cache,
batch_size,
seqlen_og,
max_length,
tensor_parallel=tensor_parallel,
)
inference_params = model._decoding_cache.inference_params
inference_params.max_sequence_len = max_length
inference_params.max_batch_size = batch_size
inference_params.sequence_len_offset = 0
else:
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size,
fused_ft_kernel=fused_ft_kernel)
inference_params = InferenceParams(
max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel
)
scores = []
with torch.inference_mode():
if timing:
......@@ -123,18 +142,32 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
sequences = [next_token]
inference_params.sequence_len_offset = seqlen_og
while True:
position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset,
dtype=torch.long, device=input_ids.device)
position_ids = torch.full(
(batch_size, 1),
inference_params.sequence_len_offset,
dtype=torch.long,
device=input_ids.device,
)
if not cg:
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
inference_params=inference_params, last_token_only=True).logits
logits = model(
rearrange(next_token, "b -> b 1"),
position_ids=position_ids,
inference_params=inference_params,
last_token_only=True,
).logits
else:
logits = model._decoding_cache.run(rearrange(next_token, 'b -> b 1'), position_ids,
inference_params.sequence_len_offset)
logits = model._decoding_cache.run(
rearrange(next_token, "b -> b 1"),
position_ids,
inference_params.sequence_len_offset,
)
if vocab_size is not None:
logits = logits[..., :vocab_size]
scores.append(logits if not cg else logits.clone())
if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset + 1:
if (
teacher_outputs is None
or teacher_output_len <= inference_params.sequence_len_offset + 1
):
next_token = sample(logits, top_k=top_k, temperature=temperature)
else:
next_token = teacher_outputs[:, inference_params.sequence_len_offset + 1]
......@@ -148,30 +181,45 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
return output_cls(
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
scores=tuple(scores)
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1), scores=tuple(scores)
)
class GenerationMixin:
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
raise NotImplementedError
def generate(self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0,
return_dict_in_generate=False, output_scores=False, **kwargs):
output = decode(input_ids, self, max_length, top_k=top_k, top_p=top_p,
temperature=temperature, **kwargs)
def generate(
self,
input_ids,
max_length,
top_k=1,
top_p=0.0,
temperature=1.0,
return_dict_in_generate=False,
output_scores=False,
**kwargs,
):
output = decode(
input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs
)
if not output_scores:
output.scores = None
return output if return_dict_in_generate else output.sequences
def allocate_inference_cache(max_batch_size, max_seqlen, nheads, headdim, layers: Union[int, Sequence],
device, dtype=torch.float16):
def allocate_inference_cache(
max_batch_size,
max_seqlen,
nheads,
headdim,
layers: Union[int, Sequence],
device,
dtype=torch.float16,
):
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if dtype == torch.float32 else 8
assert headdim % packsize == 0
......@@ -179,9 +227,13 @@ def allocate_inference_cache(max_batch_size, max_seqlen, nheads, headdim, layers
v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim)
if isinstance(layers, int):
layers = range(layers)
return {i: (torch.empty(k_cache_shape, device=device, dtype=dtype),
torch.empty(v_cache_shape, device=device, dtype=dtype))
for i in layers}
return {
i: (
torch.empty(k_cache_shape, device=device, dtype=dtype),
torch.empty(v_cache_shape, device=device, dtype=dtype),
)
for i in layers
}
def seqlen_to_seqlen_type(seqlen: int) -> int:
......@@ -211,49 +263,70 @@ class DecodingCGCache:
@torch.inference_mode()
def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1,
dtype=None, n_warmups=2):
def update_graph_cache(
model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1, dtype=None, n_warmups=2
):
if cache is None:
cache = DecodingCGCache()
param_example = next(iter(model.parameters()))
device = param_example.device
if dtype is None:
dtype = param_example.dtype
if ((device, dtype) != (cache.device, cache.dtype) or batch_size > cache.max_batch_size
or max_seqlen > cache.max_seqlen): # Invalidate the cache
if (
(device, dtype) != (cache.device, cache.dtype)
or batch_size > cache.max_batch_size
or max_seqlen > cache.max_seqlen
): # Invalidate the cache
cache.callables = {}
cache.mempool = None
cache.inference_params = None
gc.collect()
cache.device, cache.dtype = device, dtype
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
if hasattr(model, 'allocate_inference_cache'):
if hasattr(model, "allocate_inference_cache"):
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
else:
headdim = getattr(model.config, 'head_dim',
model.config.hidden_size // model.config.num_attention_heads)
headdim = getattr(
model.config,
"head_dim",
model.config.hidden_size // model.config.num_attention_heads,
)
inf_cache = allocate_inference_cache(
batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim,
model.config.num_hidden_layers, device, dtype
batch_size,
max_seqlen,
model.config.num_attention_heads // tensor_parallel,
headdim,
model.config.num_hidden_layers,
device,
dtype,
)
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
cache.inference_params = InferenceParams(
max_sequence_len=max_seqlen, max_batch_size=batch_size,
sequence_len_offset=seqlen_og, key_value_memory_dict=inf_cache, fused_ft_kernel=True,
lengths_per_sample=lengths_per_sample
max_sequence_len=max_seqlen,
max_batch_size=batch_size,
sequence_len_offset=seqlen_og,
key_value_memory_dict=inf_cache,
fused_ft_kernel=True,
lengths_per_sample=lengths_per_sample,
)
cache.mempool = torch.cuda.graphs.graph_pool_handle()
for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1):
if (batch_size, s_type) not in cache.callables:
max_seqlen_ = min(max(seqlen_og, seqlen_type_to_max_seqlen(s_type)), max_seqlen)
cache.callables[batch_size, s_type] = capture_graph(
model, cache.inference_params, batch_size, max_seqlen_, mempool=cache.mempool,
n_warmups=n_warmups
model,
cache.inference_params,
batch_size,
max_seqlen_,
mempool=cache.mempool,
n_warmups=n_warmups,
)
def dispatch(input_ids, position_ids, seqlen):
batch_size = input_ids.shape[0]
return cache.callables[batch_size, seqlen_to_seqlen_type(seqlen)](input_ids, position_ids, seqlen)
return cache.callables[batch_size, seqlen_to_seqlen_type(seqlen)](
input_ids, position_ids, seqlen
)
cache.run = dispatch
cache.inference_params.sequence_len_offset = 0 # Reset so it's not confusing
......@@ -275,8 +348,12 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(n_warmups):
logits = model(input_ids, position_ids=position_ids, inference_params=inference_params,
last_token_only=True).logits
logits = model(
input_ids,
position_ids=position_ids,
inference_params=inference_params,
last_token_only=True,
).logits
s.synchronize()
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
# which requires that graph launch and non-captured launch to not overlap (I think,
......@@ -288,8 +365,12 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
# To allow capture, automatically sets a side stream as the current stream in the context
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=mempool):
logits = model(input_ids, position_ids=position_ids, inference_params=inference_params,
last_token_only=True).logits
logits = model(
input_ids,
position_ids=position_ids,
inference_params=inference_params,
last_token_only=True,
).logits
def run(new_input_ids, new_position_ids, seqlen):
inference_params.lengths_per_sample[:] = seqlen
......
......@@ -3,13 +3,18 @@ from functools import partial
import torch
from safetensors.torch import load_file as safe_load_file
from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
from transformers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
)
from transformers.utils.hub import cached_file, get_checkpoint_shard_files
def state_dict_from_pretrained(model_name, device=None, dtype=None):
# If not fp32, then we don't want to load directly to the GPU
mapped_device = 'cpu' if dtype not in [torch.float32, None] else device
mapped_device = "cpu" if dtype not in [torch.float32, None] else device
is_sharded = False
load_safe = False
resolved_archive_file = None
......@@ -20,19 +25,23 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None):
safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME)
if os.path.isfile(weights_path):
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
_raise_exceptions_for_missing_entries=False)
resolved_archive_file = cached_file(
model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False
)
elif os.path.isfile(weights_index_path):
resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME,
_raise_exceptions_for_missing_entries=False)
resolved_archive_file = cached_file(
model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False
)
is_sharded = True
elif os.path.isfile(safe_weights_path):
resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_NAME,
_raise_exceptions_for_missing_entries=False)
resolved_archive_file = cached_file(
model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False
)
load_safe = True
elif os.path.isfile(safe_weights_index_path):
resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME,
_raise_exceptions_for_missing_entries=False)
resolved_archive_file = cached_file(
model_name, SAFE_WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False
)
is_sharded = True
load_safe = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment