Commit 896f8935 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Remove .item() calls

parent 1dcff6aa
...@@ -50,7 +50,7 @@ run e.g. ...@@ -50,7 +50,7 @@ run e.g.
```bash ```bash
python3 run_pretrained_openfold.py \ python3 run_pretrained_openfold.py \
test.fasta \ target.fasta \
data/uniref90/uniref90.fasta \ data/uniref90/uniref90.fasta \
data/mgnify/mgy_clusters_2018_12.fa \ data/mgnify/mgy_clusters_2018_12.fa \
data/pdb70/pdb70 \ data/pdb70/pdb70 \
......
...@@ -1103,7 +1103,7 @@ def random_crop_to_size( ...@@ -1103,7 +1103,7 @@ def random_crop_to_size(
else: else:
num_templates = protein["aatype"].new_zeros((1,)) num_templates = protein["aatype"].new_zeros((1,))
num_res_crop_size = min(seq_length.item(), crop_size) num_res_crop_size = min(int(seq_length), crop_size)
# We want each ensemble to be cropped the same way # We want each ensemble to be cropped the same way
g = torch.Generator(device=protein["seq_length"].device) g = torch.Generator(device=protein["seq_length"].device)
...@@ -1111,13 +1111,13 @@ def random_crop_to_size( ...@@ -1111,13 +1111,13 @@ def random_crop_to_size(
g.manual_seed(seed) g.manual_seed(seed)
def _randint(lower, upper): def _randint(lower, upper):
return torch.randint( return int(torch.randint(
lower, lower,
upper + 1, upper + 1,
(1,), (1,),
device=protein["seq_length"].device, device=protein["seq_length"].device,
generator=g, generator=g,
)[0].item() )[0])
if subsample_templates: if subsample_templates:
templates_crop_start = _randint(0, num_templates) templates_crop_start = _randint(0, num_templates)
......
...@@ -156,7 +156,7 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg): ...@@ -156,7 +156,7 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg):
tensors = compose(nonensembled_transform_fns(common_cfg, mode_cfg))(tensors) tensors = compose(nonensembled_transform_fns(common_cfg, mode_cfg))(tensors)
num_ensemble = mode_cfg.num_ensemble num_ensemble = mode_cfg.num_ensemble
num_recycling = tensors["no_recycling_iters"].item() num_recycling = int(tensors["no_recycling_iters"])
if common_cfg.resample_msa_in_recycling: if common_cfg.resample_msa_in_recycling:
# Separate batch per ensembling & recycling step. # Separate batch per ensembling & recycling step.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment