# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial import torch import torch.nn as nn from typing import Tuple, List, Callable, Any, Dict def permute_final_dims(tensor: torch.Tensor, inds: List[int]): zero_index = -1 * len(inds) first_inds = list(range(len(tensor.shape[:zero_index]))) return tensor.permute(first_inds + [zero_index + i for i in inds]) def flatten_final_dims(t: torch.Tensor, no_dims: int): return t.reshape(t.shape[:-no_dims] + (-1,)) def masked_mean(mask, value, dim, eps=1e-4): mask = mask.expand(*value.shape) return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): boundaries = torch.linspace( min_bin, max_bin, no_bins - 1, device=pts.device ) dists = torch.sqrt( torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1) ) return torch.bucketize(dists, boundaries) def dict_multimap(fn, dicts): first = dicts[0] new_dict = {} for k, v in first.items(): all_v = [d[k] for d in dicts] if(type(v) is dict): new_dict[k] = dict_multimap(fn, all_v) else: new_dict[k] = fn(all_v) return new_dict def one_hot(x, v_bins): reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) diffs = x[..., None] - reshaped_bins am = torch.argmin(torch.abs(diffs), dim=-1) return nn.functional.one_hot(am, num_classes=len(v_bins)).float() def batched_gather(data, inds, dim=0, no_batch_dims=0): ranges = [] for i, s in enumerate(data.shape[:no_batch_dims]): r = torch.arange(s) r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) ranges.append(r) remaining_dims = [ slice(None) for _ in range(len(data.shape) - no_batch_dims) ] remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds ranges.extend(remaining_dims) return data[ranges] # With tree_map, a poor man's JAX tree_map def dict_map(fn, dic, leaf_type): new_dict = {} for k, v in dic.items(): if(type(v) is dict): new_dict[k] = dict_map(fn, v, leaf_type) else: new_dict[k] = tree_map(fn, v, leaf_type) return new_dict def tree_map(fn, tree, leaf_type): if(isinstance(tree, dict)): return dict_map(fn, tree, leaf_type) elif(isinstance(tree, list)): return [tree_map(fn, x, leaf_type) for x in tree] elif(isinstance(tree, tuple)): return tuple([tree_map(fn, x, leaf_type) for x in tree]) elif(isinstance(tree, leaf_type)): return fn(tree) else: print(type(tree)) raise ValueError("Not supported") tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) def chunk_layer( layer: Callable, inputs: Dict[str, Any], chunk_size: int, no_batch_dims: int, ) -> Any: """ Implements the "chunking" procedure described in section 1.11.8. Layer outputs and inputs are assumed to be simple "pytrees," consisting only of (arbitrarily nested) lists, tuples, and dicts with torch.Tensor leaves. Args: layer: The layer to be applied chunk-wise inputs: A (non-nested) dictionary of keyworded inputs. All leaves must be tensors and must share the same batch dimensions. chunk_size: The number of sub-batches per chunk. If multiple batch dimensions are specified, a "sub-batch" is defined as a single indexing of all batch dimensions simultaneously (s.t. the number of sub-batches is the product of the batch dimensions). no_batch_dims: How many of the initial dimensions of each input tensor can be considered batch dimensions. Returns: The reassembled output of the layer on the inputs. """ if(not (len(inputs) > 0)): raise ValueError("Must provide at least one input") def fetch_dims(tree): shapes = [] tree_type = type(tree) if(tree_type is dict): for v in tree.values(): shapes.extend(fetch_dims(v)) elif(tree_type is list or tree_type is tuple): for t in tree: shapes.extend(fetch_dims(t)) elif(tree_type is torch.Tensor): shapes.append(tree.shape) else: raise ValueError("Not supported") return shapes initial_dims = [shape[:no_batch_dims] for shape in fetch_dims(inputs)] orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)]) def prep_inputs(t): # TODO: make this more memory efficient. This sucks if(not sum(t.shape[:no_batch_dims]) == no_batch_dims): t = t.expand(*orig_batch_dims, *t.shape[no_batch_dims:]) t = t.reshape(-1, *t.shape[no_batch_dims:]) return t flattened_inputs = tensor_tree_map(prep_inputs, inputs) flat_batch_dim = 1 for d in orig_batch_dims: flat_batch_dim *= d no_chunks = ( flat_batch_dim // chunk_size + (flat_batch_dim % chunk_size != 0) ) i = 0 out = None for _ in range(no_chunks): # Chunk the input select_chunk = lambda t: t[i:i+chunk_size] if t.shape[0] != 1 else t chunks = tensor_tree_map(select_chunk, flattened_inputs) # Run the layer on the chunk output_chunk = layer(**chunks) # Allocate space for the output if(out is None): allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]) out = tensor_tree_map(allocate, output_chunk) # Put the chunk in its pre-allocated space out_type = type(output_chunk) if(out_type is dict): def assign(d1, d2): for k,v in d1.items(): if(type(v) is dict): assign(v, d2[k]) else: v[i:i+chunk_size] = d2[k] assign(out, output_chunk) elif(out_type is tuple): for x1, x2 in zip(out, output_chunk): x1[i:i+chunk_size] = x2 elif(out_type is torch.Tensor): out[i:i+chunk_size] = output_chunk else: raise ValueError("Not supported") i += chunk_size reshape = lambda t: t.reshape(orig_batch_dims + t.shape[1:]) out = tensor_tree_map(reshape, out) return out