# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial import torch import torch.nn as nn def permute_final_dims(tensor, *inds): zero_index = -1 * len(inds) first_inds = range(len(tensor.shape[:zero_index])) return tensor.permute(*first_inds, *[zero_index + i for i in inds]) def flatten_final_dims(tensor, no_dims): return tensor.reshape(*tensor.shape[:-no_dims], -1) def masked_mean(mask, value, dim, eps=1e-10): mask = mask.expand(*value.shape) return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): boundaries = torch.linspace( min_bin, max_bin, no_bins - 1, device=pts.device ) dists = torch.sqrt( torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1) ) return torch.bucketize(dists, boundaries) def stack_tensor_dicts(dicts): first = dicts[0] new_dict = {} for k, v in first.items(): all_v = [d[k] for d in dicts] if(type(v) is dict): new_dict[k] = stack_tensor_dicts(all_v) else: new_dict[k] = torch.stack(all_v) return new_dict def one_hot(x, v_bins): reshaped_bins = v_bins.view(*((1,) * len(x.shape) + (len(v_bins),))) diffs = x[..., None] - reshaped_bins am = torch.argmin(torch.abs(diffs), dim=-1) return nn.functional.one_hot(am, num_classes=len(v_bins)).float() def batched_gather(data, inds, dim=0, no_batch_dims=0): ranges = [] for i, s in enumerate(data.shape[:no_batch_dims]): r = torch.arange(s) r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) ranges.append(r) remaining_dims = [ slice(None) for _ in range(len(data.shape) - no_batch_dims) ] remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds ranges.extend(remaining_dims) return data[ranges] # With tree_map, a poor man's JAX tree_map def dict_map(fn, dic, leaf_type): new_dict = {} for k, v in dic.items(): if(type(v) is dict): new_dict[k] = dict_map(fn, v, leaf_type) else: new_dict[k] = tree_map(fn, v, leaf_type) return new_dict def tree_map(fn, tree, leaf_type): tree_type = type(tree) if(tree_type is dict): return dict_map(fn, tree, leaf_type) elif(tree_type is list): return [tree_map(fn, x, leaf_type) for x in tree] elif(tree_type is tuple): return tuple([tree_map(fn, x, leaf_type) for x in tree]) elif(tree_type is leaf_type): return fn(tree) else: raise ValueError("Not supported") tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) def chunk_layer(layer, inputs, chunk_size, no_batch_dims): """ Implements the "chunking" procedure described in section 1.11.8. Layer outputs and inputs are interpreted as simplified "pytrees," consisting only of (nested) lists, tuples, and dicts with tensor leaves. Args: layer: The layer to be applied chunk-wise inputs: A (nested) dictionary of keyworded inputs. All leaves must be tensors and must share the same batch dimensions. chunk_size: The number of sub-batches per chunk. If multiple batch dimensions are specified, a "sub-batch" is defined as a single indexing of all batch dimensions simultaneously (s.t. the number of sub-batches is the product of the batch dimensions). no_batch_dims: How many of the initial dimensions of each input tensor can be considered batch dimensions. Returns: The reassembled output of the layer on the inputs. """ if(not (len(inputs) > 0)): raise ValueError("Must provide at least one input") def fetch_dims(tree): shapes = [] tree_type = type(tree) if(tree_type is dict): for v in tree.values(): shapes.extend(fetch_dims(v)) elif(tree_type is list or tree_type is tuple): for t in tree: shapes.extend(fetch_dims(t)) elif(tree_type is torch.Tensor): shapes.append(tree.shape) else: raise ValueError("Not supported") return shapes initial_dims = [shape[:no_batch_dims] for shape in fetch_dims(inputs)] orig_batch_dims = [max(s) for s in zip(*initial_dims)] def prep_inputs(t): t = t.expand(*orig_batch_dims, *t.shape[no_batch_dims:]) t = t.reshape(-1, *t.shape[no_batch_dims:]) return t #shape = lambda t: t.shape #print(tensor_tree_map(shape, inputs)) flattened_inputs = tensor_tree_map(prep_inputs, inputs) flat_batch_dim = 1 for d in orig_batch_dims: flat_batch_dim *= d no_chunks = ( flat_batch_dim // chunk_size + (flat_batch_dim % chunk_size != 0) ) i = 0 out = None for _ in range(no_chunks): # Chunk the input select_chunk = lambda t: t[i:i+chunk_size] chunks = tensor_tree_map(select_chunk, flattened_inputs) # Run the layer on the chunk output_chunk = layer(**chunks) # Allocate space for the output if(out is None): allocate = lambda t: t.new_zeros(flat_batch_dim, *t.shape[1:]) out = tensor_tree_map(allocate, output_chunk) # Put the chunk in its pre-allocated space out_type = type(output_chunk) if(out_type is dict): def assign(d1, d2): for k,v in d1.items(): if(type(v) is dict): assign(v, d2[k]) else: v[i:i+chunk_size] = d2[k] assign(out, output_chunk) elif(out_type is tuple): for x1, x2 in zip(out, output_chunk): x1[i:i+chunk_size] = x2 elif(out_type is torch.Tensor): out[i:i+chunk_size] = output_chunk else: raise ValueError("Not supported") i += chunk_size reshape = lambda t: t.reshape(*orig_batch_dims, *t.shape[1:]) out = tensor_tree_map(reshape, out) return out