"__init__.py" did not exist on "c25c636d2323bba42a96242bb6657e66e3d3698b"
Commit 6fc1e07d authored by Tri Dao's avatar Tri Dao
Browse files

[Block] Re-enable DropPath

parent 9ee0ff1d
......@@ -8,7 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
# from torchvision.ops import StochasticDepth
from torchvision.ops import StochasticDepth
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp
......@@ -70,12 +70,12 @@ class Block(nn.Module):
mlp_cls = partial(Mlp, hidden_features=4 * dim)
self.mixer = mixer_cls(dim)
self.dropout1 = dropout_cls(resid_dropout1)
# self.drop_path1 = StochasticDepth(drop_path1, mode='row')
self.drop_path1 = StochasticDepth(drop_path1, mode='row')
self.norm1 = norm_cls(dim)
self.mlp = mlp_cls(dim)
if not isinstance(self.mlp, nn.Identity):
self.dropout2 = dropout_cls(resid_dropout2)
# self.drop_path2 = StochasticDepth(drop_path2, mode='row')
self.drop_path2 = StochasticDepth(drop_path2, mode='row')
self.norm2 = norm_cls(dim)
if self.fused_dropout_add_ln:
......@@ -129,14 +129,13 @@ class Block(nn.Module):
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
rowscale1 = None
# if self.drop_path1.p == 0 or not self.training:
# rowscale1 = None
# else:
# rowscale1 = self.drop_path1(torch.ones(
# hidden_states.shape[:-1], device=hidden_states.device,
# dtype=hidden_states.dtype)
# )
if self.drop_path1.p == 0 or not self.training:
rowscale1 = None
else:
rowscale1 = self.drop_path1(torch.ones(
hidden_states.shape[:-1], device=hidden_states.device,
dtype=hidden_states.dtype)
)
hidden_states, residual = fused_add_norm_fn(
hidden_states, residual, self.norm1.weight, self.norm1.bias,
self.dropout1.p if self.training else 0.0, self.norm1.eps,
......@@ -157,14 +156,13 @@ class Block(nn.Module):
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
# if self.drop_path2.p == 0 or not self.training:
# rowscale2 = None
# else:
# rowscale2 = self.drop_path2(torch.ones(
# hidden_states.shape[:-1], device=hidden_states.device,
# dtype=hidden_states.dtype)
# )
rowscale2 = None
if self.drop_path2.p == 0 or not self.training:
rowscale2 = None
else:
rowscale2 = self.drop_path2(torch.ones(
hidden_states.shape[:-1], device=hidden_states.device,
dtype=hidden_states.dtype)
)
hidden_states, residual = fused_add_norm_fn(
hidden_states, residual, self.norm2.weight, self.norm2.bias,
self.dropout2.p if self.training else 0.0, self.norm2.eps,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment