Commit d2cf0f50 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Minor reformatting & default change

parent 0d148a7d
......@@ -54,7 +54,8 @@ class PairTransition(nn.Module):
z = self.relu(z)
# [*, N_res, N_res, C_z]
z = self.linear_2(z) * mask
z = self.linear_2(z)
z = z * mask
return z
......@@ -71,7 +72,6 @@ class PairTransition(nn.Module):
no_batch_dims=len(z.shape[:-2]),
)
def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
......
......@@ -343,7 +343,7 @@ class ChunkSizeTuner:
def __init__(self,
# Heuristically, runtimes for most of the modules in the network
# plateau earlier than this on all GPUs I've run the model on.
max_chunk_size=256,
max_chunk_size=512,
):
self.max_chunk_size = max_chunk_size
self.cached_chunk_size = None
......@@ -402,7 +402,7 @@ class ChunkSizeTuner:
representative_fn: Callable,
args: Tuple[Any],
min_chunk_size: int,
) -> int:
) -> int:
consistent = True
remove_tensors = lambda a: a.shape if type(a) is torch.Tensor else a
arg_data = tree_map(remove_tensors, args, object)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment