Unverified Commit 5cee0714 authored by Tri Dao's avatar Tri Dao Committed by GitHub
Browse files

Merge pull request #164 from ZhiyuanChen/patch-1

make mlp hidden_features defaults to 4*in_features
parents 853ff729 8c424156
...@@ -17,7 +17,7 @@ class Mlp(nn.Module): ...@@ -17,7 +17,7 @@ class Mlp(nn.Module):
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__() super().__init__()
out_features = out_features or in_features out_features = out_features or in_features
hidden_features = hidden_features or in_features hidden_features = hidden_features or in_features * 4
self.return_residual = return_residual self.return_residual = return_residual
self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs) self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs)
self.activation = activation self.activation = activation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment