Unverified Commit 9946165e authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

chore: add pre-commit (#1569)

parent 142cdabe
...@@ -40,5 +40,3 @@ __forceinline__ __device__ void dequant_6bit_16 ...@@ -40,5 +40,3 @@ __forceinline__ __device__ void dequant_6bit_16
#endif #endif
#endif #endif
...@@ -251,9 +251,9 @@ class LlamaMLP(nn.Module): ...@@ -251,9 +251,9 @@ class LlamaMLP(nn.Module):
if "gelu" not in act if "gelu" not in act
else lambda x: torch.nn.functional.gelu( else lambda x: torch.nn.functional.gelu(
x, x,
approximate="tanh" approximate=(
if act in ["gelu_fast", "gelu_pytorch_tanh"] "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
else "none", ),
) )
) )
# Fuse gate and up proj # Fuse gate and up proj
......
...@@ -255,9 +255,9 @@ class MistralMLP(nn.Module): ...@@ -255,9 +255,9 @@ class MistralMLP(nn.Module):
if "gelu" not in act if "gelu" not in act
else lambda x: torch.nn.functional.gelu( else lambda x: torch.nn.functional.gelu(
x, x,
approximate="tanh" approximate=(
if act in ["gelu_fast", "gelu_pytorch_tanh"] "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
else "none", ),
) )
) )
# Fuse gate and up proj # Fuse gate and up proj
......
...@@ -344,9 +344,9 @@ class BlockSparseMoE(nn.Module): ...@@ -344,9 +344,9 @@ class BlockSparseMoE(nn.Module):
if "gelu" in act: if "gelu" in act:
self.act = lambda x: torch.nn.functional.gelu( self.act = lambda x: torch.nn.functional.gelu(
x, x,
approximate="tanh" approximate=(
if act in ["gelu_fast", "gelu_pytorch_tanh"] "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
else "none", ),
) )
elif "silu" in act: elif "silu" in act:
self.act = torch.nn.functional.silu self.act = torch.nn.functional.silu
...@@ -600,9 +600,9 @@ class DenseMoE(nn.Module): ...@@ -600,9 +600,9 @@ class DenseMoE(nn.Module):
if "gelu" in act: if "gelu" in act:
self.act = lambda x: torch.nn.functional.gelu( self.act = lambda x: torch.nn.functional.gelu(
x, x,
approximate="tanh" approximate=(
if act in ["gelu_fast", "gelu_pytorch_tanh"] "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
else "none", ),
) )
elif "silu" in act: elif "silu" in act:
self.act = torch.nn.functional.silu self.act = torch.nn.functional.silu
......
...@@ -187,9 +187,9 @@ class FlashMLP(nn.Module): ...@@ -187,9 +187,9 @@ class FlashMLP(nn.Module):
if "gelu" not in act if "gelu" not in act
else lambda x: torch.nn.functional.gelu( else lambda x: torch.nn.functional.gelu(
x, x,
approximate="tanh" approximate=(
if act in ["gelu_fast", "gelu_pytorch_tanh"] "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
else "none", ),
) )
) )
......
...@@ -225,9 +225,9 @@ class PhiMLP(nn.Module): ...@@ -225,9 +225,9 @@ class PhiMLP(nn.Module):
if "gelu" not in act if "gelu" not in act
else lambda x: torch.nn.functional.gelu( else lambda x: torch.nn.functional.gelu(
x, x,
approximate="tanh" approximate=(
if act in ["gelu_fast", "gelu_pytorch_tanh"] "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
else "none", ),
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment