Commit f1a73d07 authored by Tri Dao's avatar Tri Dao
Browse files

Run isort and black on python files

parent cbb4cf5f
This diff is collapsed.
...@@ -732,8 +732,12 @@ class ParallelMHA(nn.Module): ...@@ -732,8 +732,12 @@ class ParallelMHA(nn.Module):
self.num_heads % self.num_heads_kv == 0 self.num_heads % self.num_heads_kv == 0
), "num_heads must be divisible by num_heads_kv" ), "num_heads must be divisible by num_heads_kv"
self.num_heads_per_rank = get_dim_for_local_rank(self.num_heads, self.world_size, self.local_rank) self.num_heads_per_rank = get_dim_for_local_rank(
self.num_heads_kv_per_rank = get_dim_for_local_rank(self.num_heads, self.world_size, self.local_rank) self.num_heads, self.world_size, self.local_rank
)
self.num_heads_kv_per_rank = get_dim_for_local_rank(
self.num_heads, self.world_size, self.local_rank
)
self.head_dim = self.embed_dim // num_heads self.head_dim = self.embed_dim // num_heads
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
......
This diff is collapsed.
...@@ -5,7 +5,6 @@ import torch ...@@ -5,7 +5,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
# 1/sqrt(2*pi)-> 0.3989423 # 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678 # 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456 # sqrt(2/pi) -> 0.79788456
...@@ -18,17 +17,19 @@ def bias_gelu(y, bias): ...@@ -18,17 +17,19 @@ def bias_gelu(y, bias):
x = bias + y x = bias + y
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype) return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)
# gradient of tanh approximation of gelu # gradient of tanh approximation of gelu
# gradient of actual gelu is: # gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script @torch.jit.script
def bias_gelu_back(g, y, bias): def bias_gelu_back(g, y, bias):
"""Assume that y has shape (B, D) and bias has shape (D) """Assume that y has shape (B, D) and bias has shape (D)"""
"""
x = bias + y x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
1 + tanh_out
)
grad_y = ff * g grad_y = ff * g
return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype) return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
...@@ -56,6 +57,7 @@ bias_gelu_impl = GeLUFunction.apply ...@@ -56,6 +57,7 @@ bias_gelu_impl = GeLUFunction.apply
def gelu_fwd(x): def gelu_fwd(x):
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype) return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)
# gradient of tanh approximation of gelu # gradient of tanh approximation of gelu
# gradient of actual gelu is: # gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
...@@ -63,7 +65,9 @@ def gelu_fwd(x): ...@@ -63,7 +65,9 @@ def gelu_fwd(x):
def gelu_bwd(g, x): def gelu_bwd(g, x):
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
1 + tanh_out
)
return (ff * g).to(dtype=x.dtype) return (ff * g).to(dtype=x.dtype)
...@@ -76,10 +80,11 @@ class FastGeLUFunction(torch.autograd.Function): ...@@ -76,10 +80,11 @@ class FastGeLUFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
input, = ctx.saved_tensors (input,) = ctx.saved_tensors
tmp = gelu_bwd(grad_output, input) tmp = gelu_bwd(grad_output, input)
return tmp return tmp
fast_gelu_impl = FastGeLUFunction.apply fast_gelu_impl = FastGeLUFunction.apply
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment