Commit 43941b82 authored by Arkadiusz Nowaczynski's avatar Arkadiusz Nowaczynski
Browse files

Set clamped vs unclamped FAPE for each sample in batch independently

parent a1f77ad0
...@@ -440,11 +440,6 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -440,11 +440,6 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
stage_cfg = self.config[self.stage] stage_cfg = self.config[self.stage]
max_iters = self.config.common.max_recycling_iters max_iters = self.config.common.max_recycling_iters
if(stage_cfg.supervised):
clamp_prob = self.config.supervised.clamp_prob
keyed_probs.append(
("use_clamped_fape", [1 - clamp_prob, clamp_prob])
)
if(stage_cfg.uniform_recycling): if(stage_cfg.uniform_recycling):
recycling_probs = [ recycling_probs = [
......
...@@ -94,6 +94,21 @@ def np_example_to_features( ...@@ -94,6 +94,21 @@ def np_example_to_features(
cfg[mode], cfg[mode],
) )
if mode == "train":
p = torch.rand(1).item()
use_clamped_fape_value = float(p < cfg.supervised.clamp_prob)
features["use_clamped_fape"] = torch.full(
size=[cfg.common.max_recycling_iters + 1],
fill_value=use_clamped_fape_value,
dtype=torch.float32,
)
else:
features["use_clamped_fape"] = torch.full(
size=[cfg.common.max_recycling_iters + 1],
fill_value=0.0,
dtype=torch.float32,
)
return {k: v for k, v in features.items()} return {k: v for k, v in features.items()}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment