Commit f1a73d07 authored by Tri Dao's avatar Tri Dao
Browse files

Run isort and black on python files

parent cbb4cf5f
......@@ -2,42 +2,52 @@
import torch
import torch.nn as nn
from torch import Tensor
from einops import rearrange
from torch import Tensor
from flash_attn.utils.distributed import reduce_scatter, all_reduce
from flash_attn.utils.distributed import all_reduce, reduce_scatter
class GPT2Embeddings(nn.Module):
def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None,
word_embed_proj_dim=None, device=None, dtype=None):
def __init__(
self,
embed_dim,
vocab_size,
max_position_embeddings,
padding_idx=None,
word_embed_proj_dim=None,
device=None,
dtype=None,
):
"""
If max_position_embeddings <= 0, there's no position embeddings
If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
the project up to embed_dim
If max_position_embeddings <= 0, there's no position embeddings
If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
the project up to embed_dim
"""
factory_kwargs = {'device': device, 'dtype': dtype}
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if word_embed_proj_dim is None:
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
**factory_kwargs)
self.word_embeddings = nn.Embedding(
vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
)
self.project_in = None
else:
self.word_embeddings = nn.Embedding(vocab_size, word_embed_proj_dim,
padding_idx=padding_idx, **factory_kwargs)
self.project_in = nn.Linear(word_embed_proj_dim, embed_dim, bias=False,
**factory_kwargs)
self.word_embeddings = nn.Embedding(
vocab_size, word_embed_proj_dim, padding_idx=padding_idx, **factory_kwargs
)
self.project_in = nn.Linear(
word_embed_proj_dim, embed_dim, bias=False, **factory_kwargs
)
self.max_position_embeddings = max_position_embeddings
if self.max_position_embeddings > 0:
self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
**factory_kwargs)
self.position_embeddings = nn.Embedding(
max_position_embeddings, embed_dim, **factory_kwargs
)
def forward(self, input_ids, position_ids=None):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
"""
batch_size, seqlen = input_ids.shape
embeddings = self.word_embeddings(input_ids)
......@@ -52,31 +62,39 @@ class GPT2Embeddings(nn.Module):
class BertEmbeddings(nn.Module):
def __init__(self, embed_dim, vocab_size, max_position_embeddings, type_vocab_size,
padding_idx=None, device=None, dtype=None):
def __init__(
self,
embed_dim,
vocab_size,
max_position_embeddings,
type_vocab_size,
padding_idx=None,
device=None,
dtype=None,
):
"""
If max_position_embeddings <= 0, there's no position embeddings
If type_vocab_size <= 0, there's no token type embeddings
If max_position_embeddings <= 0, there's no position embeddings
If type_vocab_size <= 0, there's no token type embeddings
"""
factory_kwargs = {'device': device, 'dtype': dtype}
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
**factory_kwargs)
self.word_embeddings = nn.Embedding(
vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
)
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
if self.max_position_embeddings > 0:
self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
**factory_kwargs)
self.position_embeddings = nn.Embedding(
max_position_embeddings, embed_dim, **factory_kwargs
)
if self.type_vocab_size > 0:
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim,
**factory_kwargs)
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
def forward(self, input_ids, position_ids=None, token_type_ids=None):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
token_type_ids: (batch, seqlen)
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
token_type_ids: (batch, seqlen)
"""
batch_size, seqlen = input_ids.shape
embeddings = self.word_embeddings(input_ids)
......@@ -94,16 +112,17 @@ class BertEmbeddings(nn.Module):
class VocabParallelEmbedding(nn.Embedding):
def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
self.process_group = process_group
if process_group is not None:
world_size = torch.distributed.get_world_size(process_group)
if num_embeddings % world_size != 0:
raise ValueError(f'num_embeddings ({num_embeddings}) must be divisible by '
f'world_size ({world_size})')
raise ValueError(
f"num_embeddings ({num_embeddings}) must be divisible by "
f"world_size ({world_size})"
)
if world_size > 1 and padding_idx is not None:
raise RuntimeError('ParallelEmbedding does not support padding_idx')
raise RuntimeError("ParallelEmbedding does not support padding_idx")
else:
world_size = 1
super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)
......@@ -125,33 +144,45 @@ class VocabParallelEmbedding(nn.Embedding):
class ColumnParallelEmbedding(nn.Embedding):
def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
self.process_group = process_group
if process_group is not None:
world_size = torch.distributed.get_world_size(process_group)
if embedding_dim % world_size != 0:
raise ValueError(f'embedding_dim ({embedding_dim}) must be divisible by '
f'world_size ({world_size})')
raise ValueError(
f"embedding_dim ({embedding_dim}) must be divisible by "
f"world_size ({world_size})"
)
else:
world_size = 1
super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
class ParallelGPT2Embeddings(nn.Module):
def __init__(self, embed_dim, vocab_size, max_position_embeddings, process_group,
padding_idx=None, sequence_parallel=True, device=None, dtype=None):
def __init__(
self,
embed_dim,
vocab_size,
max_position_embeddings,
process_group,
padding_idx=None,
sequence_parallel=True,
device=None,
dtype=None,
):
"""
If max_position_embeddings <= 0, there's no position embeddings
If max_position_embeddings <= 0, there's no position embeddings
"""
factory_kwargs = {'device': device, 'dtype': dtype}
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.process_group = process_group
self.sequence_parallel = sequence_parallel
self.word_embeddings = VocabParallelEmbedding(
vocab_size, embed_dim, padding_idx=padding_idx, process_group=process_group,
**factory_kwargs
vocab_size,
embed_dim,
padding_idx=padding_idx,
process_group=process_group,
**factory_kwargs,
)
self.max_position_embeddings = max_position_embeddings
if self.max_position_embeddings > 0:
......@@ -161,8 +192,8 @@ class ParallelGPT2Embeddings(nn.Module):
def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
"""
batch_size, seqlen = input_ids.shape
world_size = torch.distributed.get_world_size(self.process_group)
......@@ -176,8 +207,10 @@ class ParallelGPT2Embeddings(nn.Module):
else:
partition_dim = self.position_embeddings.embedding_dim
rank = torch.distributed.get_rank(self.process_group)
embeddings[..., rank * partition_dim:(rank + 1) * partition_dim] += position_embeddings
embeddings[
..., rank * partition_dim : (rank + 1) * partition_dim
] += position_embeddings
if combine_batch_seqlen_dim:
embeddings = rearrange(embeddings, 'b s d -> (b s) d')
embeddings = rearrange(embeddings, "b s d -> (b s) d")
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
......@@ -732,8 +732,12 @@ class ParallelMHA(nn.Module):
self.num_heads % self.num_heads_kv == 0
), "num_heads must be divisible by num_heads_kv"
self.num_heads_per_rank = get_dim_for_local_rank(self.num_heads, self.world_size, self.local_rank)
self.num_heads_kv_per_rank = get_dim_for_local_rank(self.num_heads, self.world_size, self.local_rank)
self.num_heads_per_rank = get_dim_for_local_rank(
self.num_heads, self.world_size, self.local_rank
)
self.num_heads_kv_per_rank = get_dim_for_local_rank(
self.num_heads, self.world_size, self.local_rank
)
self.head_dim = self.embed_dim // num_heads
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
......
This diff is collapsed.
......@@ -5,7 +5,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
......@@ -18,17 +17,19 @@ def bias_gelu(y, bias):
x = bias + y
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, y, bias):
"""Assume that y has shape (B, D) and bias has shape (D)
"""
"""Assume that y has shape (B, D) and bias has shape (D)"""
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
1 + tanh_out
)
grad_y = ff * g
return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
......@@ -56,6 +57,7 @@ bias_gelu_impl = GeLUFunction.apply
def gelu_fwd(x):
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
......@@ -63,7 +65,9 @@ def gelu_fwd(x):
def gelu_bwd(g, x):
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
1 + tanh_out
)
return (ff * g).to(dtype=x.dtype)
......@@ -76,10 +80,11 @@ class FastGeLUFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
(input,) = ctx.saved_tensors
tmp = gelu_bwd(grad_output, input)
return tmp
fast_gelu_impl = FastGeLUFunction.apply
......
......@@ -10,6 +10,10 @@ import fused_dense_lib as fused_dense_cuda
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup
from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_bwd, sqrelu_fwd
from flash_attn.utils.distributed import (
all_gather_raw,
......@@ -18,9 +22,6 @@ from flash_attn.utils.distributed import (
reduce_scatter,
reduce_scatter_raw,
)
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup
class FusedDenseFunc(torch.autograd.Function):
......
This diff is collapsed.
This diff is collapsed.
......@@ -11,7 +11,6 @@ from typing import Optional
import triton
import triton.language as tl
_sqrt2pi = math.sqrt(2.0 / math.pi)
_sqrt1_2 = math.sqrt(1.0 / 2)
_gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi)
......@@ -142,6 +141,7 @@ def gelu_grad(x):
pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization
return cdf + x * pdf
@triton.jit
def gelu_approx(x):
"""
......@@ -157,6 +157,6 @@ def gelu_approx_grad(x):
# CREDITS: Fast implementation proposed in
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))
return 0.5 * x * (
(1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)
) + 0.5 * (1 + tanh_out)
return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
1 + tanh_out
)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment