Commit abbc1311 authored by Tri Dao's avatar Tri Dao
Browse files

[LayerNorm] Switch from CUDA to Triton implementation

parent f5b308e2
...@@ -14,3 +14,7 @@ This extension has only been tested on A100s. ...@@ -14,3 +14,7 @@ This extension has only been tested on A100s.
```sh ```sh
cd csrc/layer_norm && pip install . cd csrc/layer_norm && pip install .
``` ```
As of 2024-01-05, this extension is no longer used in the FlashAttention repo.
We've instead switched to a Triton-based
[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py).
...@@ -40,9 +40,10 @@ except ImportError: ...@@ -40,9 +40,10 @@ except ImportError:
FusedDense = None FusedDense = None
try: try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm, layer_norm from flash_attn.ops.triton.layer_norm import layer_norm_fn
except ImportError: except ImportError:
dropout_add_layer_norm, layer_norm = None, None layer_norm_fn = None
try: try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss from flash_attn.losses.cross_entropy import CrossEntropyLoss
...@@ -237,8 +238,8 @@ class BertPredictionHeadTransform(nn.Module): ...@@ -237,8 +238,8 @@ class BertPredictionHeadTransform(nn.Module):
if fused_bias_fc and FusedDense is None: if fused_bias_fc and FusedDense is None:
raise ImportError("fused_dense is not installed") raise ImportError("fused_dense is not installed")
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False) self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
if self.fused_dropout_add_ln and layer_norm is None: if self.fused_dropout_add_ln and layer_norm_fn is None:
raise ImportError("dropout_add_layer_norm is not installed") raise ImportError("Triton is not installed")
linear_cls = nn.Linear if not fused_bias_fc else FusedDense linear_cls = nn.Linear if not fused_bias_fc else FusedDense
self.dense = linear_cls(config.hidden_size, config.hidden_size) self.dense = linear_cls(config.hidden_size, config.hidden_size)
approximate = ( approximate = (
...@@ -255,8 +256,8 @@ class BertPredictionHeadTransform(nn.Module): ...@@ -255,8 +256,8 @@ class BertPredictionHeadTransform(nn.Module):
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln:
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
else: else:
hidden_states = layer_norm( hidden_states = layer_norm_fn(
hidden_states, self.layer_norm.weight, self.layer_norm.bias, self.layer_norm.eps hidden_states, self.layer_norm.weight, self.layer_norm.bias, eps=self.layer_norm.eps
) )
return hidden_states return hidden_states
...@@ -345,8 +346,8 @@ class BertModel(BertPreTrainedModel): ...@@ -345,8 +346,8 @@ class BertModel(BertPreTrainedModel):
config.vocab_size % self.pad_vocab_size_multiple config.vocab_size % self.pad_vocab_size_multiple
) )
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False) self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
if self.fused_dropout_add_ln and layer_norm is None: if self.fused_dropout_add_ln and layer_norm_fn is None:
raise ImportError("dropout_add_layer_norm is not installed") raise ImportError("Triton is not installed")
assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"] assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
self.embeddings = BertEmbeddings( self.embeddings = BertEmbeddings(
...@@ -384,8 +385,8 @@ class BertModel(BertPreTrainedModel): ...@@ -384,8 +385,8 @@ class BertModel(BertPreTrainedModel):
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln:
hidden_states = self.emb_ln(hidden_states) hidden_states = self.emb_ln(hidden_states)
else: else:
hidden_states = layer_norm( hidden_states = layer_norm_fn(
hidden_states, self.emb_ln.weight, self.emb_ln.bias, self.emb_ln.eps hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps
) )
hidden_states = self.emb_drop(hidden_states) hidden_states = self.emb_drop(hidden_states)
......
# Copyright (c) 2023, Tri Dao. # Copyright (c) 2024, Tri Dao.
import logging import logging
import math import math
...@@ -47,29 +47,14 @@ except ImportError: ...@@ -47,29 +47,14 @@ except ImportError:
ColumnParallelLinear = None ColumnParallelLinear = None
try: try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
except ImportError:
dropout_add_layer_norm = None
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
except ImportError:
dropout_add_layer_norm_parallel_residual = None
try:
from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm
except ImportError:
RMSNorm, dropout_add_rms_norm = None, None
try:
from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
except ImportError: except ImportError:
dropout_add_rms_norm_parallel_residual = None FusedDenseSqreluDense = None
try: try:
from flash_attn.ops.triton.mlp import FusedDenseSqreluDense from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
except ImportError: except ImportError:
FusedDenseSqreluDense = None layer_norm_fn, RMSNorm = None, None
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -481,13 +466,15 @@ class GPTModel(GPTPreTrainedModel): ...@@ -481,13 +466,15 @@ class GPTModel(GPTPreTrainedModel):
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
) )
rotary_emb_fraction = getattr(config, "rotary_emb_fraction", 0.0)
if rotary_emb_fraction > 0.0: # Tie all the RotaryEmbedding modules to share the same cos/sin cache
for layer in self.layers[1:]:
layer.mixer.rotary_emb = self.layers[0].mixer.rotary_emb
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False) self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
if self.fused_dropout_add_ln: if self.fused_dropout_add_ln:
if (not self.parallel_block and dropout_add_layer_norm is None) or ( if layer_norm_fn is None:
self.parallel_block and dropout_add_layer_norm_parallel_residual is None raise ImportError("Triton is not installed")
):
raise ImportError("dropout_layer_norm is not installed")
if self.prenorm: if self.prenorm:
self.drop_f = nn.Dropout(config.resid_pdrop) self.drop_f = nn.Dropout(config.resid_pdrop)
norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm
...@@ -571,41 +558,17 @@ class GPTModel(GPTPreTrainedModel): ...@@ -571,41 +558,17 @@ class GPTModel(GPTPreTrainedModel):
hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype)) hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
else: else:
# Set prenorm=False here since we don't need the residual # Set prenorm=False here since we don't need the residual
if not self.parallel_block: hidden_states = layer_norm_fn(
fused_add_norm_fn = ( hidden_states,
dropout_add_rms_norm self.ln_f.weight,
if isinstance(self.ln_f, RMSNorm) self.ln_f.bias,
else dropout_add_layer_norm residual=residual,
) x1=None if not self.parallel_block else hidden_states2,
hidden_states = fused_add_norm_fn( eps=self.ln_f.eps,
hidden_states, dropout_p=self.drop_f.p if self.training else 0.0,
residual, prenorm=False,
self.ln_f.weight, is_rms_norm=isinstance(self.ln_f, RMSNorm)
self.ln_f.bias, )
self.drop_f.p if self.training else 0.0,
self.ln_f.eps,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
else:
fused_add_norm_fn = (
dropout_add_rms_norm_parallel_residual
if isinstance(self.ln_f, RMSNorm)
else dropout_add_layer_norm_parallel_residual
)
hidden_states, _ = fused_add_norm_fn(
hidden_states,
hidden_states2,
residual,
self.ln_f.weight,
self.ln_f.bias,
None,
None,
self.drop_f.p if self.training else 0.0,
self.ln_f.eps,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
return hidden_states return hidden_states
......
...@@ -20,9 +20,9 @@ from flash_attn.modules.mha import MHA ...@@ -20,9 +20,9 @@ from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import FusedMLP, Mlp from flash_attn.modules.mlp import FusedMLP, Mlp
try: try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm from flash_attn.ops.triton.layer_norm import layer_norm_fn
except ImportError: except ImportError:
dropout_add_layer_norm = None layer_norm_fn = None
def create_mixer_cls( def create_mixer_cls(
...@@ -229,8 +229,8 @@ class VisionTransformer(nn.Module): ...@@ -229,8 +229,8 @@ class VisionTransformer(nn.Module):
self.norm = norm_layer(embed_dim) self.norm = norm_layer(embed_dim)
self.fused_dropout_add_ln = fused_dropout_add_ln self.fused_dropout_add_ln = fused_dropout_add_ln
if self.fused_dropout_add_ln and dropout_add_layer_norm is None: if self.fused_dropout_add_ln and layer_norm_fn is None:
raise ImportError("dropout_add_layer_norm is not installed") raise ImportError("Triton is not installed")
# Classifier Head # Classifier Head
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
...@@ -302,16 +302,15 @@ class VisionTransformer(nn.Module): ...@@ -302,16 +302,15 @@ class VisionTransformer(nn.Module):
) )
) )
# Set prenorm=False here since we don't need to the residual # Set prenorm=False here since we don't need to the residual
hidden_states = dropout_add_layer_norm( hidden_states = layer_norm_fn(
hidden_states, hidden_states,
residual,
self.norm.weight, self.norm.weight,
self.norm.bias, self.norm.bias,
self.dropout.p if self.training else 0.0, residual=residual,
self.norm.eps, eps=self.norm.eps,
dropout_p=self.dropout.p if self.training else 0.0,
rowscale=rowscale, rowscale=rowscale,
prenorm=False, prenorm=False,
residual_in_fp32=True,
) )
return hidden_states return hidden_states
......
# Copyright (c) 2022, Tri Dao. # Copyright (c) 2024, Tri Dao.
from functools import partial from functools import partial
from typing import Optional from typing import Optional
...@@ -13,24 +13,9 @@ from flash_attn.modules.mha import MHA ...@@ -13,24 +13,9 @@ from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp from flash_attn.modules.mlp import Mlp
try: try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
except ImportError: except ImportError:
dropout_add_layer_norm = None layer_norm_fn, RMSNorm = None, None
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
except ImportError:
dropout_add_layer_norm_parallel_residual = None
try:
from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm
except ImportError:
RMSNorm, dropout_add_rms_norm = None, None
try:
from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
except ImportError:
dropout_add_rms_norm_parallel_residual = None
class Block(nn.Module): class Block(nn.Module):
...@@ -91,8 +76,7 @@ class Block(nn.Module): ...@@ -91,8 +76,7 @@ class Block(nn.Module):
self.norm2 = norm_cls(dim) self.norm2 = norm_cls(dim)
if self.fused_dropout_add_ln: if self.fused_dropout_add_ln:
assert dropout_add_layer_norm is not None, "dropout_layer_norm is not installed" assert layer_norm_fn is not None, "Triton is not installed"
assert dropout_add_rms_norm is not None, "dropout_layer_norm is not installed"
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance( assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
self.dropout1, nn.Dropout self.dropout1, nn.Dropout
) )
...@@ -137,11 +121,6 @@ class Block(nn.Module): ...@@ -137,11 +121,6 @@ class Block(nn.Module):
before applying the query projection. Useful for e.g., ViT where we only care before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer. about the CLS token in the last layer.
""" """
fused_add_norm_fn = (
dropout_add_rms_norm
if RMSNorm and isinstance(self.norm1, RMSNorm)
else dropout_add_layer_norm
)
if self.prenorm: if self.prenorm:
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln:
dropped = self.drop_path1(self.dropout1(hidden_states)) dropped = self.drop_path1(self.dropout1(hidden_states))
...@@ -160,16 +139,17 @@ class Block(nn.Module): ...@@ -160,16 +139,17 @@ class Block(nn.Module):
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
) )
) )
hidden_states, residual = fused_add_norm_fn( hidden_states, residual = layer_norm_fn(
hidden_states, hidden_states,
residual,
self.norm1.weight, self.norm1.weight,
self.norm1.bias, self.norm1.bias,
self.dropout1.p if self.training else 0.0, residual=residual,
self.norm1.eps, eps=self.norm1.eps,
dropout_p=self.dropout1.p if self.training else 0.0,
rowscale=rowscale1, rowscale=rowscale1,
prenorm=True, prenorm=True,
residual_in_fp32=self.residual_in_fp32, residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm1, RMSNorm)
) )
if mixer_kwargs is None: if mixer_kwargs is None:
mixer_kwargs = {} mixer_kwargs = {}
...@@ -196,16 +176,17 @@ class Block(nn.Module): ...@@ -196,16 +176,17 @@ class Block(nn.Module):
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
) )
) )
hidden_states, residual = fused_add_norm_fn( hidden_states, residual = layer_norm_fn(
hidden_states, hidden_states,
residual,
self.norm2.weight, self.norm2.weight,
self.norm2.bias, self.norm2.bias,
self.dropout2.p if self.training else 0.0, residual=residual,
self.norm2.eps, eps=self.norm2.eps,
dropout_p=self.dropout2.p if self.training else 0.0,
rowscale=rowscale2, rowscale=rowscale2,
prenorm=True, prenorm=True,
residual_in_fp32=self.residual_in_fp32, residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm2, RMSNorm)
) )
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
return hidden_states, residual return hidden_states, residual
...@@ -231,15 +212,16 @@ class Block(nn.Module): ...@@ -231,15 +212,16 @@ class Block(nn.Module):
mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype
) )
) )
hidden_states = fused_add_norm_fn( hidden_states = layer_norm_fn(
mixer_out, mixer_out,
hidden_states,
self.norm1.weight, self.norm1.weight,
self.norm1.bias, self.norm1.bias,
self.dropout1.p if self.training else 0.0, residual=hidden_states,
self.norm1.eps, eps=self.norm1.eps,
dropout_p=self.dropout1.p if self.training else 0.0,
rowscale=rowscale1, rowscale=rowscale1,
prenorm=False, prenorm=False,
is_rms_norm=isinstance(self.norm1, RMSNorm)
) )
if not isinstance(self.mlp, nn.Identity): if not isinstance(self.mlp, nn.Identity):
mlp_out = self.mlp(hidden_states) mlp_out = self.mlp(hidden_states)
...@@ -260,15 +242,16 @@ class Block(nn.Module): ...@@ -260,15 +242,16 @@ class Block(nn.Module):
mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype
) )
) )
hidden_states = fused_add_norm_fn( hidden_states = layer_norm_fn(
mlp_out, mlp_out,
hidden_states,
self.norm2.weight, self.norm2.weight,
self.norm2.bias, self.norm2.bias,
self.dropout2.p if self.training else 0.0, residual=hidden_states,
self.norm2.eps, eps=self.norm2.eps,
dropout_p=self.dropout2.p if self.training else 0.0,
rowscale=rowscale2, rowscale=rowscale2,
prenorm=False, prenorm=False,
is_rms_norm=isinstance(self.norm2, RMSNorm)
) )
return hidden_states return hidden_states
...@@ -320,12 +303,7 @@ class ParallelBlock(nn.Module): ...@@ -320,12 +303,7 @@ class ParallelBlock(nn.Module):
self.norm2 = norm_cls(dim) self.norm2 = norm_cls(dim)
if self.fused_dropout_add_ln: if self.fused_dropout_add_ln:
assert ( assert layer_norm_fn is not None, "Triton is not installed"
dropout_add_layer_norm_parallel_residual is not None
), "dropout_layer_norm is not installed"
assert (
dropout_add_rms_norm_parallel_residual is not None
), "dropout_layer_norm is not installed"
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance( assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
self.dropout1, nn.Dropout self.dropout1, nn.Dropout
) )
...@@ -370,11 +348,6 @@ class ParallelBlock(nn.Module): ...@@ -370,11 +348,6 @@ class ParallelBlock(nn.Module):
""" """
# TODO: Ideally we should only do the allgather / allreduce once for # TODO: Ideally we should only do the allgather / allreduce once for
# the Linear to MLP & Attention # the Linear to MLP & Attention
fused_add_norm_fn = (
dropout_add_rms_norm_parallel_residual
if isinstance(self.norm1, RMSNorm)
else dropout_add_layer_norm_parallel_residual
)
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln:
dropped1 = self.dropout1(hidden_states1) dropped1 = self.dropout1(hidden_states1)
# For the very 1st block, we only want 1 dropout, not two different dropouts # For the very 1st block, we only want 1 dropout, not two different dropouts
...@@ -399,21 +372,24 @@ class ParallelBlock(nn.Module): ...@@ -399,21 +372,24 @@ class ParallelBlock(nn.Module):
weight2, bias2 = ( weight2, bias2 = (
(self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None) (self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None)
) )
hidden_states1, hidden_states2, residual = fused_add_norm_fn( hidden_states1, *rest, residual = layer_norm_fn(
hidden_states1, hidden_states1,
hidden_states2,
residual,
self.norm1.weight, self.norm1.weight,
self.norm1.bias, self.norm1.bias,
weight2, residual=residual,
bias2, x1=hidden_states2,
self.dropout1.p if self.training else 0.0, weight1=weight2,
self.norm1.eps, bias1=bias2,
eps=self.norm1.eps,
dropout_p=self.dropout1.p if self.training else 0.0,
prenorm=True, prenorm=True,
residual_in_fp32=self.residual_in_fp32, residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm1, RMSNorm)
) )
if self.tied_norm: if self.tied_norm:
hidden_states2 = hidden_states1 hidden_states2 = hidden_states1
else:
hidden_states2, = rest
if mixer_kwargs is None: if mixer_kwargs is None:
mixer_kwargs = {} mixer_kwargs = {}
hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs) hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
......
...@@ -87,9 +87,5 @@ RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 ...@@ -87,9 +87,5 @@ RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
# Install FlashAttention # Install FlashAttention
RUN pip install flash-attn==2.4.2 RUN pip install flash-attn==2.4.2
# Install CUDA extensions for fused dense, layer norm # Install CUDA extensions for fused dense
RUN git clone https://github.com/HazyResearch/flash-attention \ RUN pip install git+https://github.com/HazyResearch/flash-attention@v2.4.2#subdirectory=csrc/fused_dense_lib
&& cd flash-attention && git checkout v2.4.2 \
&& cd csrc/layer_norm && pip install . && cd ../../ \
&& cd csrc/fused_dense_lib && pip install . && cd ../../ \
&& cd .. && rm -rf flash-attention
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment