Unverified Commit 9946165e authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

chore: add pre-commit (#1569)

parent 142cdabe
......@@ -40,5 +40,3 @@ __forceinline__ __device__ void dequant_6bit_16
#endif
#endif
......@@ -251,9 +251,9 @@ class LlamaMLP(nn.Module):
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
# Fuse gate and up proj
......
......@@ -255,9 +255,9 @@ class MistralMLP(nn.Module):
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
# Fuse gate and up proj
......
......@@ -344,9 +344,9 @@ class BlockSparseMoE(nn.Module):
if "gelu" in act:
self.act = lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
elif "silu" in act:
self.act = torch.nn.functional.silu
......@@ -600,9 +600,9 @@ class DenseMoE(nn.Module):
if "gelu" in act:
self.act = lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
elif "silu" in act:
self.act = torch.nn.functional.silu
......
......@@ -187,9 +187,9 @@ class FlashMLP(nn.Module):
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
......
......@@ -225,9 +225,9 @@ class PhiMLP(nn.Module):
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment