# coding=utf-8 # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # 1. Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # # 2. Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # 3. Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # import torch import tensorly as tl tl.set_backend('pytorch') from tltorch.factorized_tensors.core import FactorizedTensor einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' def _contract_dense(x, weight, separable=False, operator_type='diagonal'): order = tl.ndim(x) # batch-size, in_channels, x, y... x_syms = list(einsum_symbols[:order]) # in_channels, out_channels, x, y... weight_syms = list(x_syms[1:]) # no batch-size # batch-size, out_channels, x, y... if separable: out_syms = [x_syms[0]] + list(weight_syms) else: weight_syms.insert(1, einsum_symbols[order]) # outputs out_syms = list(weight_syms) out_syms[0] = x_syms[0] if operator_type == 'diagonal': pass elif operator_type == 'block-diagonal': weight_syms.insert(-1, einsum_symbols[order+1]) out_syms[-1] = weight_syms[-2] elif operator_type == 'vector': weight_syms.pop() else: raise ValueError(f"Unkonw operator type {operator_type}") eq= ''.join(x_syms) + ',' + ''.join(weight_syms) + '->' + ''.join(out_syms) if not torch.is_tensor(weight): weight = weight.to_tensor() return tl.einsum(eq, x, weight) def _contract_cp(x, cp_weight, separable=False, operator_type='diagonal'): order = tl.ndim(x) x_syms = str(einsum_symbols[:order]) rank_sym = einsum_symbols[order] out_sym = einsum_symbols[order+1] out_syms = list(x_syms) if separable: factor_syms = [einsum_symbols[1]+rank_sym] #in only else: out_syms[1] = out_sym factor_syms = [einsum_symbols[1]+rank_sym, out_sym+rank_sym] #in, out factor_syms += [xs+rank_sym for xs in x_syms[2:]] #x, y, ... if operator_type == 'diagonal': pass elif operator_type == 'block-diagonal': out_syms[-1] = einsum_symbols[order+2] factor_syms += [out_syms[-1] + rank_sym] elif operator_type == 'vector': factor_syms.pop() else: raise ValueError(f"Unkonw operator type {operator_type}") eq = x_syms + ',' + rank_sym + ',' + ','.join(factor_syms) + '->' + ''.join(out_syms) return tl.einsum(eq, x, cp_weight.weights, *cp_weight.factors) def _contract_tucker(x, tucker_weight, separable=False, operator_type='diagonal'): order = tl.ndim(x) x_syms = str(einsum_symbols[:order]) out_sym = einsum_symbols[order] out_syms = list(x_syms) if separable: core_syms = einsum_symbols[order+1:2*order] # factor_syms = [einsum_symbols[1]+core_syms[0]] #in only factor_syms = [xs+rs for (xs, rs) in zip(x_syms[1:], core_syms)] #x, y, ... else: core_syms = einsum_symbols[order+1:2*order+1] out_syms[1] = out_sym factor_syms = [einsum_symbols[1]+core_syms[0], out_sym+core_syms[1]] #out, in factor_syms += [xs+rs for (xs, rs) in zip(x_syms[2:], core_syms[2:])] #x, y, ... if operator_type == 'diagonal': pass elif operator_type == 'block-diagonal': raise NotImplementedError(f"Operator type {operator_type} not implemented for Tucker") else: raise ValueError(f"Unkonw operator type {operator_type}") eq = x_syms + ',' + core_syms + ',' + ','.join(factor_syms) + '->' + ''.join(out_syms) return tl.einsum(eq, x, tucker_weight.core, *tucker_weight.factors) def _contract_tt(x, tt_weight, separable=False, operator_type='diagonal'): order = tl.ndim(x) x_syms = list(einsum_symbols[:order]) weight_syms = list(x_syms[1:]) # no batch-size if not separable: weight_syms.insert(1, einsum_symbols[order]) # outputs out_syms = list(weight_syms) out_syms[0] = x_syms[0] else: out_syms = list(x_syms) if operator_type == 'diagonal': pass elif operator_type == 'block-diagonal': weight_syms.insert(-1, einsum_symbols[order+1]) out_syms[-1] = weight_syms[-2] elif operator_type == 'vector': weight_syms.pop() else: raise ValueError(f"Unkonw operator type {operator_type}") rank_syms = list(einsum_symbols[order+2:]) tt_syms = [] for i, s in enumerate(weight_syms): tt_syms.append([rank_syms[i], s, rank_syms[i+1]]) eq = ''.join(x_syms) + ',' + ','.join(''.join(f) for f in tt_syms) + '->' + ''.join(out_syms) return tl.einsum(eq, x, *tt_weight.factors) def get_contract_fun(weight, implementation='reconstructed', separable=False): """Generic ND implementation of Fourier Spectral Conv contraction Parameters ---------- weight : tensorly-torch's FactorizedTensor implementation : {'reconstructed', 'factorized'}, default is 'reconstructed' whether to reconstruct the weight and do a forward pass (reconstructed) or contract directly the factors of the factorized weight with the input (factorized) Returns ------- function : (x, weight) -> x * weight in Fourier space """ if implementation == 'reconstructed': return _contract_dense elif implementation == 'factorized': if torch.is_tensor(weight): return _contract_dense elif isinstance(weight, FactorizedTensor): if weight.name.lower() == 'complexdense': return _contract_dense elif weight.name.lower() == 'complextucker': return _contract_tucker elif weight.name.lower() == 'complextt': return _contract_tt elif weight.name.lower() == 'complexcp': return _contract_cp else: raise ValueError(f'Got unexpected factorized weight type {weight.name}') else: raise ValueError(f'Got unexpected weight type of class {weight.__class__.__name__}') else: raise ValueError(f'Got {implementation=}, expected "reconstructed" or "factorized"')