Commit f1a73d07 authored by Tri Dao's avatar Tri Dao
Browse files

Run isort and black on python files

parent cbb4cf5f
...@@ -2,42 +2,52 @@ ...@@ -2,42 +2,52 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from einops import rearrange from einops import rearrange
from torch import Tensor
from flash_attn.utils.distributed import reduce_scatter, all_reduce from flash_attn.utils.distributed import all_reduce, reduce_scatter
class GPT2Embeddings(nn.Module): class GPT2Embeddings(nn.Module):
def __init__(
def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None, self,
word_embed_proj_dim=None, device=None, dtype=None): embed_dim,
vocab_size,
max_position_embeddings,
padding_idx=None,
word_embed_proj_dim=None,
device=None,
dtype=None,
):
""" """
If max_position_embeddings <= 0, there's no position embeddings If max_position_embeddings <= 0, there's no position embeddings
If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
the project up to embed_dim the project up to embed_dim
""" """
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
if word_embed_proj_dim is None: if word_embed_proj_dim is None:
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx, self.word_embeddings = nn.Embedding(
**factory_kwargs) vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
)
self.project_in = None self.project_in = None
else: else:
self.word_embeddings = nn.Embedding(vocab_size, word_embed_proj_dim, self.word_embeddings = nn.Embedding(
padding_idx=padding_idx, **factory_kwargs) vocab_size, word_embed_proj_dim, padding_idx=padding_idx, **factory_kwargs
self.project_in = nn.Linear(word_embed_proj_dim, embed_dim, bias=False, )
**factory_kwargs) self.project_in = nn.Linear(
word_embed_proj_dim, embed_dim, bias=False, **factory_kwargs
)
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
if self.max_position_embeddings > 0: if self.max_position_embeddings > 0:
self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim, self.position_embeddings = nn.Embedding(
**factory_kwargs) max_position_embeddings, embed_dim, **factory_kwargs
)
def forward(self, input_ids, position_ids=None): def forward(self, input_ids, position_ids=None):
""" """
input_ids: (batch, seqlen) input_ids: (batch, seqlen)
position_ids: (batch, seqlen) position_ids: (batch, seqlen)
""" """
batch_size, seqlen = input_ids.shape batch_size, seqlen = input_ids.shape
embeddings = self.word_embeddings(input_ids) embeddings = self.word_embeddings(input_ids)
...@@ -52,31 +62,39 @@ class GPT2Embeddings(nn.Module): ...@@ -52,31 +62,39 @@ class GPT2Embeddings(nn.Module):
class BertEmbeddings(nn.Module): class BertEmbeddings(nn.Module):
def __init__(
def __init__(self, embed_dim, vocab_size, max_position_embeddings, type_vocab_size, self,
padding_idx=None, device=None, dtype=None): embed_dim,
vocab_size,
max_position_embeddings,
type_vocab_size,
padding_idx=None,
device=None,
dtype=None,
):
""" """
If max_position_embeddings <= 0, there's no position embeddings If max_position_embeddings <= 0, there's no position embeddings
If type_vocab_size <= 0, there's no token type embeddings If type_vocab_size <= 0, there's no token type embeddings
""" """
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx, self.word_embeddings = nn.Embedding(
**factory_kwargs) vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
)
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size self.type_vocab_size = type_vocab_size
if self.max_position_embeddings > 0: if self.max_position_embeddings > 0:
self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim, self.position_embeddings = nn.Embedding(
**factory_kwargs) max_position_embeddings, embed_dim, **factory_kwargs
)
if self.type_vocab_size > 0: if self.type_vocab_size > 0:
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
**factory_kwargs)
def forward(self, input_ids, position_ids=None, token_type_ids=None): def forward(self, input_ids, position_ids=None, token_type_ids=None):
""" """
input_ids: (batch, seqlen) input_ids: (batch, seqlen)
position_ids: (batch, seqlen) position_ids: (batch, seqlen)
token_type_ids: (batch, seqlen) token_type_ids: (batch, seqlen)
""" """
batch_size, seqlen = input_ids.shape batch_size, seqlen = input_ids.shape
embeddings = self.word_embeddings(input_ids) embeddings = self.word_embeddings(input_ids)
...@@ -94,16 +112,17 @@ class BertEmbeddings(nn.Module): ...@@ -94,16 +112,17 @@ class BertEmbeddings(nn.Module):
class VocabParallelEmbedding(nn.Embedding): class VocabParallelEmbedding(nn.Embedding):
def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs): def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
self.process_group = process_group self.process_group = process_group
if process_group is not None: if process_group is not None:
world_size = torch.distributed.get_world_size(process_group) world_size = torch.distributed.get_world_size(process_group)
if num_embeddings % world_size != 0: if num_embeddings % world_size != 0:
raise ValueError(f'num_embeddings ({num_embeddings}) must be divisible by ' raise ValueError(
f'world_size ({world_size})') f"num_embeddings ({num_embeddings}) must be divisible by "
f"world_size ({world_size})"
)
if world_size > 1 and padding_idx is not None: if world_size > 1 and padding_idx is not None:
raise RuntimeError('ParallelEmbedding does not support padding_idx') raise RuntimeError("ParallelEmbedding does not support padding_idx")
else: else:
world_size = 1 world_size = 1
super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs) super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)
...@@ -125,33 +144,45 @@ class VocabParallelEmbedding(nn.Embedding): ...@@ -125,33 +144,45 @@ class VocabParallelEmbedding(nn.Embedding):
class ColumnParallelEmbedding(nn.Embedding): class ColumnParallelEmbedding(nn.Embedding):
def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs): def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
self.process_group = process_group self.process_group = process_group
if process_group is not None: if process_group is not None:
world_size = torch.distributed.get_world_size(process_group) world_size = torch.distributed.get_world_size(process_group)
if embedding_dim % world_size != 0: if embedding_dim % world_size != 0:
raise ValueError(f'embedding_dim ({embedding_dim}) must be divisible by ' raise ValueError(
f'world_size ({world_size})') f"embedding_dim ({embedding_dim}) must be divisible by "
f"world_size ({world_size})"
)
else: else:
world_size = 1 world_size = 1
super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs) super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
class ParallelGPT2Embeddings(nn.Module): class ParallelGPT2Embeddings(nn.Module):
def __init__(
def __init__(self, embed_dim, vocab_size, max_position_embeddings, process_group, self,
padding_idx=None, sequence_parallel=True, device=None, dtype=None): embed_dim,
vocab_size,
max_position_embeddings,
process_group,
padding_idx=None,
sequence_parallel=True,
device=None,
dtype=None,
):
""" """
If max_position_embeddings <= 0, there's no position embeddings If max_position_embeddings <= 0, there's no position embeddings
""" """
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
self.process_group = process_group self.process_group = process_group
self.sequence_parallel = sequence_parallel self.sequence_parallel = sequence_parallel
self.word_embeddings = VocabParallelEmbedding( self.word_embeddings = VocabParallelEmbedding(
vocab_size, embed_dim, padding_idx=padding_idx, process_group=process_group, vocab_size,
**factory_kwargs embed_dim,
padding_idx=padding_idx,
process_group=process_group,
**factory_kwargs,
) )
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
if self.max_position_embeddings > 0: if self.max_position_embeddings > 0:
...@@ -161,8 +192,8 @@ class ParallelGPT2Embeddings(nn.Module): ...@@ -161,8 +192,8 @@ class ParallelGPT2Embeddings(nn.Module):
def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False): def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
""" """
input_ids: (batch, seqlen) input_ids: (batch, seqlen)
position_ids: (batch, seqlen) position_ids: (batch, seqlen)
""" """
batch_size, seqlen = input_ids.shape batch_size, seqlen = input_ids.shape
world_size = torch.distributed.get_world_size(self.process_group) world_size = torch.distributed.get_world_size(self.process_group)
...@@ -176,8 +207,10 @@ class ParallelGPT2Embeddings(nn.Module): ...@@ -176,8 +207,10 @@ class ParallelGPT2Embeddings(nn.Module):
else: else:
partition_dim = self.position_embeddings.embedding_dim partition_dim = self.position_embeddings.embedding_dim
rank = torch.distributed.get_rank(self.process_group) rank = torch.distributed.get_rank(self.process_group)
embeddings[..., rank * partition_dim:(rank + 1) * partition_dim] += position_embeddings embeddings[
..., rank * partition_dim : (rank + 1) * partition_dim
] += position_embeddings
if combine_batch_seqlen_dim: if combine_batch_seqlen_dim:
embeddings = rearrange(embeddings, 'b s d -> (b s) d') embeddings = rearrange(embeddings, "b s d -> (b s) d")
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group) return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
...@@ -732,8 +732,12 @@ class ParallelMHA(nn.Module): ...@@ -732,8 +732,12 @@ class ParallelMHA(nn.Module):
self.num_heads % self.num_heads_kv == 0 self.num_heads % self.num_heads_kv == 0
), "num_heads must be divisible by num_heads_kv" ), "num_heads must be divisible by num_heads_kv"
self.num_heads_per_rank = get_dim_for_local_rank(self.num_heads, self.world_size, self.local_rank) self.num_heads_per_rank = get_dim_for_local_rank(
self.num_heads_kv_per_rank = get_dim_for_local_rank(self.num_heads, self.world_size, self.local_rank) self.num_heads, self.world_size, self.local_rank
)
self.num_heads_kv_per_rank = get_dim_for_local_rank(
self.num_heads, self.world_size, self.local_rank
)
self.head_dim = self.embed_dim // num_heads self.head_dim = self.embed_dim // num_heads
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
......
...@@ -17,10 +17,19 @@ except ImportError: ...@@ -17,10 +17,19 @@ except ImportError:
class Mlp(nn.Module): class Mlp(nn.Module):
def __init__(
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu, self,
bias1=True, bias2=True, return_residual=False, device=None, dtype=None): in_features,
factory_kwargs = {'device': device, 'dtype': dtype} hidden_features=None,
out_features=None,
activation=F.gelu,
bias1=True,
bias2=True,
return_residual=False,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
out_features = out_features if out_features is not None else in_features out_features = out_features if out_features is not None else in_features
hidden_features = hidden_features if hidden_features is not None else in_features * 4 hidden_features = hidden_features if hidden_features is not None else in_features * 4
...@@ -37,21 +46,42 @@ class Mlp(nn.Module): ...@@ -37,21 +46,42 @@ class Mlp(nn.Module):
class ParallelMLP(nn.Module): class ParallelMLP(nn.Module):
def __init__(
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu, self,
process_group: ProcessGroup = None, sequence_parallel=True, in_features,
bias1=True, bias2=True, device=None, dtype=None): hidden_features=None,
factory_kwargs = {'device': device, 'dtype': dtype} out_features=None,
activation=F.gelu,
process_group: ProcessGroup = None,
sequence_parallel=True,
bias1=True,
bias2=True,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
assert ColumnParallelLinear is not None, "Need to install fused_dense" assert ColumnParallelLinear is not None, "Need to install fused_dense"
assert RowParallelLinear is not None, "Need to install fused_dense" assert RowParallelLinear is not None, "Need to install fused_dense"
out_features = out_features if out_features is not None else in_features out_features = out_features if out_features is not None else in_features
hidden_features = hidden_features if hidden_features is not None else in_features * 4 hidden_features = hidden_features if hidden_features is not None else in_features * 4
self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group, bias=bias1, self.fc1 = ColumnParallelLinear(
sequence_parallel=sequence_parallel, **factory_kwargs) in_features,
hidden_features,
process_group,
bias=bias1,
sequence_parallel=sequence_parallel,
**factory_kwargs,
)
self.activation = activation self.activation = activation
self.fc2 = RowParallelLinear(hidden_features, out_features, process_group, bias=bias2, self.fc2 = RowParallelLinear(
sequence_parallel=sequence_parallel, **factory_kwargs) hidden_features,
out_features,
process_group,
bias=bias2,
sequence_parallel=sequence_parallel,
**factory_kwargs,
)
def forward(self, x): def forward(self, x):
y = self.fc1(x) y = self.fc1(x)
...@@ -61,15 +91,25 @@ class ParallelMLP(nn.Module): ...@@ -61,15 +91,25 @@ class ParallelMLP(nn.Module):
class GatedMlp(nn.Module): class GatedMlp(nn.Module):
def __init__(
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.sigmoid, self,
bias1=True, bias2=True, multiple_of=256, return_residual=False, in_features,
device=None, dtype=None): hidden_features=None,
factory_kwargs = {'device': device, 'dtype': dtype} out_features=None,
activation=F.sigmoid,
bias1=True,
bias2=True,
multiple_of=256,
return_residual=False,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
out_features = out_features if out_features is not None else in_features out_features = out_features if out_features is not None else in_features
hidden_features = (hidden_features if hidden_features is not None hidden_features = (
else int(8 * in_features / 3)) hidden_features if hidden_features is not None else int(8 * in_features / 3)
)
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
self.return_residual = return_residual self.return_residual = return_residual
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs) self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
...@@ -88,24 +128,48 @@ class GatedMlp(nn.Module): ...@@ -88,24 +128,48 @@ class GatedMlp(nn.Module):
class ParallelGatedMlp(nn.Module): class ParallelGatedMlp(nn.Module):
""" Parallel GatedMlp """ """Parallel GatedMlp"""
def __init__(self, in_features, process_group, hidden_features=None, out_features=None, def __init__(
activation=F.sigmoid, bias1=True, bias2=True, multiple_of=256, self,
sequence_parallel=True, device=None, dtype=None): in_features,
factory_kwargs = {'device': device, 'dtype': dtype} process_group,
hidden_features=None,
out_features=None,
activation=F.sigmoid,
bias1=True,
bias2=True,
multiple_of=256,
sequence_parallel=True,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
out_features = out_features if out_features is not None else in_features out_features = out_features if out_features is not None else in_features
hidden_features = (hidden_features if hidden_features is not None hidden_features = (
else int(8 * in_features / 3)) hidden_features if hidden_features is not None else int(8 * in_features / 3)
)
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
if ColumnParallelLinear is None or RowParallelLinear is None: if ColumnParallelLinear is None or RowParallelLinear is None:
raise ImportError('fused_dense is not installed') raise ImportError("fused_dense is not installed")
self.fc1 = ColumnParallelLinear(in_features, 2 * hidden_features, process_group, bias=bias1, self.fc1 = ColumnParallelLinear(
sequence_parallel=sequence_parallel, **factory_kwargs) in_features,
2 * hidden_features,
process_group,
bias=bias1,
sequence_parallel=sequence_parallel,
**factory_kwargs,
)
self.activation = activation self.activation = activation
self.fc2 = RowParallelLinear(hidden_features, out_features, process_group, bias=bias2, self.fc2 = RowParallelLinear(
sequence_parallel=sequence_parallel, **factory_kwargs) hidden_features,
out_features,
process_group,
bias=bias2,
sequence_parallel=sequence_parallel,
**factory_kwargs,
)
def forward(self, x): def forward(self, x):
y = self.fc1(x) y = self.fc1(x)
......
...@@ -5,7 +5,6 @@ import torch ...@@ -5,7 +5,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
# 1/sqrt(2*pi)-> 0.3989423 # 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678 # 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456 # sqrt(2/pi) -> 0.79788456
...@@ -18,17 +17,19 @@ def bias_gelu(y, bias): ...@@ -18,17 +17,19 @@ def bias_gelu(y, bias):
x = bias + y x = bias + y
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype) return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)
# gradient of tanh approximation of gelu # gradient of tanh approximation of gelu
# gradient of actual gelu is: # gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script @torch.jit.script
def bias_gelu_back(g, y, bias): def bias_gelu_back(g, y, bias):
"""Assume that y has shape (B, D) and bias has shape (D) """Assume that y has shape (B, D) and bias has shape (D)"""
"""
x = bias + y x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
1 + tanh_out
)
grad_y = ff * g grad_y = ff * g
return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype) return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
...@@ -56,6 +57,7 @@ bias_gelu_impl = GeLUFunction.apply ...@@ -56,6 +57,7 @@ bias_gelu_impl = GeLUFunction.apply
def gelu_fwd(x): def gelu_fwd(x):
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype) return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)
# gradient of tanh approximation of gelu # gradient of tanh approximation of gelu
# gradient of actual gelu is: # gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
...@@ -63,7 +65,9 @@ def gelu_fwd(x): ...@@ -63,7 +65,9 @@ def gelu_fwd(x):
def gelu_bwd(g, x): def gelu_bwd(g, x):
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
1 + tanh_out
)
return (ff * g).to(dtype=x.dtype) return (ff * g).to(dtype=x.dtype)
...@@ -76,10 +80,11 @@ class FastGeLUFunction(torch.autograd.Function): ...@@ -76,10 +80,11 @@ class FastGeLUFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
input, = ctx.saved_tensors (input,) = ctx.saved_tensors
tmp = gelu_bwd(grad_output, input) tmp = gelu_bwd(grad_output, input)
return tmp return tmp
fast_gelu_impl = FastGeLUFunction.apply fast_gelu_impl = FastGeLUFunction.apply
......
...@@ -10,6 +10,10 @@ import fused_dense_lib as fused_dense_cuda ...@@ -10,6 +10,10 @@ import fused_dense_lib as fused_dense_cuda
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup
from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_bwd, sqrelu_fwd from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_bwd, sqrelu_fwd
from flash_attn.utils.distributed import ( from flash_attn.utils.distributed import (
all_gather_raw, all_gather_raw,
...@@ -18,9 +22,6 @@ from flash_attn.utils.distributed import ( ...@@ -18,9 +22,6 @@ from flash_attn.utils.distributed import (
reduce_scatter, reduce_scatter,
reduce_scatter_raw, reduce_scatter_raw,
) )
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup
class FusedDenseFunc(torch.autograd.Function): class FusedDenseFunc(torch.autograd.Function):
......
# Copyright (c) 2022, Tri Dao. # Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
import dropout_layer_norm
import torch import torch
from torch.nn import init from torch.nn import init
import dropout_layer_norm
def maybe_align(x, alignment_in_bytes=16): def maybe_align(x, alignment_in_bytes=16):
"""Assume that x already has last dim divisible by alignment_in_bytes """Assume that x already has last dim divisible by alignment_in_bytes"""
"""
# TD [2023-07-04] I'm not 100% sure that clone will align the memory # TD [2023-07-04] I'm not 100% sure that clone will align the memory
# https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440 # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone() return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()
def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscale, dropout_p, def _dropout_add_layer_norm_forward(
epsilon, residual_in_fp32=False, is_rms_norm=False): x0,
""" Assume that arguments are contiguous and aligned to 16 bytes residual,
""" gamma,
beta,
rowscale,
colscale,
dropout_p,
epsilon,
residual_in_fp32=False,
is_rms_norm=False,
):
"""Assume that arguments are contiguous and aligned to 16 bytes"""
hidden_size = gamma.numel() hidden_size = gamma.numel()
x0mat = x0.view((-1, hidden_size)) x0mat = x0.view((-1, hidden_size))
residualmat = residual.view((-1, hidden_size)) if residual is not None else None residualmat = residual.view((-1, hidden_size)) if residual is not None else None
rowscale = rowscale.view(-1) if rowscale is not None else None rowscale = rowscale.view(-1) if rowscale is not None else None
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
x0mat, residualmat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon, x0mat,
1.0, 0, None, residual_in_fp32, is_rms_norm residualmat,
gamma,
beta,
rowscale,
colscale,
None,
None,
dropout_p,
epsilon,
1.0,
0,
None,
residual_in_fp32,
is_rms_norm,
) )
# dmask is None if dropout_p == 0.0 # dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, def _dropout_add_layer_norm_backward(
dropout_p, has_residual, is_rms_norm=False): dz,
""" Assume that arguments are contiguous and aligned to 16 bytes dx,
x,
x0,
dmask,
mu,
rsigma,
gamma,
rowscale,
colscale,
dropout_p,
has_residual,
is_rms_norm=False,
):
"""Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd). (x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale. x0 must not be None if we have colscale.
...@@ -46,10 +79,25 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro ...@@ -46,10 +79,25 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
rowscale = rowscale.view(-1) if rowscale is not None else None rowscale = rowscale.view(-1) if rowscale is not None else None
if colscale is not None: if colscale is not None:
assert x0 is not None, 'x0 is required to compute the gradient of colscale' assert x0 is not None, "x0 is required to compute the gradient of colscale"
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None, dzmat,
dropout_p, 1.0, 0, has_residual, is_rms_norm dxmat,
xmat,
x0mat,
dmask,
mu,
rsigma,
gamma,
rowscale,
colscale,
None,
None,
dropout_p,
1.0,
0,
has_residual,
is_rms_norm,
) )
# dresidualmat is None if not has_residual # dresidualmat is None if not has_residual
if colscale is None: if colscale is None:
...@@ -59,29 +107,68 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro ...@@ -59,29 +107,68 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
return dx0mat, dresidualmat, dgamma, dbeta, dcolscale return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale, x0_subset, def _dropout_add_layer_norm_subset_forward(
out_subset, dropout_p, epsilon, rowscale_const, x0,
out_numrows, residual_in_fp32=False, is_rms_norm=False): residual,
""" Assume that arguments are contiguous and aligned to 16 bytes gamma,
""" beta,
colscale,
x0_subset,
out_subset,
dropout_p,
epsilon,
rowscale_const,
out_numrows,
residual_in_fp32=False,
is_rms_norm=False,
):
"""Assume that arguments are contiguous and aligned to 16 bytes"""
hidden_size = gamma.numel() hidden_size = gamma.numel()
x0mat = x0.view((-1, hidden_size)) x0mat = x0.view((-1, hidden_size))
residualmat = residual.view((-1, hidden_size)) if residual is not None else None residualmat = residual.view((-1, hidden_size)) if residual is not None else None
x0_subset = x0_subset.view(-1) if x0_subset is not None else None x0_subset = x0_subset.view(-1) if x0_subset is not None else None
out_subset = out_subset.view(-1) if out_subset is not None else None out_subset = out_subset.view(-1) if out_subset is not None else None
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
x0mat, residualmat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon, x0mat,
rowscale_const, out_numrows, None, residual_in_fp32, is_rms_norm residualmat,
gamma,
beta,
None,
colscale,
x0_subset,
out_subset,
dropout_p,
epsilon,
rowscale_const,
out_numrows,
None,
residual_in_fp32,
is_rms_norm,
) )
# dmask is None if dropout_p == 0.0 # dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, def _dropout_add_layer_norm_subset_backward(
x0_subset, out_subset, dropout_p, rowscale_const, dz,
x0_numrows, has_residual, is_rms_norm=False): dx,
""" Assume that arguments are contiguous and aligned to 16 bytes x,
x0,
dmask,
mu,
rsigma,
gamma,
colscale,
x0_subset,
out_subset,
dropout_p,
rowscale_const,
x0_numrows,
has_residual,
is_rms_norm=False,
):
"""Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd). (x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale. x0 must not be None if we have colscale.
...@@ -94,10 +181,25 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga ...@@ -94,10 +181,25 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
x0_subset = x0_subset.view(-1) if x0_subset is not None else None x0_subset = x0_subset.view(-1) if x0_subset is not None else None
out_subset = out_subset.view(-1) if out_subset is not None else None out_subset = out_subset.view(-1) if out_subset is not None else None
if colscale is not None: if colscale is not None:
assert x0 is not None, 'x0 is required to compute the gradient of colscale' assert x0 is not None, "x0 is required to compute the gradient of colscale"
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset, dzmat,
dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm dxmat,
xmat,
x0mat,
dmask,
mu,
rsigma,
gamma,
None,
colscale,
x0_subset,
out_subset,
dropout_p,
rowscale_const,
x0_numrows,
has_residual,
is_rms_norm,
) )
# dresidualmat is None if not has_residual # dresidualmat is None if not has_residual
if colscale is None: if colscale is None:
...@@ -108,18 +210,44 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga ...@@ -108,18 +210,44 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
def _dropout_add_layer_norm_parallel_residual_forward( def _dropout_add_layer_norm_parallel_residual_forward(
x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, x0,
epsilon, residual_in_fp32=False, is_rms_norm=False x1,
residual,
gamma0,
beta0,
gamma1,
beta1,
dropout_p,
epsilon,
residual_in_fp32=False,
is_rms_norm=False,
): ):
""" Assume that arguments are contiguous and aligned to 16 bytes """Assume that arguments are contiguous and aligned to 16 bytes"""
"""
hidden_size = gamma0.numel() hidden_size = gamma0.numel()
x0mat = x0.view((-1, hidden_size)) x0mat = x0.view((-1, hidden_size))
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
residualmat = residual.view((-1, hidden_size)) if residual is not None else None residualmat = residual.view((-1, hidden_size)) if residual is not None else None
z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd( (
x0mat, x1mat, residualmat, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, z0mat,
None, residual_in_fp32, is_rms_norm z1mat,
xmat,
dmask0,
dmask1,
mu,
rsigma,
) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
x0mat,
x1mat,
residualmat,
gamma0,
beta0,
gamma1,
beta1,
dropout_p,
epsilon,
None,
residual_in_fp32,
is_rms_norm,
) )
# dmask0 and dmask1 are None if dropout_p == 0.0 # dmask0 and dmask1 are None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
...@@ -127,10 +255,22 @@ def _dropout_add_layer_norm_parallel_residual_forward( ...@@ -127,10 +255,22 @@ def _dropout_add_layer_norm_parallel_residual_forward(
def _dropout_add_layer_norm_parallel_residual_backward( def _dropout_add_layer_norm_parallel_residual_backward(
dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1, dz0,
dropout_p, has_x1, has_residual, is_rms_norm=False dz1,
dx,
x,
dmask0,
dmask1,
mu,
rsigma,
gamma0,
gamma1,
dropout_p,
has_x1,
has_residual,
is_rms_norm=False,
): ):
""" Assume that arguments are contiguous and aligned to 16 bytes """Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd). (x = drop(x0) + residual was not returned in the fwd).
""" """
...@@ -139,9 +279,30 @@ def _dropout_add_layer_norm_parallel_residual_backward( ...@@ -139,9 +279,30 @@ def _dropout_add_layer_norm_parallel_residual_backward(
dz0mat = dz0.view(xmat.shape) dz0mat = dz0.view(xmat.shape)
dz1mat = dz1.view(xmat.shape) if dz1 is not None else None dz1mat = dz1.view(xmat.shape) if dz1 is not None else None
dxmat = dx.view(xmat.shape) if dx is not None else None dxmat = dx.view(xmat.shape) if dx is not None else None
dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1, *rest = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd( (
dz0mat, dz1mat, dxmat, xmat, dmask0, dmask1, mu, rsigma, gamma0, gamma1, dx0mat,
dropout_p, has_x1, has_residual, is_rms_norm dx1mat,
dresidualmat,
dgamma0,
dbeta0,
dgamma1,
dbeta1,
*rest,
) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
dz0mat,
dz1mat,
dxmat,
xmat,
dmask0,
dmask1,
mu,
rsigma,
gamma0,
gamma1,
dropout_p,
has_x1,
has_residual,
is_rms_norm,
) )
# dresidualmat is None if not has_residual # dresidualmat is None if not has_residual
return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1
...@@ -149,8 +310,21 @@ def _dropout_add_layer_norm_parallel_residual_backward( ...@@ -149,8 +310,21 @@ def _dropout_add_layer_norm_parallel_residual_backward(
class DropoutAddLayerNormFn(torch.autograd.Function): class DropoutAddLayerNormFn(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon, def forward(
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False): ctx,
x0,
residual,
gamma,
beta,
rowscale,
colscale,
dropout_p,
epsilon,
residual_in_fp32=False,
prenorm=False,
is_rms_norm=False,
return_dmask=False,
):
x0 = maybe_align(x0.contiguous(), 16) x0 = maybe_align(x0.contiguous(), 16)
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
gamma = maybe_align(gamma.contiguous(), 16) gamma = maybe_align(gamma.contiguous(), 16)
...@@ -158,26 +332,43 @@ class DropoutAddLayerNormFn(torch.autograd.Function): ...@@ -158,26 +332,43 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon, x0,
residual_in_fp32, is_rms_norm residual,
gamma,
beta,
rowscale,
colscale,
dropout_p,
epsilon,
residual_in_fp32,
is_rms_norm,
) )
# Only need to save x0 if we need to compute gradient wrt colscale # Only need to save x0 if we need to compute gradient wrt colscale
x0_saved = x0 if colscale is not None else None x0_saved = x0 if colscale is not None else None
ctx.save_for_backward(xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale) ctx.save_for_backward(
xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale
)
ctx.prenorm = prenorm ctx.prenorm = prenorm
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
ctx.has_residual = residual is not None ctx.has_residual = residual is not None
ctx.is_rms_norm = is_rms_norm ctx.is_rms_norm = is_rms_norm
ctx.has_beta = beta is not None ctx.has_beta = beta is not None
if not return_dmask: if not return_dmask:
return (zmat.view(x0.shape) if not prenorm return (
else (zmat.view(x0.shape), xmat.view(x0.shape))) zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape))
)
else: else:
dmask = (dmask.view(x0.shape) if dropout_p > 0. dmask = (
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)) dmask.view(x0.shape)
if dropout_p > 0.0
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
)
ctx.mark_non_differentiable(dmask) ctx.mark_non_differentiable(dmask)
return ((zmat.view(x0.shape), dmask) if not prenorm return (
else (zmat.view(x0.shape), xmat.view(x0.shape), dmask)) (zmat.view(x0.shape), dmask)
if not prenorm
else (zmat.view(x0.shape), xmat.view(x0.shape), dmask)
)
@staticmethod @staticmethod
def backward(ctx, dz, *args): def backward(ctx, dz, *args):
...@@ -189,35 +380,85 @@ class DropoutAddLayerNormFn(torch.autograd.Function): ...@@ -189,35 +380,85 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
dropout_p = ctx.dropout_p dropout_p = ctx.dropout_p
has_residual = ctx.has_residual has_residual = ctx.has_residual
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward( dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual, dz,
ctx.is_rms_norm dx,
x,
x0,
dmask,
mu,
rsigma,
gamma,
rowscale,
colscale,
dropout_p,
has_residual,
ctx.is_rms_norm,
) )
dx0 = dx0mat.view(x.shape) dx0 = dx0mat.view(x.shape)
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
dcolscale = rest[0] if colscale is not None else None dcolscale = rest[0] if colscale is not None else None
return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None, return (
None, None, None, None, None) dx0,
dresidual,
dgamma,
dbeta if ctx.has_beta else None,
None,
dcolscale,
None,
None,
None,
None,
None,
None,
)
class DropoutAddLayerNormSubsetFn(torch.autograd.Function): class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon, def forward(
rowscale_const, out_numrows, residual_in_fp32=False, ctx,
prenorm=False, is_rms_norm=False, return_dmask=False): x0,
residual,
gamma,
beta,
colscale,
x0_subset,
out_subset,
dropout_p,
epsilon,
rowscale_const,
out_numrows,
residual_in_fp32=False,
prenorm=False,
is_rms_norm=False,
return_dmask=False,
):
x0 = maybe_align(x0.contiguous(), 16) x0 = maybe_align(x0.contiguous(), 16)
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
gamma = maybe_align(gamma.contiguous(), 16) gamma = maybe_align(gamma.contiguous(), 16)
beta = maybe_align(beta.contiguous(), 16) if beta is not None else None beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward( zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon, x0,
rowscale_const, out_numrows, residual_in_fp32, is_rms_norm residual,
gamma,
beta,
colscale,
x0_subset,
out_subset,
dropout_p,
epsilon,
rowscale_const,
out_numrows,
residual_in_fp32,
is_rms_norm,
) )
# Only need to save x0 if we need to compute gradient wrt colscale # Only need to save x0 if we need to compute gradient wrt colscale
x0_saved = x0 if colscale is not None else None x0_saved = x0 if colscale is not None else None
x_shape = (-1, *x0.shape[1:]) x_shape = (-1, *x0.shape[1:])
ctx.save_for_backward(xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, ctx.save_for_backward(
x0_subset, out_subset) xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset
)
ctx.prenorm = prenorm ctx.prenorm = prenorm
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
ctx.rowscale_const = rowscale_const ctx.rowscale_const = rowscale_const
...@@ -227,14 +468,16 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function): ...@@ -227,14 +468,16 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
ctx.has_beta = beta is not None ctx.has_beta = beta is not None
z_shape = (-1, *x0.shape[1:]) z_shape = (-1, *x0.shape[1:])
if not return_dmask: if not return_dmask:
return (zmat.view(z_shape) if not prenorm return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape))
else (zmat.view(z_shape), xmat.view(x0.shape)))
else: else:
z = zmat.view(z_shape) z = zmat.view(z_shape)
dmask = (dmask.view(x0.shape) if dropout_p > 0. dmask = (
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)) dmask.view(x0.shape)
if dropout_p > 0.0
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
)
ctx.mark_non_differentiable(dmask) ctx.mark_non_differentiable(dmask)
return ((z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask)) return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask)
@staticmethod @staticmethod
def backward(ctx, dz, *args): def backward(ctx, dz, *args):
...@@ -246,20 +489,63 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function): ...@@ -246,20 +489,63 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
dropout_p = ctx.dropout_p dropout_p = ctx.dropout_p
has_residual = ctx.has_residual has_residual = ctx.has_residual
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward( dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p, dz,
ctx.rowscale_const, ctx.x0_numrows, has_residual, ctx.is_rms_norm dx,
x,
x0,
dmask,
mu,
rsigma,
gamma,
colscale,
x0_subset,
out_subset,
dropout_p,
ctx.rowscale_const,
ctx.x0_numrows,
has_residual,
ctx.is_rms_norm,
) )
dx0 = dx0mat.view(-1, *x.shape[1:]) dx0 = dx0mat.view(-1, *x.shape[1:])
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
dcolscale = rest[0] if colscale is not None else None dcolscale = rest[0] if colscale is not None else None
return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None, return (
None, None, None, None, None, None, None, None) dx0,
dresidual,
dgamma,
dbeta if ctx.has_beta else None,
dcolscale,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function): class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, def forward(
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False): ctx,
x0,
x1,
residual,
gamma0,
beta0,
gamma1,
beta1,
dropout_p,
epsilon,
residual_in_fp32=False,
prenorm=False,
is_rms_norm=False,
return_dmask=False,
):
x0 = maybe_align(x0.contiguous(), 16) x0 = maybe_align(x0.contiguous(), 16)
x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
...@@ -267,9 +553,26 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function): ...@@ -267,9 +553,26 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None
gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None
beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None
z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = _dropout_add_layer_norm_parallel_residual_forward( (
x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, z0mat,
residual_in_fp32, is_rms_norm z1mat,
xmat,
dmask0,
dmask1,
mu,
rsigma,
) = _dropout_add_layer_norm_parallel_residual_forward(
x0,
x1,
residual,
gamma0,
beta0,
gamma1,
beta1,
dropout_p,
epsilon,
residual_in_fp32,
is_rms_norm,
) )
ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma) ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma)
ctx.prenorm = prenorm ctx.prenorm = prenorm
...@@ -282,13 +585,21 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function): ...@@ -282,13 +585,21 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
if not return_dmask: if not return_dmask:
return z if not prenorm else (*z, xmat.view(x0.shape)) return z if not prenorm else (*z, xmat.view(x0.shape))
else: else:
dmask0 = (dmask0.view(x0.shape) if dropout_p > 0. dmask0 = (
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)) dmask0.view(x0.shape)
dmask1 = (dmask1.view(x0.shape) if dropout_p > 0. and x1 is not None if dropout_p > 0.0
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)) else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
)
dmask1 = (
dmask1.view(x0.shape)
if dropout_p > 0.0 and x1 is not None
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
)
ctx.mark_non_differentiable(dmask0) ctx.mark_non_differentiable(dmask0)
ctx.mark_non_differentiable(dmask1) ctx.mark_non_differentiable(dmask1)
return (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1) return (
(*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)
)
@staticmethod @staticmethod
def backward(ctx, dz0, dz1, *args): def backward(ctx, dz0, dz1, *args):
...@@ -299,63 +610,170 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function): ...@@ -299,63 +610,170 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
dropout_p = ctx.dropout_p dropout_p = ctx.dropout_p
has_x1 = ctx.has_x1 has_x1 = ctx.has_x1
has_residual = ctx.has_residual has_residual = ctx.has_residual
dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 = _dropout_add_layer_norm_parallel_residual_backward( (
dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, dx0mat,
has_residual, ctx.is_rms_norm dx1mat,
dresidualmat,
dgamma0,
dbeta0,
dgamma1,
dbeta1,
) = _dropout_add_layer_norm_parallel_residual_backward(
dz0,
dz1,
dx,
x,
dmask0,
dmask1,
mu,
rsigma,
gamma0,
gamma1,
dropout_p,
has_x1,
has_residual,
ctx.is_rms_norm,
) )
dx0 = dx0mat.view(x.shape) dx0 = dx0mat.view(x.shape)
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
return (dx0, dx1, dresidual, dgamma0, dbeta0 if ctx.has_beta else None, dgamma1, return (
dbeta1 if ctx.has_beta else None, None, None, None, None, None, None) dx0,
dx1,
dresidual,
dgamma0,
dbeta0 if ctx.has_beta else None,
dgamma1,
dbeta1 if ctx.has_beta else None,
None,
None,
None,
None,
None,
None,
)
def layer_norm(x, weight, bias, epsilon): def layer_norm(x, weight, bias, epsilon):
return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False) return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
def dropout_add_layer_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None, def dropout_add_layer_norm(
layerscale=None, prenorm=False, residual_in_fp32=False, x0,
return_dropout_mask=False): residual,
weight,
bias,
dropout_p,
epsilon,
rowscale=None,
layerscale=None,
prenorm=False,
residual_in_fp32=False,
return_dropout_mask=False,
):
"""residual_in_fp32 only has an effect if residual is None. """residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype. Otherwise residual dtype is residual.dtype.
""" """
return DropoutAddLayerNormFn.apply( return DropoutAddLayerNormFn.apply(
x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm, x0,
False, return_dropout_mask residual,
weight,
bias,
rowscale,
layerscale,
dropout_p,
epsilon,
residual_in_fp32,
prenorm,
False,
return_dropout_mask,
) )
def dropout_add_layer_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None, def dropout_add_layer_norm_subset(
x0_subset=None, out_subset=None, rowscale_const=1.0, x0,
out_numrows=0, prenorm=False, residual_in_fp32=False, residual,
return_dropout_mask=False): weight,
bias,
dropout_p,
epsilon,
layerscale=None,
x0_subset=None,
out_subset=None,
rowscale_const=1.0,
out_numrows=0,
prenorm=False,
residual_in_fp32=False,
return_dropout_mask=False,
):
"""residual_in_fp32 only has an effect if residual is None. """residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype. Otherwise residual dtype is residual.dtype.
""" """
return DropoutAddLayerNormSubsetFn.apply( return DropoutAddLayerNormSubsetFn.apply(
x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon, x0,
rowscale_const, out_numrows, residual_in_fp32, prenorm, False, return_dropout_mask residual,
weight,
bias,
layerscale,
x0_subset,
out_subset,
dropout_p,
epsilon,
rowscale_const,
out_numrows,
residual_in_fp32,
prenorm,
False,
return_dropout_mask,
) )
def dropout_add_layer_norm_parallel_residual( def dropout_add_layer_norm_parallel_residual(
x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, prenorm=False, x0,
residual_in_fp32=False, return_dropout_mask=False x1,
residual,
weight0,
bias0,
weight1,
bias1,
dropout_p,
epsilon,
prenorm=False,
residual_in_fp32=False,
return_dropout_mask=False,
): ):
"""residual_in_fp32 only has an effect if residual is None. """residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype. Otherwise residual dtype is residual.dtype.
""" """
return DropoutAddLayerNormParallelResidualFn.apply( return DropoutAddLayerNormParallelResidualFn.apply(
x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm, x0,
False, return_dropout_mask x1,
residual,
weight0,
bias0,
weight1,
bias1,
dropout_p,
epsilon,
residual_in_fp32,
prenorm,
False,
return_dropout_mask,
) )
class DropoutAddLayerNorm(torch.nn.Module): class DropoutAddLayerNorm(torch.nn.Module):
def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False, def __init__(
device=None, dtype=None): self,
factory_kwargs = {'device': device, 'dtype': dtype} hidden_size,
prenorm=False,
p=0.0,
eps=1e-5,
residual_in_fp32=False,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
self.prenorm = prenorm self.prenorm = prenorm
self.p = p self.p = p
...@@ -370,6 +788,13 @@ class DropoutAddLayerNorm(torch.nn.Module): ...@@ -370,6 +788,13 @@ class DropoutAddLayerNorm(torch.nn.Module):
init.zeros_(self.bias) init.zeros_(self.bias)
def forward(self, x0, residual=None): def forward(self, x0, residual=None):
return dropout_add_layer_norm(x0, residual, self.weight, self.bias, return dropout_add_layer_norm(
self.p if self.training else 0.0, self.eps, x0,
prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32) residual,
self.weight,
self.bias,
self.p if self.training else 0.0,
self.eps,
prenorm=self.prenorm,
residual_in_fp32=self.residual_in_fp32,
)
...@@ -4,60 +4,130 @@ ...@@ -4,60 +4,130 @@
import torch import torch
from torch.nn import init from torch.nn import init
from flash_attn.ops.layer_norm import DropoutAddLayerNormFn, DropoutAddLayerNormSubsetFn from flash_attn.ops.layer_norm import (
from flash_attn.ops.layer_norm import DropoutAddLayerNormParallelResidualFn DropoutAddLayerNormFn,
DropoutAddLayerNormParallelResidualFn,
DropoutAddLayerNormSubsetFn,
)
def rms_norm(x, weight, epsilon): def rms_norm(x, weight, epsilon):
return DropoutAddLayerNormFn.apply(x, None, weight, None, None, None, 0.0, epsilon, False, return DropoutAddLayerNormFn.apply(
False, True) x, None, weight, None, None, None, 0.0, epsilon, False, False, True
)
def dropout_add_rms_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None, def dropout_add_rms_norm(
layerscale=None, prenorm=False, residual_in_fp32=False, x0,
return_dropout_mask=False): residual,
weight,
bias,
dropout_p,
epsilon,
rowscale=None,
layerscale=None,
prenorm=False,
residual_in_fp32=False,
return_dropout_mask=False,
):
"""residual_in_fp32 only has an effect if residual is None. """residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype. Otherwise residual dtype is residual.dtype.
""" """
return DropoutAddLayerNormFn.apply( return DropoutAddLayerNormFn.apply(
x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm, x0,
True, return_dropout_mask residual,
weight,
bias,
rowscale,
layerscale,
dropout_p,
epsilon,
residual_in_fp32,
prenorm,
True,
return_dropout_mask,
) )
def dropout_add_rms_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None, def dropout_add_rms_norm_subset(
x0_subset=None, out_subset=None, rowscale_const=1.0, x0,
out_numrows=0, prenorm=False, residual_in_fp32=False, residual,
return_dropout_mask=False): weight,
bias,
dropout_p,
epsilon,
layerscale=None,
x0_subset=None,
out_subset=None,
rowscale_const=1.0,
out_numrows=0,
prenorm=False,
residual_in_fp32=False,
return_dropout_mask=False,
):
"""residual_in_fp32 only has an effect if residual is None. """residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype. Otherwise residual dtype is residual.dtype.
""" """
return DropoutAddLayerNormSubsetFn.apply( return DropoutAddLayerNormSubsetFn.apply(
x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon, x0,
rowscale_const, out_numrows, residual_in_fp32, prenorm, True, return_dropout_mask residual,
weight,
bias,
layerscale,
x0_subset,
out_subset,
dropout_p,
epsilon,
rowscale_const,
out_numrows,
residual_in_fp32,
prenorm,
True,
return_dropout_mask,
) )
def dropout_add_rms_norm_parallel_residual( def dropout_add_rms_norm_parallel_residual(
x0, x1, residual, weight0, bias0, weight1, bias1, x0,
dropout_p, epsilon, prenorm=False, residual_in_fp32=False, return_dropout_mask=False x1,
residual,
weight0,
bias0,
weight1,
bias1,
dropout_p,
epsilon,
prenorm=False,
residual_in_fp32=False,
return_dropout_mask=False,
): ):
"""residual_in_fp32 only has an effect if residual is None. """residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype. Otherwise residual dtype is residual.dtype.
""" """
return DropoutAddLayerNormParallelResidualFn.apply( return DropoutAddLayerNormParallelResidualFn.apply(
x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm, x0,
True, return_dropout_mask x1,
residual,
weight0,
bias0,
weight1,
bias1,
dropout_p,
epsilon,
residual_in_fp32,
prenorm,
True,
return_dropout_mask,
) )
class RMSNorm(torch.nn.Module): class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
self.eps = eps self.eps = eps
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter('bias', None) self.register_parameter("bias", None)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
...@@ -68,22 +138,37 @@ class RMSNorm(torch.nn.Module): ...@@ -68,22 +138,37 @@ class RMSNorm(torch.nn.Module):
class DropoutAddRMSNorm(torch.nn.Module): class DropoutAddRMSNorm(torch.nn.Module):
def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False, def __init__(
device=None, dtype=None): self,
factory_kwargs = {'device': device, 'dtype': dtype} hidden_size,
prenorm=False,
p=0.0,
eps=1e-5,
residual_in_fp32=False,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
self.prenorm = prenorm self.prenorm = prenorm
self.p = p self.p = p
self.eps = eps self.eps = eps
self.residual_in_fp32 = residual_in_fp32 self.residual_in_fp32 = residual_in_fp32
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter('bias', None) self.register_parameter("bias", None)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
init.ones_(self.weight) init.ones_(self.weight)
def forward(self, x0, residual=None): def forward(self, x0, residual=None):
return dropout_add_rms_norm(x0, residual, self.weight, None, return dropout_add_rms_norm(
self.p if self.training else 0.0, self.eps, x0,
prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32) residual,
self.weight,
None,
self.p if self.training else 0.0,
self.eps,
prenorm=self.prenorm,
residual_in_fp32=self.residual_in_fp32,
)
...@@ -11,7 +11,6 @@ from typing import Optional ...@@ -11,7 +11,6 @@ from typing import Optional
import triton import triton
import triton.language as tl import triton.language as tl
_sqrt2pi = math.sqrt(2.0 / math.pi) _sqrt2pi = math.sqrt(2.0 / math.pi)
_sqrt1_2 = math.sqrt(1.0 / 2) _sqrt1_2 = math.sqrt(1.0 / 2)
_gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi) _gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi)
...@@ -142,6 +141,7 @@ def gelu_grad(x): ...@@ -142,6 +141,7 @@ def gelu_grad(x):
pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization
return cdf + x * pdf return cdf + x * pdf
@triton.jit @triton.jit
def gelu_approx(x): def gelu_approx(x):
""" """
...@@ -157,6 +157,6 @@ def gelu_approx_grad(x): ...@@ -157,6 +157,6 @@ def gelu_approx_grad(x):
# CREDITS: Fast implementation proposed in # CREDITS: Fast implementation proposed in
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30 # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x)) tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))
return 0.5 * x * ( return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
(1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x) 1 + tanh_out
) + 0.5 * (1 + tanh_out) )
...@@ -9,8 +9,14 @@ from torch.autograd.function import FunctionCtx ...@@ -9,8 +9,14 @@ from torch.autograd.function import FunctionCtx
from torch.cuda.amp import custom_fwd from torch.cuda.amp import custom_fwd
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
from flash_attn.ops.triton.k_activations import gelu, gelu_grad, gelu_approx, gelu_approx_grad, squared_relu, squared_relu_grad from flash_attn.ops.triton.k_activations import (
gelu,
gelu_approx,
gelu_approx_grad,
gelu_grad,
squared_relu,
squared_relu_grad,
)
# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications # CREDITS: Initially inspired by the Triton tutorial on matrix multiplications
...@@ -28,7 +34,12 @@ def get_configs_io_bound(): ...@@ -28,7 +34,12 @@ def get_configs_io_bound():
num_warps = 2 if block_n <= 64 else 4 num_warps = 2 if block_n <= 64 else 4
configs.append( configs.append(
triton.Config( triton.Config(
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1}, {
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"BLOCK_K": block_k,
"SPLIT_K": 1,
},
num_stages=num_stages, num_stages=num_stages,
num_warps=num_warps, num_warps=num_warps,
) )
...@@ -43,29 +54,75 @@ def get_configs_io_bound(): ...@@ -43,29 +54,75 @@ def get_configs_io_bound():
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), triton.Config(
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), ),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), triton.Config(
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), ),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), triton.Config(
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2), ),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2
),
# good for int8 # good for int8
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), triton.Config(
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), num_stages=3,
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), num_warps=8,
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), ),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), triton.Config(
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), num_stages=3,
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2), num_warps=8,
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2
),
] ]
+ get_configs_io_bound(), + get_configs_io_bound(),
key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"],
prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10}, prune_configs_by={
"early_config_prune": early_config_prune,
"perf_model": estimate_matmul_time,
"top_k": 10,
},
) )
@triton.heuristics( @triton.heuristics(
{ {
...@@ -204,7 +261,7 @@ def triton_linear_act( ...@@ -204,7 +261,7 @@ def triton_linear_act(
x: torch.Tensor, x: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
activation: str = 'id', activation: str = "id",
save_act_input: bool = False, save_act_input: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -221,7 +278,7 @@ def triton_linear_act( ...@@ -221,7 +278,7 @@ def triton_linear_act(
# dtype = torch.get_autocast_gpu_dtype() # dtype = torch.get_autocast_gpu_dtype()
# x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]] # x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
assert activation in ['id', 'gelu', 'gelu_approx', 'squared_relu'] assert activation in ["id", "gelu", "gelu_approx", "squared_relu"]
batch_shape, n = x.shape[:-1], x.shape[-1] batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel() batch_dim = batch_shape.numel()
...@@ -233,12 +290,20 @@ def triton_linear_act( ...@@ -233,12 +290,20 @@ def triton_linear_act(
weight = weight.contiguous() weight = weight.contiguous()
bias = bias.contiguous() if bias is not None else None bias = bias.contiguous() if bias is not None else None
assert x.dtype == weight.dtype, f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}" assert (
x.dtype == weight.dtype
), f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}"
if bias is not None: if bias is not None:
assert x.dtype == bias.dtype, f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}" assert (
assert x_reshaped.shape[1] == weight.shape[1], f"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}" x.dtype == bias.dtype
), f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}"
assert (
x_reshaped.shape[1] == weight.shape[1]
), f"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}"
assert bias is None or bias.shape[0] == weight.shape[0], "Incompatible dimensions in between weight and bias" assert (
bias is None or bias.shape[0] == weight.shape[0]
), "Incompatible dimensions in between weight and bias"
M, K = x_reshaped.shape M, K = x_reshaped.shape
N, K = weight.shape N, K = weight.shape
...@@ -278,35 +343,83 @@ def triton_linear_act( ...@@ -278,35 +343,83 @@ def triton_linear_act(
if not save_act_input: if not save_act_input:
return output.reshape(*batch_shape, output.shape[-1]) return output.reshape(*batch_shape, output.shape[-1])
else: else:
return (output.reshape(*batch_shape, output.shape[-1]), return (
act_input.reshape(*batch_shape, act_input.shape[-1])) output.reshape(*batch_shape, output.shape[-1]),
act_input.reshape(*batch_shape, act_input.shape[-1]),
)
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), triton.Config(
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), ),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), triton.Config(
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), ),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), triton.Config(
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2), ),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2
),
# good for int8 # good for int8
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), triton.Config(
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), num_stages=3,
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), num_warps=8,
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), ),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), triton.Config(
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), num_stages=3,
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2), num_warps=8,
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2
),
] ]
+ get_configs_io_bound(), + get_configs_io_bound(),
key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"],
prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10}, prune_configs_by={
"early_config_prune": early_config_prune,
"perf_model": estimate_matmul_time,
"top_k": 10,
},
) )
@triton.heuristics( @triton.heuristics(
{ {
...@@ -395,7 +508,7 @@ def kernel_bwd( ...@@ -395,7 +508,7 @@ def kernel_bwd(
B += BLOCK_K * stride_bk B += BLOCK_K * stride_bk
# optional: fused activation (while the data is in shared memory) # optional: fused activation (while the data is in shared memory)
if ACTIVATION != 'id': if ACTIVATION != "id":
act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]
act_input = tl.load(act_in_ptrs).to(acc.dtype) act_input = tl.load(act_in_ptrs).to(acc.dtype)
if ACTIVATION == "gelu": if ACTIVATION == "gelu":
...@@ -418,7 +531,7 @@ def kernel_bwd( ...@@ -418,7 +531,7 @@ def kernel_bwd(
def triton_dgrad_act( def triton_dgrad_act(
grad_output: torch.Tensor, grad_output: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
activation: str = 'id', activation: str = "id",
act_input: Optional[torch.Tensor] = None, act_input: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -430,7 +543,7 @@ def triton_dgrad_act( ...@@ -430,7 +543,7 @@ def triton_dgrad_act(
:param act_input: an optional tensor to save the activation inputs (for backward) :param act_input: an optional tensor to save the activation inputs (for backward)
:return: result tensor :return: result tensor
""" """
assert activation in ['id', 'gelu', 'gelu_approx', 'squared_relu'] assert activation in ["id", "gelu", "gelu_approx", "squared_relu"]
batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1] batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1]
batch_dim = batch_shape.numel() batch_dim = batch_shape.numel()
...@@ -441,10 +554,14 @@ def triton_dgrad_act( ...@@ -441,10 +554,14 @@ def triton_dgrad_act(
if weight.stride(0) > 1 and weight.stride(1) > 1: if weight.stride(0) > 1 and weight.stride(1) > 1:
weight = weight.contiguous() weight = weight.contiguous()
assert grad_output.dtype == weight.dtype, f"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}" assert (
assert grad_output_reshaped.shape[1] == weight.shape[0], f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}" grad_output.dtype == weight.dtype
if activation != 'id': ), f"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}"
assert act_input is not None, f'act_input is required for activation {activation}' assert (
grad_output_reshaped.shape[1] == weight.shape[0]
), f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}"
if activation != "id":
assert act_input is not None, f"act_input is required for activation {activation}"
# M, N, K in bwd are different from M, N, K in fwd # M, N, K in bwd are different from M, N, K in fwd
M, K = grad_output_reshaped.shape M, K = grad_output_reshaped.shape
......
# The triton fused matmul + sqrelu is faster for fp16 but slower for bf16, compared # The triton fused matmul + sqrelu is faster for fp16 but slower for bf16, compared
# to naive implementation. # to naive implementation.
import fused_dense_lib as fused_dense_cuda
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
import fused_dense_lib as fused_dense_cuda from flash_attn.ops.activations import sqrelu_bwd, sqrelu_fwd
from flash_attn.ops.triton.linear import triton_dgrad_act, triton_linear_act
from flash_attn.ops.triton.linear import triton_linear_act, triton_dgrad_act
from flash_attn.ops.activations import sqrelu_fwd, sqrelu_bwd
class FusedDenseSqreluDenseFunc(torch.autograd.Function): class FusedDenseSqreluDenseFunc(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd @custom_fwd
def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0): def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0):
...@@ -23,8 +21,9 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function): ...@@ -23,8 +21,9 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function):
""" """
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
dtype = torch.get_autocast_gpu_dtype() dtype = torch.get_autocast_gpu_dtype()
x, weight1, bias1, weight2, bias2 = [a.to(dtype=dtype) x, weight1, bias1, weight2, bias2 = [
for a in [x, weight1, bias1, weight2, bias2]] a.to(dtype=dtype) for a in [x, weight1, bias1, weight2, bias2]
]
is_bf16 = x.dtype == torch.bfloat16 is_bf16 = x.dtype == torch.bfloat16
assert checkpoint_lvl in [0, 1, 2] assert checkpoint_lvl in [0, 1, 2]
x = x.contiguous() x = x.contiguous()
...@@ -35,13 +34,18 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function): ...@@ -35,13 +34,18 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function):
batch_shape, n = x.shape[:-1], x.shape[-1] batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel() batch_dim = batch_shape.numel()
if is_bf16: if is_bf16:
act_input = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1) act_input = fused_dense_cuda.linear_bias_forward(
x.reshape(batch_dim, n), weight1, bias1
)
output1 = sqrelu_fwd(act_input) output1 = sqrelu_fwd(act_input)
else: else:
save_act_input = checkpoint_lvl != 2 save_act_input = checkpoint_lvl != 2
result = triton_linear_act( result = triton_linear_act(
x.reshape(batch_dim, n), weight1, bias1, activation='squared_relu', x.reshape(batch_dim, n),
save_act_input=save_act_input weight1,
bias1,
activation="squared_relu",
save_act_input=save_act_input,
) )
if save_act_input: if save_act_input:
output1, act_input = result output1, act_input = result
...@@ -69,16 +73,21 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function): ...@@ -69,16 +73,21 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function):
if checkpoint_lvl == 0: if checkpoint_lvl == 0:
act_input, output1 = rest act_input, output1 = rest
elif checkpoint_lvl == 1: elif checkpoint_lvl == 1:
act_input, = rest (act_input,) = rest
output1 = sqrelu_fwd(act_input) output1 = sqrelu_fwd(act_input)
elif checkpoint_lvl == 2: elif checkpoint_lvl == 2:
if is_bf16: if is_bf16:
act_input = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1) act_input = fused_dense_cuda.linear_bias_forward(
x.reshape(batch_dim, n), weight1, bias1
)
output1 = sqrelu_fwd(act_input) output1 = sqrelu_fwd(act_input)
else: else:
output1, act_input = triton_linear_act( output1, act_input = triton_linear_act(
x.reshape(batch_dim, n), weight1, bias1, activation='squared_relu', x.reshape(batch_dim, n),
save_act_input=True weight1,
bias1,
activation="squared_relu",
save_act_input=True,
) )
if is_bf16: if is_bf16:
...@@ -92,8 +101,9 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function): ...@@ -92,8 +101,9 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function):
else: else:
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output) grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
grad_act_input = triton_dgrad_act(grad_output, weight2, activation='squared_relu', grad_act_input = triton_dgrad_act(
act_input=act_input) grad_output, weight2, activation="squared_relu", act_input=act_input
)
grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward( grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
x.reshape(batch_dim, n), weight1, grad_act_input x.reshape(batch_dim, n), weight1, grad_act_input
) )
...@@ -104,9 +114,17 @@ fused_dense_sqrelu_dense_function = FusedDenseSqreluDenseFunc.apply ...@@ -104,9 +114,17 @@ fused_dense_sqrelu_dense_function = FusedDenseSqreluDenseFunc.apply
class FusedDenseSqreluDense(nn.Module): class FusedDenseSqreluDense(nn.Module):
def __init__(
def __init__(self, in_features, hidden_features=None, out_features=None, bias1=True, bias2=True, self,
checkpoint_lvl=0, device=None, dtype=None): in_features,
hidden_features=None,
out_features=None,
bias1=True,
bias2=True,
checkpoint_lvl=0,
device=None,
dtype=None,
):
""" """
checkpoint_lvl (increasing lvl means slower but more memory saving): checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd 0: no recomputation in the bwd
...@@ -114,7 +132,7 @@ class FusedDenseSqreluDense(nn.Module): ...@@ -114,7 +132,7 @@ class FusedDenseSqreluDense(nn.Module):
2: recompute gelu_in and gelu_out in the bwd 2: recompute gelu_in and gelu_out in the bwd
""" """
assert checkpoint_lvl in [0, 1, 2] assert checkpoint_lvl in [0, 1, 2]
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
out_features = out_features or in_features out_features = out_features or in_features
hidden_features = hidden_features or in_features * 4 hidden_features = hidden_features or in_features * 4
...@@ -126,6 +144,6 @@ class FusedDenseSqreluDense(nn.Module): ...@@ -126,6 +144,6 @@ class FusedDenseSqreluDense(nn.Module):
def forward(self, x): def forward(self, x):
assert x.is_cuda assert x.is_cuda
return fused_dense_sqrelu_dense_function(x, self.fc1.weight, self.fc1.bias, return fused_dense_sqrelu_dense_function(
self.fc2.weight, self.fc2.bias, x, self.fc1.weight, self.fc1.bias, self.fc2.weight, self.fc2.bias, self.checkpoint_lvl
self.checkpoint_lvl) )
...@@ -5,31 +5,43 @@ import torch ...@@ -5,31 +5,43 @@ import torch
import torch.utils.benchmark as benchmark import torch.utils.benchmark as benchmark
def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, amp=False, def benchmark_forward(
amp_dtype=torch.float16, **kwinputs): fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs
""" Use Pytorch Benchmark on the forward pass of an arbitrary function. """ ):
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
if verbose: if verbose:
print(desc, '- Forward pass') print(desc, "- Forward pass")
def amp_wrapper(*inputs, **kwinputs): def amp_wrapper(*inputs, **kwinputs):
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
fn(*inputs, **kwinputs) fn(*inputs, **kwinputs)
t = benchmark.Timer( t = benchmark.Timer(
stmt='fn_amp(*inputs, **kwinputs)', stmt="fn_amp(*inputs, **kwinputs)",
globals={'fn_amp': amp_wrapper, 'inputs': inputs, 'kwinputs': kwinputs}, globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
num_threads=torch.get_num_threads(), num_threads=torch.get_num_threads(),
) )
m = t.timeit(repeats) m = t.timeit(repeats)
if verbose: if verbose:
print(m) print(m)
return t, m return t, m
def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False, def benchmark_backward(
amp_dtype=torch.float16, **kwinputs): fn,
""" Use Pytorch Benchmark on the backward pass of an arbitrary function. """ *inputs,
grad=None,
repeats=10,
desc="",
verbose=True,
amp=False,
amp_dtype=torch.float16,
**kwinputs,
):
"""Use Pytorch Benchmark on the backward pass of an arbitrary function."""
if verbose: if verbose:
print(desc, '- Backward pass') print(desc, "- Backward pass")
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
y = fn(*inputs, **kwinputs) y = fn(*inputs, **kwinputs)
if type(y) is tuple: if type(y) is tuple:
y = y[0] y = y[0]
...@@ -37,7 +49,8 @@ def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True ...@@ -37,7 +49,8 @@ def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
grad = torch.randn_like(y) grad = torch.randn_like(y)
else: else:
if grad.shape != y.shape: if grad.shape != y.shape:
raise RuntimeError('Grad shape does not match output shape') raise RuntimeError("Grad shape does not match output shape")
def f(*inputs, y, grad): def f(*inputs, y, grad):
# Set .grad to None to avoid extra operation of gradient accumulation # Set .grad to None to avoid extra operation of gradient accumulation
for x in inputs: for x in inputs:
...@@ -46,22 +59,31 @@ def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True ...@@ -46,22 +59,31 @@ def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
y.backward(grad, retain_graph=True) y.backward(grad, retain_graph=True)
t = benchmark.Timer( t = benchmark.Timer(
stmt='f(*inputs, y=y, grad=grad)', stmt="f(*inputs, y=y, grad=grad)",
globals={'f': f, 'inputs': inputs, 'y': y, 'grad': grad}, globals={"f": f, "inputs": inputs, "y": y, "grad": grad},
num_threads=torch.get_num_threads(), num_threads=torch.get_num_threads(),
) )
m = t.timeit(repeats) m = t.timeit(repeats)
if verbose: if verbose:
print(m) print(m)
return t, m return t, m
def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False, def benchmark_combined(
amp_dtype=torch.float16, **kwinputs): fn,
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """ *inputs,
grad=None,
repeats=10,
desc="",
verbose=True,
amp=False,
amp_dtype=torch.float16,
**kwinputs,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
if verbose: if verbose:
print(desc, '- Forward + Backward pass') print(desc, "- Forward + Backward pass")
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
y = fn(*inputs, **kwinputs) y = fn(*inputs, **kwinputs)
if type(y) is tuple: if type(y) is tuple:
y = y[0] y = y[0]
...@@ -69,68 +91,142 @@ def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True ...@@ -69,68 +91,142 @@ def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
grad = torch.randn_like(y) grad = torch.randn_like(y)
else: else:
if grad.shape != y.shape: if grad.shape != y.shape:
raise RuntimeError('Grad shape does not match output shape') raise RuntimeError("Grad shape does not match output shape")
def f(grad, *inputs, **kwinputs): def f(grad, *inputs, **kwinputs):
for x in inputs: for x in inputs:
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
x.grad = None x.grad = None
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
y = fn(*inputs, **kwinputs) y = fn(*inputs, **kwinputs)
if type(y) is tuple: if type(y) is tuple:
y = y[0] y = y[0]
y.backward(grad, retain_graph=True) y.backward(grad, retain_graph=True)
t = benchmark.Timer( t = benchmark.Timer(
stmt='f(grad, *inputs, **kwinputs)', stmt="f(grad, *inputs, **kwinputs)",
globals={'f': f, 'fn': fn, 'inputs': inputs, 'grad': grad, 'kwinputs': kwinputs}, globals={"f": f, "fn": fn, "inputs": inputs, "grad": grad, "kwinputs": kwinputs},
num_threads=torch.get_num_threads(), num_threads=torch.get_num_threads(),
) )
m = t.timeit(repeats) m = t.timeit(repeats)
if verbose: if verbose:
print(m) print(m)
return t, m return t, m
def benchmark_fwd_bwd(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False, def benchmark_fwd_bwd(
amp_dtype=torch.float16, **kwinputs): fn,
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """ *inputs,
grad=None,
repeats=10,
desc="",
verbose=True,
amp=False,
amp_dtype=torch.float16,
**kwinputs,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
return ( return (
benchmark_forward(fn, *inputs, repeats=repeats, desc=desc, verbose=verbose, benchmark_forward(
amp=amp, amp_dtype=amp_dtype, **kwinputs), fn,
benchmark_backward(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose, *inputs,
amp=amp, amp_dtype=amp_dtype, **kwinputs), repeats=repeats,
desc=desc,
verbose=verbose,
amp=amp,
amp_dtype=amp_dtype,
**kwinputs,
),
benchmark_backward(
fn,
*inputs,
grad=grad,
repeats=repeats,
desc=desc,
verbose=verbose,
amp=amp,
amp_dtype=amp_dtype,
**kwinputs,
),
) )
def benchmark_all(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False, def benchmark_all(
amp_dtype=torch.float16, **kwinputs): fn,
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """ *inputs,
grad=None,
repeats=10,
desc="",
verbose=True,
amp=False,
amp_dtype=torch.float16,
**kwinputs,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
return ( return (
benchmark_forward(fn, *inputs, repeats=repeats, desc=desc, verbose=verbose, benchmark_forward(
amp=amp, amp_dtype=amp_dtype, **kwinputs), fn,
benchmark_backward(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose, *inputs,
amp=amp, amp_dtype=amp_dtype, **kwinputs), repeats=repeats,
benchmark_combined(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose, desc=desc,
amp=amp, amp_dtype=amp_dtype, **kwinputs), verbose=verbose,
amp=amp,
amp_dtype=amp_dtype,
**kwinputs,
),
benchmark_backward(
fn,
*inputs,
grad=grad,
repeats=repeats,
desc=desc,
verbose=verbose,
amp=amp,
amp_dtype=amp_dtype,
**kwinputs,
),
benchmark_combined(
fn,
*inputs,
grad=grad,
repeats=repeats,
desc=desc,
verbose=verbose,
amp=amp,
amp_dtype=amp_dtype,
**kwinputs,
),
) )
def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False, def pytorch_profiler(
amp_dtype=torch.float16, cpu=False, verbose=True, **kwinputs): fn,
""" Wrap benchmark functions in Pytorch profiler to see CUDA information. """ *inputs,
trace_filename=None,
backward=False,
amp=False,
amp_dtype=torch.float16,
cpu=False,
verbose=True,
**kwinputs,
):
"""Wrap benchmark functions in Pytorch profiler to see CUDA information."""
if backward: if backward:
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
g = torch.randn_like(fn(*inputs, **kwinputs)) g = torch.randn_like(fn(*inputs, **kwinputs))
for _ in range(30): # Warm up for _ in range(30): # Warm up
if backward: if backward:
for x in inputs: for x in inputs:
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
x.grad = None x.grad = None
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
out = fn(*inputs, **kwinputs) out = fn(*inputs, **kwinputs)
# Backward should be done outside autocast # Backward should be done outside autocast
if backward: if backward:
out.backward(g, retain_graph=True) out.backward(g, retain_graph=True)
activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [torch.profiler.ProfilerActivity.CUDA] activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [
torch.profiler.ProfilerActivity.CUDA
]
with torch.profiler.profile( with torch.profiler.profile(
activities=activities, activities=activities,
record_shapes=True, record_shapes=True,
...@@ -141,9 +237,10 @@ def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False ...@@ -141,9 +237,10 @@ def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False
for x in inputs: for x in inputs:
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
x.grad = None x.grad = None
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
out = fn(*inputs, **kwinputs) out = fn(*inputs, **kwinputs)
if backward: out.backward(g, retain_graph=True) if backward:
out.backward(g, retain_graph=True)
if verbose: if verbose:
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50)) # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
print(prof.key_averages().table(row_limit=50)) print(prof.key_averages().table(row_limit=50))
...@@ -151,14 +248,14 @@ def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False ...@@ -151,14 +248,14 @@ def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False
prof.export_chrome_trace(trace_filename) prof.export_chrome_trace(trace_filename)
def benchmark_memory(fn, *inputs, desc='', verbose=True, **kwinputs): def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs):
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize() torch.cuda.synchronize()
fn(*inputs, **kwinputs) fn(*inputs, **kwinputs)
torch.cuda.synchronize() torch.cuda.synchronize()
mem = torch.cuda.max_memory_allocated() / ((2 ** 20) * 1000) mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000)
if verbose: if verbose:
print(f'{desc} max memory: {mem}GB') print(f"{desc} max memory: {mem}GB")
torch.cuda.empty_cache() torch.cuda.empty_cache()
return mem return mem
...@@ -17,10 +17,12 @@ if "reduce_scatter_tensor" not in dir(torch.distributed): ...@@ -17,10 +17,12 @@ if "reduce_scatter_tensor" not in dir(torch.distributed):
# Raw operation, does not support autograd, but does support async # Raw operation, does not support autograd, but does support async
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
world_size = torch.distributed.get_world_size(process_group) world_size = torch.distributed.get_world_size(process_group)
output = torch.empty(world_size * input_.shape[0], *input_.shape[1:], output = torch.empty(
dtype=input_.dtype, device=input_.device) world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
handle = torch.distributed.all_gather_into_tensor(output, input_.contiguous(), )
group=process_group, async_op=async_op) handle = torch.distributed.all_gather_into_tensor(
output, input_.contiguous(), group=process_group, async_op=async_op
)
return output, handle return output, handle
...@@ -28,11 +30,12 @@ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = ...@@ -28,11 +30,12 @@ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool =
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
world_size = torch.distributed.get_world_size(process_group) world_size = torch.distributed.get_world_size(process_group)
assert input_.shape[0] % world_size == 0 assert input_.shape[0] % world_size == 0
output = torch.empty(input_.shape[0] // world_size, *input_.shape[1:], output = torch.empty(
dtype=input_.dtype, device=input_.device) input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
handle = torch.distributed.reduce_scatter_tensor(output, input_.contiguous(), )
group=process_group, handle = torch.distributed.reduce_scatter_tensor(
async_op=async_op) output, input_.contiguous(), group=process_group, async_op=async_op
)
return output, handle return output, handle
...@@ -102,8 +105,9 @@ all_reduce = AllReduceFunc.apply ...@@ -102,8 +105,9 @@ all_reduce = AllReduceFunc.apply
def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup): def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
# We want to iterate over parameters with _shared_params=True in the same order, # We want to iterate over parameters with _shared_params=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias). # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
pamams_shared = {name: p for name, p in model.named_parameters() pamams_shared = {
if getattr(p, '_shared_params', False)} name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
}
for _, p in sorted(pamams_shared.items()): for _, p in sorted(pamams_shared.items()):
with torch.no_grad(): with torch.no_grad():
# Broadcast needs src to be global rank, not group rank # Broadcast needs src to be global rank, not group rank
...@@ -116,8 +120,9 @@ def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup): ...@@ -116,8 +120,9 @@ def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup): def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
# We want to iterate over parameters with _sequence_parallel=True in the same order, # We want to iterate over parameters with _sequence_parallel=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias). # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
params_seqparallel = {name: p for name, p in model.named_parameters() params_seqparallel = {
if getattr(p, '_sequence_parallel', False)} name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
}
grads = [p.grad for _, p in sorted(params_seqparallel.items())] grads = [p.grad for _, p in sorted(params_seqparallel.items())]
if grads: if grads:
with torch.no_grad(): with torch.no_grad():
......
# Copyright (c) 2023, Tri Dao. # Copyright (c) 2023, Tri Dao.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31 # Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
from typing import Optional, Union, Sequence, Callable
import gc import gc
import time import time
from dataclasses import dataclass, field
from collections import namedtuple from collections import namedtuple
from dataclasses import dataclass, field
from typing import Callable, Optional, Sequence, Union
import torch import torch
from torch import Tensor
from torch.profiler import profile, record_function, ProfilerActivity
from einops import rearrange from einops import rearrange
from torch import Tensor
from torch.profiler import ProfilerActivity, profile, record_function
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
...@@ -20,6 +17,7 @@ from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoder ...@@ -20,6 +17,7 @@ from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoder
class InferenceParams: class InferenceParams:
"""Inference parameters that are passed to the main model in order """Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference.""" to efficienly calculate and store the context during inference."""
max_sequence_len: int max_sequence_len: int
max_batch_size: int max_batch_size: int
sequence_len_offset: int = 0 sequence_len_offset: int = 0
...@@ -38,11 +36,13 @@ def modify_logits_for_top_p_filtering(logits, top_p): ...@@ -38,11 +36,13 @@ def modify_logits_for_top_p_filtering(logits, top_p):
# First sort and calculate cumulative sum of probabilities. # First sort and calculate cumulative sum of probabilities.
sorted_logits, sorted_indices = torch.sort(logits, descending=False) sorted_logits, sorted_indices = torch.sort(logits, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept) # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p) sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
# scatter sorted tensors to original indexing # scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) indices_to_remove = sorted_indices_to_remove.scatter(
logits = logits.masked_fill(indices_to_remove, float('-inf')) 1, sorted_indices, sorted_indices_to_remove
)
logits = logits.masked_fill(indices_to_remove, float("-inf"))
def sample(logits, top_k=1, top_p=0.0, temperature=1.0): def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
...@@ -54,7 +54,7 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0): ...@@ -54,7 +54,7 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
return logits.argmax(dim=-1) return logits.argmax(dim=-1)
else: else:
if top_p > 0.0: if top_p > 0.0:
assert top_p <= 1.0, 'top-p should be in (0, 1].' assert top_p <= 1.0, "top-p should be in (0, 1]."
if top_k > 0: if top_k > 0:
top_k = min(top_k, logits.size(-1)) # Safety check top_k = min(top_k, logits.size(-1)) # Safety check
logits_top, indices = torch.topk(logits, top_k, dim=-1) logits_top, indices = torch.topk(logits, top_k, dim=-1)
...@@ -62,17 +62,31 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0): ...@@ -62,17 +62,31 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
modify_logits_for_top_p_filtering(logits_top, top_p) modify_logits_for_top_p_filtering(logits_top, top_p)
return indices[ return indices[
torch.arange(indices.shape[0], device=indices.device), torch.arange(indices.shape[0], device=indices.device),
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1) torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
] ]
else: else:
logits_top = logits / temperature logits_top = logits / temperature
modify_logits_for_top_p_filtering(logits_top, top_p) modify_logits_for_top_p_filtering(logits_top, top_p)
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1) return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
dim=-1
)
def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, def decode(
eos_token_id=None, teacher_outputs=None, vocab_size=None, tensor_parallel=1, input_ids,
fused_ft_kernel=False, cg=False, timing=False): model,
max_length,
top_k=1,
top_p=0.0,
temperature=1.0,
eos_token_id=None,
teacher_outputs=None,
vocab_size=None,
tensor_parallel=1,
fused_ft_kernel=False,
cg=False,
timing=False,
):
"""Decoding, either greedy or with top-k or top-p sampling. """Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling). If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
...@@ -92,19 +106,24 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, ...@@ -92,19 +106,24 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0 teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
if cg: if cg:
assert fused_ft_kernel assert fused_ft_kernel
if not hasattr(model, '_decoding_cache'): if not hasattr(model, "_decoding_cache"):
model._decoding_cache = None model._decoding_cache = None
model._decoding_cache = update_graph_cache( model._decoding_cache = update_graph_cache(
model, model._decoding_cache, batch_size, seqlen_og, max_length, model,
tensor_parallel=tensor_parallel model._decoding_cache,
batch_size,
seqlen_og,
max_length,
tensor_parallel=tensor_parallel,
) )
inference_params = model._decoding_cache.inference_params inference_params = model._decoding_cache.inference_params
inference_params.max_sequence_len = max_length inference_params.max_sequence_len = max_length
inference_params.max_batch_size = batch_size inference_params.max_batch_size = batch_size
inference_params.sequence_len_offset = 0 inference_params.sequence_len_offset = 0
else: else:
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size, inference_params = InferenceParams(
fused_ft_kernel=fused_ft_kernel) max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel
)
scores = [] scores = []
with torch.inference_mode(): with torch.inference_mode():
if timing: if timing:
...@@ -123,18 +142,32 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, ...@@ -123,18 +142,32 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
sequences = [next_token] sequences = [next_token]
inference_params.sequence_len_offset = seqlen_og inference_params.sequence_len_offset = seqlen_og
while True: while True:
position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset, position_ids = torch.full(
dtype=torch.long, device=input_ids.device) (batch_size, 1),
inference_params.sequence_len_offset,
dtype=torch.long,
device=input_ids.device,
)
if not cg: if not cg:
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids, logits = model(
inference_params=inference_params, last_token_only=True).logits rearrange(next_token, "b -> b 1"),
position_ids=position_ids,
inference_params=inference_params,
last_token_only=True,
).logits
else: else:
logits = model._decoding_cache.run(rearrange(next_token, 'b -> b 1'), position_ids, logits = model._decoding_cache.run(
inference_params.sequence_len_offset) rearrange(next_token, "b -> b 1"),
position_ids,
inference_params.sequence_len_offset,
)
if vocab_size is not None: if vocab_size is not None:
logits = logits[..., :vocab_size] logits = logits[..., :vocab_size]
scores.append(logits if not cg else logits.clone()) scores.append(logits if not cg else logits.clone())
if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset + 1: if (
teacher_outputs is None
or teacher_output_len <= inference_params.sequence_len_offset + 1
):
next_token = sample(logits, top_k=top_k, temperature=temperature) next_token = sample(logits, top_k=top_k, temperature=temperature)
else: else:
next_token = teacher_outputs[:, inference_params.sequence_len_offset + 1] next_token = teacher_outputs[:, inference_params.sequence_len_offset + 1]
...@@ -148,30 +181,45 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, ...@@ -148,30 +181,45 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
if tensor_parallel > 1: if tensor_parallel > 1:
torch.distributed.barrier() torch.distributed.barrier()
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
return output_cls( return output_cls(
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1), sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1), scores=tuple(scores)
scores=tuple(scores)
) )
class GenerationMixin: class GenerationMixin:
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
raise NotImplementedError raise NotImplementedError
def generate(self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0, def generate(
return_dict_in_generate=False, output_scores=False, **kwargs): self,
output = decode(input_ids, self, max_length, top_k=top_k, top_p=top_p, input_ids,
temperature=temperature, **kwargs) max_length,
top_k=1,
top_p=0.0,
temperature=1.0,
return_dict_in_generate=False,
output_scores=False,
**kwargs,
):
output = decode(
input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs
)
if not output_scores: if not output_scores:
output.scores = None output.scores = None
return output if return_dict_in_generate else output.sequences return output if return_dict_in_generate else output.sequences
def allocate_inference_cache(max_batch_size, max_seqlen, nheads, headdim, layers: Union[int, Sequence], def allocate_inference_cache(
device, dtype=torch.float16): max_batch_size,
max_seqlen,
nheads,
headdim,
layers: Union[int, Sequence],
device,
dtype=torch.float16,
):
assert dtype in [torch.float16, torch.bfloat16, torch.float32] assert dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if dtype == torch.float32 else 8 packsize = 4 if dtype == torch.float32 else 8
assert headdim % packsize == 0 assert headdim % packsize == 0
...@@ -179,9 +227,13 @@ def allocate_inference_cache(max_batch_size, max_seqlen, nheads, headdim, layers ...@@ -179,9 +227,13 @@ def allocate_inference_cache(max_batch_size, max_seqlen, nheads, headdim, layers
v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim) v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim)
if isinstance(layers, int): if isinstance(layers, int):
layers = range(layers) layers = range(layers)
return {i: (torch.empty(k_cache_shape, device=device, dtype=dtype), return {
torch.empty(v_cache_shape, device=device, dtype=dtype)) i: (
for i in layers} torch.empty(k_cache_shape, device=device, dtype=dtype),
torch.empty(v_cache_shape, device=device, dtype=dtype),
)
for i in layers
}
def seqlen_to_seqlen_type(seqlen: int) -> int: def seqlen_to_seqlen_type(seqlen: int) -> int:
...@@ -211,49 +263,70 @@ class DecodingCGCache: ...@@ -211,49 +263,70 @@ class DecodingCGCache:
@torch.inference_mode() @torch.inference_mode()
def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1, def update_graph_cache(
dtype=None, n_warmups=2): model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1, dtype=None, n_warmups=2
):
if cache is None: if cache is None:
cache = DecodingCGCache() cache = DecodingCGCache()
param_example = next(iter(model.parameters())) param_example = next(iter(model.parameters()))
device = param_example.device device = param_example.device
if dtype is None: if dtype is None:
dtype = param_example.dtype dtype = param_example.dtype
if ((device, dtype) != (cache.device, cache.dtype) or batch_size > cache.max_batch_size if (
or max_seqlen > cache.max_seqlen): # Invalidate the cache (device, dtype) != (cache.device, cache.dtype)
or batch_size > cache.max_batch_size
or max_seqlen > cache.max_seqlen
): # Invalidate the cache
cache.callables = {} cache.callables = {}
cache.mempool = None cache.mempool = None
cache.inference_params = None cache.inference_params = None
gc.collect() gc.collect()
cache.device, cache.dtype = device, dtype cache.device, cache.dtype = device, dtype
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
if hasattr(model, 'allocate_inference_cache'): if hasattr(model, "allocate_inference_cache"):
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype) inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
else: else:
headdim = getattr(model.config, 'head_dim', headdim = getattr(
model.config.hidden_size // model.config.num_attention_heads) model.config,
"head_dim",
model.config.hidden_size // model.config.num_attention_heads,
)
inf_cache = allocate_inference_cache( inf_cache = allocate_inference_cache(
batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim, batch_size,
model.config.num_hidden_layers, device, dtype max_seqlen,
model.config.num_attention_heads // tensor_parallel,
headdim,
model.config.num_hidden_layers,
device,
dtype,
) )
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
cache.inference_params = InferenceParams( cache.inference_params = InferenceParams(
max_sequence_len=max_seqlen, max_batch_size=batch_size, max_sequence_len=max_seqlen,
sequence_len_offset=seqlen_og, key_value_memory_dict=inf_cache, fused_ft_kernel=True, max_batch_size=batch_size,
lengths_per_sample=lengths_per_sample sequence_len_offset=seqlen_og,
key_value_memory_dict=inf_cache,
fused_ft_kernel=True,
lengths_per_sample=lengths_per_sample,
) )
cache.mempool = torch.cuda.graphs.graph_pool_handle() cache.mempool = torch.cuda.graphs.graph_pool_handle()
for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1): for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1):
if (batch_size, s_type) not in cache.callables: if (batch_size, s_type) not in cache.callables:
max_seqlen_ = min(max(seqlen_og, seqlen_type_to_max_seqlen(s_type)), max_seqlen) max_seqlen_ = min(max(seqlen_og, seqlen_type_to_max_seqlen(s_type)), max_seqlen)
cache.callables[batch_size, s_type] = capture_graph( cache.callables[batch_size, s_type] = capture_graph(
model, cache.inference_params, batch_size, max_seqlen_, mempool=cache.mempool, model,
n_warmups=n_warmups cache.inference_params,
batch_size,
max_seqlen_,
mempool=cache.mempool,
n_warmups=n_warmups,
) )
def dispatch(input_ids, position_ids, seqlen): def dispatch(input_ids, position_ids, seqlen):
batch_size = input_ids.shape[0] batch_size = input_ids.shape[0]
return cache.callables[batch_size, seqlen_to_seqlen_type(seqlen)](input_ids, position_ids, seqlen) return cache.callables[batch_size, seqlen_to_seqlen_type(seqlen)](
input_ids, position_ids, seqlen
)
cache.run = dispatch cache.run = dispatch
cache.inference_params.sequence_len_offset = 0 # Reset so it's not confusing cache.inference_params.sequence_len_offset = 0 # Reset so it's not confusing
...@@ -275,8 +348,12 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, ...@@ -275,8 +348,12 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
s.wait_stream(torch.cuda.current_stream()) s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s): with torch.cuda.stream(s):
for _ in range(n_warmups): for _ in range(n_warmups):
logits = model(input_ids, position_ids=position_ids, inference_params=inference_params, logits = model(
last_token_only=True).logits input_ids,
position_ids=position_ids,
inference_params=inference_params,
last_token_only=True,
).logits
s.synchronize() s.synchronize()
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0, # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
# which requires that graph launch and non-captured launch to not overlap (I think, # which requires that graph launch and non-captured launch to not overlap (I think,
...@@ -288,8 +365,12 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, ...@@ -288,8 +365,12 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
# To allow capture, automatically sets a side stream as the current stream in the context # To allow capture, automatically sets a side stream as the current stream in the context
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=mempool): with torch.cuda.graph(graph, pool=mempool):
logits = model(input_ids, position_ids=position_ids, inference_params=inference_params, logits = model(
last_token_only=True).logits input_ids,
position_ids=position_ids,
inference_params=inference_params,
last_token_only=True,
).logits
def run(new_input_ids, new_position_ids, seqlen): def run(new_input_ids, new_position_ids, seqlen):
inference_params.lengths_per_sample[:] = seqlen inference_params.lengths_per_sample[:] = seqlen
......
...@@ -3,13 +3,18 @@ from functools import partial ...@@ -3,13 +3,18 @@ from functools import partial
import torch import torch
from safetensors.torch import load_file as safe_load_file from safetensors.torch import load_file as safe_load_file
from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME from transformers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
)
from transformers.utils.hub import cached_file, get_checkpoint_shard_files from transformers.utils.hub import cached_file, get_checkpoint_shard_files
def state_dict_from_pretrained(model_name, device=None, dtype=None): def state_dict_from_pretrained(model_name, device=None, dtype=None):
# If not fp32, then we don't want to load directly to the GPU # If not fp32, then we don't want to load directly to the GPU
mapped_device = 'cpu' if dtype not in [torch.float32, None] else device mapped_device = "cpu" if dtype not in [torch.float32, None] else device
is_sharded = False is_sharded = False
load_safe = False load_safe = False
resolved_archive_file = None resolved_archive_file = None
...@@ -20,19 +25,23 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None): ...@@ -20,19 +25,23 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None):
safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME) safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME)
if os.path.isfile(weights_path): if os.path.isfile(weights_path):
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, resolved_archive_file = cached_file(
_raise_exceptions_for_missing_entries=False) model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False
)
elif os.path.isfile(weights_index_path): elif os.path.isfile(weights_index_path):
resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME, resolved_archive_file = cached_file(
_raise_exceptions_for_missing_entries=False) model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False
)
is_sharded = True is_sharded = True
elif os.path.isfile(safe_weights_path): elif os.path.isfile(safe_weights_path):
resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_NAME, resolved_archive_file = cached_file(
_raise_exceptions_for_missing_entries=False) model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False
)
load_safe = True load_safe = True
elif os.path.isfile(safe_weights_index_path): elif os.path.isfile(safe_weights_index_path):
resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME, resolved_archive_file = cached_file(
_raise_exceptions_for_missing_entries=False) model_name, SAFE_WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False
)
is_sharded = True is_sharded = True
load_safe = True load_safe = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment