Commit 896f8935 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Remove .item() calls

parent 1dcff6aa
......@@ -50,7 +50,7 @@ run e.g.
```bash
python3 run_pretrained_openfold.py \
test.fasta \
target.fasta \
data/uniref90/uniref90.fasta \
data/mgnify/mgy_clusters_2018_12.fa \
data/pdb70/pdb70 \
......
......@@ -1103,7 +1103,7 @@ def random_crop_to_size(
else:
num_templates = protein["aatype"].new_zeros((1,))
num_res_crop_size = min(seq_length.item(), crop_size)
num_res_crop_size = min(int(seq_length), crop_size)
# We want each ensemble to be cropped the same way
g = torch.Generator(device=protein["seq_length"].device)
......@@ -1111,13 +1111,13 @@ def random_crop_to_size(
g.manual_seed(seed)
def _randint(lower, upper):
return torch.randint(
return int(torch.randint(
lower,
upper + 1,
(1,),
device=protein["seq_length"].device,
generator=g,
)[0].item()
)[0])
if subsample_templates:
templates_crop_start = _randint(0, num_templates)
......
......@@ -156,7 +156,7 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg):
tensors = compose(nonensembled_transform_fns(common_cfg, mode_cfg))(tensors)
num_ensemble = mode_cfg.num_ensemble
num_recycling = tensors["no_recycling_iters"].item()
num_recycling = int(tensors["no_recycling_iters"])
if common_cfg.resample_msa_in_recycling:
# Separate batch per ensembling & recycling step.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment