"...git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "43baf787bcc3ceefc007c0bafb3b122a56911cc9"
Commit 68ba77e5 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Continue fixing loss bugs, clean up structure module docs

parent 33941e46
......@@ -486,12 +486,12 @@ def _frames_and_literature_positions_to_atom14_pos(
):
# [*, N, 14, 4, 4]
default_4x4 = default_frames[f,...]
default_4x4 = default_frames[f, ...]
# [*, N, 14]
group_mask = group_idx[f,...]
group_mask = group_idx[f, ...]
# [N, 14, 8]
# [*, N, 14, 8]
group_mask = nn.functional.one_hot(
group_mask, num_classes=default_frames.shape[-3],
)
......@@ -504,11 +504,11 @@ def _frames_and_literature_positions_to_atom14_pos(
lambda x: torch.sum(x, dim=-1)
)
# [N, 14, 1]
# [*, N, 14, 1]
atom_mask = atom_mask[f,...].unsqueeze(-1)
# [N, 14, 3]
lit_positions = lit_positions[f,...]
# [*, N, 14, 3]
lit_positions = lit_positions[f, ...]
pred_positions = t_atoms_to_global.apply(lit_positions)
pred_positions *= atom_mask
......@@ -758,19 +758,27 @@ class StructureModule(nn.Module):
def _init_residue_constants(self, device):
if(self.default_frames is None):
self.default_frames = torch.tensor(
restype_rigid_group_default_frame, device=device,
restype_rigid_group_default_frame,
device=device,
requires_grad=False,
)
if(self.group_idx is None):
self.group_idx = torch.tensor(
restype_atom14_to_rigid_group, device=device,
restype_atom14_to_rigid_group,
device=device,
requires_grad=False,
)
if(self.atom_mask is None):
self.atom_mask = torch.tensor(
restype_atom14_mask, device=device,
restype_atom14_mask,
device=device,
requires_grad=False,
)
if(self.lit_positions is None):
self.lit_positions = torch.tensor(
restype_atom14_rigid_group_positions, device=device,
restype_atom14_rigid_group_positions,
device=device,
requires_grad=False,
)
def torsion_angles_to_frames(self, t, alpha, f):
......
......@@ -366,6 +366,7 @@ residue_atoms = {
# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
# in LEU, VAL and ARG can be resolved by using the 3d constellations of
# the 'ambiguous' atoms and their neighbours)
# TODO: ^ interpret this
residue_atom_renaming_swaps = {
'ASP': {'OD1': 'OD2'},
'GLU': {'OE1': 'OE2'},
......@@ -895,3 +896,25 @@ def make_atom14_dists_bounds(overlap_tolerance=1.5,
'upper_bound': restype_atom14_bond_upper_bound, # shape (21,14,14)
'stddev': restype_atom14_bond_stddev, # shape (21,14,14)
}
restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32)
restype_atom14_ambiguous_atoms_swap_idx = (
np.tile(np.arange(14, dtype=np.int), (21, 1))
)
def _make_atom14_ambiguity_feats():
for res, pairs in residue_atom_renaming_swaps.items():
res_idx = restype_order[restype_3to1[res]]
for atom1, atom2 in pairs.items():
atom1_idx = restype_name_to_atom14_names[res].index(atom1)
atom2_idx = restype_name_to_atom14_names[res].index(atom2)
restype_atom14_ambiguous_atoms[res_idx, atom1_idx] = 1
restype_atom14_ambiguous_atoms[res_idx, atom2_idx] = 1
restype_atom14_ambiguous_atoms_swap_idx[res_idx, atom1_idx] = (
atom2_idx
)
restype_atom14_ambiguous_atoms_swap_idx[res_idx, atom2_idx] = (
atom1_idx
)
_make_atom14_ambiguity_feats()
......@@ -335,3 +335,15 @@ def quat_to_rot(
# [*, 3, 3]
return torch.sum(quat, dim=(-3, -4))
def affine_vector_to_4x4(vector):
quats = vector[..., :4]
trans = vector[..., 4:]
four_by_four = torch.zeros(
(*vector.shape[:-1], 4, 4), device=vector.device
)
four_by_four[..., :3, :3] = quat_to_rot(quats)
four_by_four[..., :3, 3] = trans
four_by_four[..., 3, 3] = 1
return four_by_four
......@@ -18,7 +18,7 @@ import torch
import torch.nn as nn
from typing import Dict
import openfold.np.residue_constants as residue_constants
import openfold.np.residue_constants as rc
from openfold.utils.affine_utils import T
from openfold.utils.tensor_utils import (
batched_gather,
......@@ -27,9 +27,9 @@ from openfold.utils.tensor_utils import (
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
is_gly = (aatype == residue_constants.restype_order['G'])
ca_idx = residue_constants.atom_order['CA']
cb_idx = residue_constants.atom_order['CB']
is_gly = (aatype == rc.restype_order['G'])
ca_idx = rc.atom_order['CA']
cb_idx = rc.atom_order['CB']
pseudo_beta = torch.where(
is_gly[..., None].expand(*((-1,) * len(is_gly.shape)), 3),
all_atom_positions[..., ca_idx, :],
......@@ -52,18 +52,18 @@ def get_chi_atom_indices():
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in residue_constants.restypes + unknown residue type
in the order specified in rc.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices = []
for residue_name in residue_constants.restypes:
residue_name = residue_constants.restype_1to3[residue_name]
residue_chi_angles = residue_constants.chi_angles_atoms[residue_name]
for residue_name in rc.restypes:
residue_name = rc.restype_1to3[residue_name]
residue_chi_angles = rc.chi_angles_atoms[residue_name]
atom_indices = []
for chi_angle in residue_chi_angles:
atom_indices.append(
[residue_constants.atom_order[atom] for atom in chi_angle])
[rc.atom_order[atom] for atom in chi_angle])
for _ in range(4 - len(atom_indices)):
atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.
chi_atom_indices.append(atom_indices)
......@@ -74,6 +74,7 @@ def get_chi_atom_indices():
def compute_residx(batch):
out = {}
float_type = batch["seq_mask"].dtype
aatype = batch["aatype"]
......@@ -81,19 +82,20 @@ def compute_residx(batch):
restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
restype_atom14_mask = []
for rt in residue_constants.restypes:
atom_names = residue_constants.restype_name_to_atom14_names[
residue_constants.restype_1to3[rt]]
for rt in rc.restypes:
atom_names = rc.restype_name_to_atom14_names[
rc.restype_1to3[rt]
]
restype_atom14_to_atom37.append([
(residue_constants.atom_order[name] if name else 0)
(rc.atom_order[name] if name else 0)
for name in atom_names
])
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in residue_constants.atom_types
for name in rc.atom_types
])
restype_atom14_mask.append(
......@@ -118,24 +120,27 @@ def compute_residx(batch):
residx_atom14_to_atom37 = restype_atom14_to_atom37[aatype]
residx_atom14_mask = restype_atom14_mask[aatype]
batch['atom14_atom_exists'] = residx_atom14_mask
batch['residx_atom14_to_atom37'] = residx_atom14_to_atom37
out["residx_atom14_to_atom37"] = residx_atom14_to_atom37
out["atom14_atom_exists"] = residx_atom14_mask
# create the gather indices for mapping back
residx_atom37_to_atom14 = restype_atom37_to_atom14[aatype]
batch['residx_atom37_to_atom14'] = residx_atom37_to_atom14
out["residx_atom37_to_atom14"] = residx_atom37_to_atom14
# create the corresponding mask
restype_atom37_mask = torch.zeros([21, 37], dtype=float_type)
for restype, restype_letter in enumerate(residue_constants.restypes):
restype_name = residue_constants.restype_1to3[restype_letter]
atom_names = residue_constants.residue_atoms[restype_name]
for restype, restype_letter in enumerate(rc.restypes):
restype_name = rc.restype_1to3[restype_letter]
atom_names = rc.residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = residue_constants.atom_order[atom_name]
atom_type = rc.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = restype_atom37_mask[aatype]
batch['atom37_atom_exists'] = residx_atom37_mask
out["atom37_atom_exists"] = residx_atom37_mask
return out
def atom14_to_atom37(atom14, batch):
......@@ -225,9 +230,9 @@ def atom37_to_torsion_angles(
all_atom_pos, atom_indices, -2, len(atom_indices.shape[:-2])
)
chi_angles_mask = list(residue_constants.chi_angles_mask)
chi_angles_mask = list(rc.chi_angles_mask)
chi_angles_mask.append([0., 0., 0., 0.])
chi_angles_mask = all_atom_pos.new_tensor(chi_angles_mask)
chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask)
chis_mask = chi_angles_mask[aatype, :]
......@@ -282,7 +287,7 @@ def atom37_to_torsion_angles(
)[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]
chi_is_ambiguous = torsion_angles_sin_cos.new_tensor(
residue_constants.chi_pi_periodic,
rc.chi_pi_periodic,
)[aatype, ...]
mirror_torsion_angles = torch.cat(
......@@ -307,6 +312,7 @@ def atom37_to_frames(
aatype: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
**kwargs,
) -> Dict[str, torch.Tensor]:
batch_dims = len(aatype.shape[:-1])
......@@ -314,13 +320,14 @@ def atom37_to_frames(
restype_rigidgroup_base_atom_names[:, 0, :] = ['C', 'CA', 'N']
restype_rigidgroup_base_atom_names[:, 3, :] = ['CA', 'C', 'O']
for restype, restype_letter in enumerate(residue_constants.restypes):
resname = residue_constants.restype_1to3[restype_letter]
for restype, restype_letter in enumerate(rc.restypes):
resname = rc.restype_1to3[restype_letter]
for chi_idx in range(4):
if(residue_constants.chi_angles_mask[restype][chi_idx]):
names = residue_constants.chi_angles_atoms[resname][chi_idx]
if(rc.chi_angles_mask[restype][chi_idx]):
names = rc.chi_angles_atoms[resname][chi_idx]
restype_rigidgroup_base_atom_names[
restype, chi_idx + 4, :] = atom_names[1:]
restype, chi_idx + 4, :
] = names[1:]
restype_rigidgroup_mask = torch.zeros(
(*aatype.shape[:-1], 21, 8),
......@@ -330,9 +337,11 @@ def atom37_to_frames(
)
restype_rigidgroup_mask[:, 0] = 1
restype_rigidgroup_mask[:, 3] = 1
restype_rigidgroup_mask[:20, 4:] = residue_constants.chi_angles_mask
restype_rigidgroup_mask[:20, 4:] = (
all_atom_mask.new_tensor(rc.chi_angles_mask)
)
lookuptable = residue_constants.atom_order.copy()
lookuptable = rc.atom_order.copy()
lookuptable[''] = 0
lookup = np.vectorize(lambda x: lookuptable[x])
restype_rigidgroup_base_atom37_idx = lookup(
......@@ -349,7 +358,7 @@ def atom37_to_frames(
)
residx_rigidgroup_base_atom37_idx = batched_gather(
residx_rigidgroup_base_atom37_idx,
restype_rigidgroup_base_atom37_idx,
aatype,
dim=-3,
no_batch_dims=batch_dims,
......@@ -363,9 +372,9 @@ def atom37_to_frames(
)
gt_frames = T.from_3_points(
point_on_neg_x_axis=base_atom_pos[..., 0, :],
p_neg_x_axis=base_atom_pos[..., 0, :],
origin=base_atom_pos[..., 1, :],
point_on_xy_plane=base_atom_pos[..., 2, :],
p_xy_plane=base_atom_pos[..., 2, :],
)
group_exists = batched_gather(
......@@ -381,33 +390,31 @@ def atom37_to_frames(
dim=-1,
no_batch_dims=len(all_atom_mask.shape[:-1])
)
gt_exists = torch.min(gt_atoms_exist, dim=-1) * group_exists
gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists
rots = torch.eye(3, device=aatype.device, requires_grad=False)
rots = rots.view(*((1,) * batch_dims), 1, 3, 3)
rots = rots.expand(*((-1,) * batch_dims), 8, -1, -1)
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1
rots[..., 0, 2, 2] = -1
gt_frames = gt_frames.compose(T(rots, None))
gt_frames = gt_frames.compose(T(rots, None))
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
*((1,) * batch_dims), 21, 8
)
restype_rigidgroup_rots = torch.eye(
3, device=aatype.device, requires_grad=False
)
restype_rigidgroup_rots = restype_rigidgroup_rots.view(
*((1,) * batch_dims), 1, 1, 3, 3
)
restype_rigidgroup_rots = restype_rigidgroup_rots.expand(
*((-1,) * batch_dims), 21, 8, 3, 3
restype_rigidgroup_rots = torch.tile(
restype_rigidgroup_rots,
(*((1,) * batch_dims), 21, 8, 1, 1),
)
for resname, _ in residue_constants.residue_atom_renaming_swaps.items():
restype = residue_constants.restype_order[
residue_constants.restype3to1[resname]
for resname, _ in rc.residue_atom_renaming_swaps.items():
restype = rc.restype_order[
rc.restype_3to1[resname]
]
chi_idx = int(sum(residue_constants.chi_angles_mask[restype]) - 1)
chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1)
restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1
restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1
restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1
......@@ -419,18 +426,17 @@ def atom37_to_frames(
no_batch_dims=batch_dims,
)
residx_rigidgroup_ambiguity_rot = utils.batched_gather(
residx_rigidgroup_ambiguity_rot = batched_gather(
restype_rigidgroup_rots,
aatype,
dim=-4,
no_batch_dims=batch_dims,
)
alt_gt_frames = gt_frames.apply(T(residx_rigidgroup_ambiguity_rot, None))
alt_gt_frames = gt_frames.compose(T(residx_rigidgroup_ambiguity_rot, None))
# TODO: Verify that I can get away with skipping the flat12 format
gt_frames_tensor = gt_frames.to_tensor()
alt_gt_frames_tensor = alt_gt_frames.to_tensor()
gt_frames_tensor = gt_frames.to_4x4()
alt_gt_frames_tensor = alt_gt_frames.to_4x4()
return {
'rigidgroups_gt_frames': gt_frames_tensor,
......@@ -477,7 +483,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8
to_concat = [dgram, template_mask_2d[..., None]]
aatype_one_hot = nn.functional.one_hot(
batch["template_aatype"], residue_constants.restype_num + 2,
batch["template_aatype"], rc.restype_num + 2,
)
n_res = batch["template_aatype"].shape[-1]
......@@ -492,7 +498,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8
)
)
n, ca, c = [residue_constants.atom_order[a] for a in ['N', 'CA', 'C']]
n, ca, c = [rc.atom_order[a] for a in ['N', 'CA', 'C']]
t_aa_masks = batch["template_all_atom_masks"]
template_mask = (
......@@ -522,7 +528,7 @@ def build_extra_msa_feat(batch):
# adapted from model/tf/data_transforms.py
def build_msa_feat(protein):
def build_msa_feat(batch):
"""Create and concatenate MSA features."""
# Whether there is a domain break. Always zero for chains, but keeping
# for compatibility with domain datasets.
......@@ -544,7 +550,7 @@ def build_msa_feat(protein):
deletion_value.unsqueeze(-1),
]
if 'cluster_profile' in protein:
if 'cluster_profile' in batch:
deletion_mean_value = (
tf.atan(batch['cluster_deletion_mean'] / 3.) * (2. / np.pi))
msa_feat.extend([
......@@ -560,4 +566,53 @@ def build_msa_feat(protein):
batch['msa_feat'] = torch.cat(msa_feat, dim=-1)
batch['target_feat'] = torch.cat(target_feat, dim=-1)
return protein
return batch
def build_ambiguity_feats(batch: Dict[str, torch.Tensor]) -> None:
"""
Compute features required by compute_renamed_ground_truth (Alg. 26)
Args:
batch:
str/tensor dictionary containing:
* atom14_gt_positions: [*, N, 14, 3] ground truth pos.
* atom14_gt_exists: [*, N, 14] atom mask
* aatype: [*, N] residue indices
Returns:
str/tensor dictionary containing:
* atom14_atom_is_ambiguous: [*, N, 14] mask of ambiguous atoms
* atom14_alt_gt_positions: [*, N, 14, 3] renamed positions
"""
ambiguous_atoms = (
batch["atom14_gt_positions"].new_tensor(
rc.restype_atom14_ambiguous_atoms, requires_grad=False,
)
)
atom14_atom_is_ambiguous = ambiguous_atoms[batch["aatype"], ...]
# Swap pairs of ambiguous positions
swap_idx = rc.restype_atom14_ambiguous_atoms_swap_idx
swap_mat = np.eye(swap_idx.shape[-1])[swap_idx] # one-hot swap_idx
swap_mat = batch["atom14_gt_positions"].new_tensor(
swap_mat, requires_grad=False
)
swap_mat = swap_mat[batch["aatype"], ...]
atom14_alt_gt_positions = (
torch.sum(
batch["atom14_gt_positions"][..., None, :] * swap_mat[..., None],
dim=-3
)
)
atom14_alt_gt_exists = (
torch.sum(
batch["atom14_gt_exists"][..., None] * swap_mat, dim=-2
)
)
return {
"atom14_atom_is_ambiguous": atom14_atom_is_ambiguous,
"atom14_alt_gt_positions": atom14_alt_gt_positions,
"atom14_alt_gt_exists": atom14_alt_gt_exists,
}
......@@ -89,15 +89,15 @@ def compute_fape(
target_positions[..., None, :, :],
)
error_dist = torch.sqrt(
(pred_positions - target_positions)**2 + eps
torch.sum((local_pred_pos - local_target_pos)**2, dim=-1) + eps
)
if(l1_clamp_distance is not None):
error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance)
normed_error = error_dist / length_scale
normed_error *= frames_mask.unsqueeze(-1)
normed_error *= positions_mask.unsqueeze(-2)
normed_error *= frames_mask[..., None]
normed_error *= positions_mask[..., None, :]
norm_factor = (
torch.sum(frames_mask, dim=-1) *
......@@ -109,67 +109,71 @@ def compute_fape(
return normed_error
# DISCREPANCY: figure out if loss clamping happens in 90% of each bach or in 90% of batches
def backbone_loss(
batch: Dict[str, torch.Tensor],
pred_aff_tensor: torch.Tensor,
backbone_affine_tensor: torch.Tensor,
backbone_affine_mask: torch.Tensor,
traj: torch.Tensor,
use_clamped_fape: Optional[torch.Tensor] = None,
clamp_distance: float = 10.,
loss_unit_distance: float = 10.,
**kwargs,
) -> torch.Tensor:
pred_aff = T.from_tensor(pred_aff_tensor)
gt_aff = T.from_tensor(batch["backbone_affine_tensor"])
backbone_mask = batch["backbone_affine_mask"]
pred_aff = T.from_tensor(traj)
gt_aff = T.from_tensor(backbone_affine_tensor)
fape_loss = compute_fape(
pred_aff,
gt_aff,
backbone_mask,
gt_aff[..., None, :],
backbone_affine_mask[..., None, :],
pred_aff.get_trans(),
gt_aff.get_trans(),
backbone_mask,
gt_aff[..., None, :].get_trans(),
backbone_affine_mask[..., None, :],
l1_clamp_distance=clamp_distance,
length_scale=loss_unit_distance,
)
if('use_clamped_fape' in batch):
use_clamped_fape = batch["use_clamped_fape"]
if(use_clamped_fape is not None):
unclamped_fape_loss = compute_fape(
pred_aff,
gt_aff,
backbone_mask,
gt_aff[..., None, :],
backbone_affine_mask[..., None, :],
pred_aff.get_trans(),
gt_aff.get_trans(),
backbone_mask,
gt_aff[..., None, :].get_trans(),
backbone_affine_mask[..., None, :],
l1_clamp_distance=None,
length_scale=loss_unit_distance,
)
fape_loss = (
fape_loss * use_clamped_fape +
fape_loss_unclamped * (1 - use_clamped_fape)
unclamped_fape_loss * (1 - use_clamped_fape)
)
return torch.mean(fape_loss, dim=-1)
def sidechain_loss(
sidechain_frames,
sidechain_atom_pos,
rigidgroups_gt_frames,
rigidgroups_alt_gt_frames,
rigidgroups_gt_exists,
renamed_atom14_gt_positions,
renamed_atom14_gt_exists,
alt_naming_is_better,
clamp_distance=10.,
length_scale=10.,
):
sidechain_frames: torch.Tensor,
sidechain_atom_pos: torch.Tensor,
rigidgroups_gt_frames: torch.Tensor,
rigidgroups_alt_gt_frames: torch.Tensor,
rigidgroups_gt_exists: torch.Tensor,
renamed_atom14_gt_positions: torch.Tensor,
renamed_atom14_gt_exists: torch.Tensor,
alt_naming_is_better: torch.Tensor,
clamp_distance: float = 10.,
length_scale: float = 10.,
**kwargs,
) -> torch.Tensor:
renamed_gt_frames = (
(1. - alt_naming_is_better[..., None, None, None, None]) *
gt_frames +
rigidgroups_gt_frames +
alt_naming_is_better[..., None, None, None, None] *
alt_gt_frames
rigidgroups_alt_gt_frames
)
sidechain_frames = T.from_4x4(sidechain_frames)
renamed_gt_frames = T.from_4x4(renamed_gt_frames)
fape = compute_fape(
......@@ -192,16 +196,13 @@ def fape_loss(
config: ml_collections.ConfigDict,
) -> torch.Tensor:
bb_loss = backbone_loss(
batch, out["sm"]["frames"][-1], **config.backbone
traj=out["sm"]["frames"], **{**batch, **config.backbone},
)
sc_loss = sidechain_loss(
out["sm"]["sidechain_frames"],
out["sm"]["positions"],
{
**batch,
**config.sidechain,
},
**{**batch, **config.sidechain}
)
return (
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import numpy as np
from scipy.spatial.transform import Rotation
def random_template_feats(n_templ, n, batch_size=None):
......@@ -35,6 +36,7 @@ def random_template_feats(n_templ, n, batch_size=None):
batch["template_aatype"] = batch["template_aatype"].astype(np.int64)
return batch
def random_extra_msa_feats(n_extra, n, batch_size=None):
b = []
if(batch_size is not None):
......@@ -50,3 +52,34 @@ def random_extra_msa_feats(n_extra, n, batch_size=None):
np.random.randint(0, 2, (*b, n_extra, n)).astype(np.float32),
}
return batch
def random_affine_vectors(dim):
prod_dim = 1
for d in dim:
prod_dim *= d
affines = np.zeros((prod_dim, 7))
for i in range(prod_dim):
affines[i, :4] = Rotation.random(random_state=42).as_quat()
affines[i, 4:] = np.random.rand(3,)
return affines.reshape(*dim, 7)
def random_affine_4x4s(dim):
prod_dim = 1
for d in dim:
prod_dim *= d
affines = np.zeros((prod_dim, 4, 4))
for i in range(prod_dim):
affines[i, :3, :3] = Rotation.random(random_state=42).as_matrix()
affines[i, :3, 3] = np.random.rand(3,)
affines[:, 3, 3] = 1
return affines.reshape(*dim, 4, 4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment