Commit f1a73d07 authored by Tri Dao's avatar Tri Dao
Browse files

Run isort and black on python files

parent cbb4cf5f
...@@ -2,42 +2,52 @@ ...@@ -2,42 +2,52 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from einops import rearrange from einops import rearrange
from torch import Tensor
from flash_attn.utils.distributed import reduce_scatter, all_reduce from flash_attn.utils.distributed import all_reduce, reduce_scatter
class GPT2Embeddings(nn.Module): class GPT2Embeddings(nn.Module):
def __init__(
def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None, self,
word_embed_proj_dim=None, device=None, dtype=None): embed_dim,
vocab_size,
max_position_embeddings,
padding_idx=None,
word_embed_proj_dim=None,
device=None,
dtype=None,
):
""" """
If max_position_embeddings <= 0, there's no position embeddings If max_position_embeddings <= 0, there's no position embeddings
If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
the project up to embed_dim the project up to embed_dim
""" """
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
if word_embed_proj_dim is None: if word_embed_proj_dim is None:
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx, self.word_embeddings = nn.Embedding(
**factory_kwargs) vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
)
self.project_in = None self.project_in = None
else: else:
self.word_embeddings = nn.Embedding(vocab_size, word_embed_proj_dim, self.word_embeddings = nn.Embedding(
padding_idx=padding_idx, **factory_kwargs) vocab_size, word_embed_proj_dim, padding_idx=padding_idx, **factory_kwargs
self.project_in = nn.Linear(word_embed_proj_dim, embed_dim, bias=False, )
**factory_kwargs) self.project_in = nn.Linear(
word_embed_proj_dim, embed_dim, bias=False, **factory_kwargs
)
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
if self.max_position_embeddings > 0: if self.max_position_embeddings > 0:
self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim, self.position_embeddings = nn.Embedding(
**factory_kwargs) max_position_embeddings, embed_dim, **factory_kwargs
)
def forward(self, input_ids, position_ids=None): def forward(self, input_ids, position_ids=None):
""" """
input_ids: (batch, seqlen) input_ids: (batch, seqlen)
position_ids: (batch, seqlen) position_ids: (batch, seqlen)
""" """
batch_size, seqlen = input_ids.shape batch_size, seqlen = input_ids.shape
embeddings = self.word_embeddings(input_ids) embeddings = self.word_embeddings(input_ids)
...@@ -52,31 +62,39 @@ class GPT2Embeddings(nn.Module): ...@@ -52,31 +62,39 @@ class GPT2Embeddings(nn.Module):
class BertEmbeddings(nn.Module): class BertEmbeddings(nn.Module):
def __init__(
def __init__(self, embed_dim, vocab_size, max_position_embeddings, type_vocab_size, self,
padding_idx=None, device=None, dtype=None): embed_dim,
vocab_size,
max_position_embeddings,
type_vocab_size,
padding_idx=None,
device=None,
dtype=None,
):
""" """
If max_position_embeddings <= 0, there's no position embeddings If max_position_embeddings <= 0, there's no position embeddings
If type_vocab_size <= 0, there's no token type embeddings If type_vocab_size <= 0, there's no token type embeddings
""" """
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx, self.word_embeddings = nn.Embedding(
**factory_kwargs) vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
)
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size self.type_vocab_size = type_vocab_size
if self.max_position_embeddings > 0: if self.max_position_embeddings > 0:
self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim, self.position_embeddings = nn.Embedding(
**factory_kwargs) max_position_embeddings, embed_dim, **factory_kwargs
)
if self.type_vocab_size > 0: if self.type_vocab_size > 0:
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
**factory_kwargs)
def forward(self, input_ids, position_ids=None, token_type_ids=None): def forward(self, input_ids, position_ids=None, token_type_ids=None):
""" """
input_ids: (batch, seqlen) input_ids: (batch, seqlen)
position_ids: (batch, seqlen) position_ids: (batch, seqlen)
token_type_ids: (batch, seqlen) token_type_ids: (batch, seqlen)
""" """
batch_size, seqlen = input_ids.shape batch_size, seqlen = input_ids.shape
embeddings = self.word_embeddings(input_ids) embeddings = self.word_embeddings(input_ids)
...@@ -94,16 +112,17 @@ class BertEmbeddings(nn.Module): ...@@ -94,16 +112,17 @@ class BertEmbeddings(nn.Module):
class VocabParallelEmbedding(nn.Embedding): class VocabParallelEmbedding(nn.Embedding):
def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs): def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
self.process_group = process_group self.process_group = process_group
if process_group is not None: if process_group is not None:
world_size = torch.distributed.get_world_size(process_group) world_size = torch.distributed.get_world_size(process_group)
if num_embeddings % world_size != 0: if num_embeddings % world_size != 0:
raise ValueError(f'num_embeddings ({num_embeddings}) must be divisible by ' raise ValueError(
f'world_size ({world_size})') f"num_embeddings ({num_embeddings}) must be divisible by "
f"world_size ({world_size})"
)
if world_size > 1 and padding_idx is not None: if world_size > 1 and padding_idx is not None:
raise RuntimeError('ParallelEmbedding does not support padding_idx') raise RuntimeError("ParallelEmbedding does not support padding_idx")
else: else:
world_size = 1 world_size = 1
super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs) super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)
...@@ -125,33 +144,45 @@ class VocabParallelEmbedding(nn.Embedding): ...@@ -125,33 +144,45 @@ class VocabParallelEmbedding(nn.Embedding):
class ColumnParallelEmbedding(nn.Embedding): class ColumnParallelEmbedding(nn.Embedding):
def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs): def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
self.process_group = process_group self.process_group = process_group
if process_group is not None: if process_group is not None:
world_size = torch.distributed.get_world_size(process_group) world_size = torch.distributed.get_world_size(process_group)
if embedding_dim % world_size != 0: if embedding_dim % world_size != 0:
raise ValueError(f'embedding_dim ({embedding_dim}) must be divisible by ' raise ValueError(
f'world_size ({world_size})') f"embedding_dim ({embedding_dim}) must be divisible by "
f"world_size ({world_size})"
)
else: else:
world_size = 1 world_size = 1
super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs) super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
class ParallelGPT2Embeddings(nn.Module): class ParallelGPT2Embeddings(nn.Module):
def __init__(
def __init__(self, embed_dim, vocab_size, max_position_embeddings, process_group, self,
padding_idx=None, sequence_parallel=True, device=None, dtype=None): embed_dim,
vocab_size,
max_position_embeddings,
process_group,
padding_idx=None,
sequence_parallel=True,
device=None,
dtype=None,
):
""" """
If max_position_embeddings <= 0, there's no position embeddings If max_position_embeddings <= 0, there's no position embeddings
""" """
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
self.process_group = process_group self.process_group = process_group
self.sequence_parallel = sequence_parallel self.sequence_parallel = sequence_parallel
self.word_embeddings = VocabParallelEmbedding( self.word_embeddings = VocabParallelEmbedding(
vocab_size, embed_dim, padding_idx=padding_idx, process_group=process_group, vocab_size,
**factory_kwargs embed_dim,
padding_idx=padding_idx,
process_group=process_group,
**factory_kwargs,
) )
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
if self.max_position_embeddings > 0: if self.max_position_embeddings > 0:
...@@ -161,8 +192,8 @@ class ParallelGPT2Embeddings(nn.Module): ...@@ -161,8 +192,8 @@ class ParallelGPT2Embeddings(nn.Module):
def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False): def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
""" """
input_ids: (batch, seqlen) input_ids: (batch, seqlen)
position_ids: (batch, seqlen) position_ids: (batch, seqlen)
""" """
batch_size, seqlen = input_ids.shape batch_size, seqlen = input_ids.shape
world_size = torch.distributed.get_world_size(self.process_group) world_size = torch.distributed.get_world_size(self.process_group)
...@@ -176,8 +207,10 @@ class ParallelGPT2Embeddings(nn.Module): ...@@ -176,8 +207,10 @@ class ParallelGPT2Embeddings(nn.Module):
else: else:
partition_dim = self.position_embeddings.embedding_dim partition_dim = self.position_embeddings.embedding_dim
rank = torch.distributed.get_rank(self.process_group) rank = torch.distributed.get_rank(self.process_group)
embeddings[..., rank * partition_dim:(rank + 1) * partition_dim] += position_embeddings embeddings[
..., rank * partition_dim : (rank + 1) * partition_dim
] += position_embeddings
if combine_batch_seqlen_dim: if combine_batch_seqlen_dim:
embeddings = rearrange(embeddings, 'b s d -> (b s) d') embeddings = rearrange(embeddings, "b s d -> (b s) d")
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group) return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
...@@ -732,8 +732,12 @@ class ParallelMHA(nn.Module): ...@@ -732,8 +732,12 @@ class ParallelMHA(nn.Module):
self.num_heads % self.num_heads_kv == 0 self.num_heads % self.num_heads_kv == 0
), "num_heads must be divisible by num_heads_kv" ), "num_heads must be divisible by num_heads_kv"
self.num_heads_per_rank = get_dim_for_local_rank(self.num_heads, self.world_size, self.local_rank) self.num_heads_per_rank = get_dim_for_local_rank(
self.num_heads_kv_per_rank = get_dim_for_local_rank(self.num_heads, self.world_size, self.local_rank) self.num_heads, self.world_size, self.local_rank
)
self.num_heads_kv_per_rank = get_dim_for_local_rank(
self.num_heads, self.world_size, self.local_rank
)
self.head_dim = self.embed_dim // num_heads self.head_dim = self.embed_dim // num_heads
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
......
This diff is collapsed.
...@@ -5,7 +5,6 @@ import torch ...@@ -5,7 +5,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
# 1/sqrt(2*pi)-> 0.3989423 # 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678 # 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456 # sqrt(2/pi) -> 0.79788456
...@@ -18,17 +17,19 @@ def bias_gelu(y, bias): ...@@ -18,17 +17,19 @@ def bias_gelu(y, bias):
x = bias + y x = bias + y
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype) return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)
# gradient of tanh approximation of gelu # gradient of tanh approximation of gelu
# gradient of actual gelu is: # gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script @torch.jit.script
def bias_gelu_back(g, y, bias): def bias_gelu_back(g, y, bias):
"""Assume that y has shape (B, D) and bias has shape (D) """Assume that y has shape (B, D) and bias has shape (D)"""
"""
x = bias + y x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
1 + tanh_out
)
grad_y = ff * g grad_y = ff * g
return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype) return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
...@@ -56,6 +57,7 @@ bias_gelu_impl = GeLUFunction.apply ...@@ -56,6 +57,7 @@ bias_gelu_impl = GeLUFunction.apply
def gelu_fwd(x): def gelu_fwd(x):
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype) return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)
# gradient of tanh approximation of gelu # gradient of tanh approximation of gelu
# gradient of actual gelu is: # gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
...@@ -63,7 +65,9 @@ def gelu_fwd(x): ...@@ -63,7 +65,9 @@ def gelu_fwd(x):
def gelu_bwd(g, x): def gelu_bwd(g, x):
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
1 + tanh_out
)
return (ff * g).to(dtype=x.dtype) return (ff * g).to(dtype=x.dtype)
...@@ -76,10 +80,11 @@ class FastGeLUFunction(torch.autograd.Function): ...@@ -76,10 +80,11 @@ class FastGeLUFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
input, = ctx.saved_tensors (input,) = ctx.saved_tensors
tmp = gelu_bwd(grad_output, input) tmp = gelu_bwd(grad_output, input)
return tmp return tmp
fast_gelu_impl = FastGeLUFunction.apply fast_gelu_impl = FastGeLUFunction.apply
......
...@@ -10,6 +10,10 @@ import fused_dense_lib as fused_dense_cuda ...@@ -10,6 +10,10 @@ import fused_dense_lib as fused_dense_cuda
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup
from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_bwd, sqrelu_fwd from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_bwd, sqrelu_fwd
from flash_attn.utils.distributed import ( from flash_attn.utils.distributed import (
all_gather_raw, all_gather_raw,
...@@ -18,9 +22,6 @@ from flash_attn.utils.distributed import ( ...@@ -18,9 +22,6 @@ from flash_attn.utils.distributed import (
reduce_scatter, reduce_scatter,
reduce_scatter_raw, reduce_scatter_raw,
) )
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup
class FusedDenseFunc(torch.autograd.Function): class FusedDenseFunc(torch.autograd.Function):
......
This diff is collapsed.
This diff is collapsed.
...@@ -11,7 +11,6 @@ from typing import Optional ...@@ -11,7 +11,6 @@ from typing import Optional
import triton import triton
import triton.language as tl import triton.language as tl
_sqrt2pi = math.sqrt(2.0 / math.pi) _sqrt2pi = math.sqrt(2.0 / math.pi)
_sqrt1_2 = math.sqrt(1.0 / 2) _sqrt1_2 = math.sqrt(1.0 / 2)
_gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi) _gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi)
...@@ -142,6 +141,7 @@ def gelu_grad(x): ...@@ -142,6 +141,7 @@ def gelu_grad(x):
pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization
return cdf + x * pdf return cdf + x * pdf
@triton.jit @triton.jit
def gelu_approx(x): def gelu_approx(x):
""" """
...@@ -157,6 +157,6 @@ def gelu_approx_grad(x): ...@@ -157,6 +157,6 @@ def gelu_approx_grad(x):
# CREDITS: Fast implementation proposed in # CREDITS: Fast implementation proposed in
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30 # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x)) tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))
return 0.5 * x * ( return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
(1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x) 1 + tanh_out
) + 0.5 * (1 + tanh_out) )
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment