Commit 07421c47 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

start working on masked_msa_loss error

parent 1df55b4e
...@@ -542,6 +542,7 @@ def lddt_loss( ...@@ -542,6 +542,7 @@ def lddt_loss(
eps=eps eps=eps
) )
score = torch.nan_to_num(score,nan=torch.nanmean(score)) score = torch.nan_to_num(score,nan=torch.nanmean(score))
score[score<0] = 0
score = score.detach() score = score.detach()
bin_index = torch.floor(score * no_bins).long() bin_index = torch.floor(score * no_bins).long()
bin_index = torch.clamp(bin_index, max=(no_bins - 1)) bin_index = torch.clamp(bin_index, max=(no_bins - 1))
...@@ -1605,6 +1606,7 @@ def masked_msa_loss(logits, true_msa, bert_mask, num_classes, eps=1e-8, **kwargs ...@@ -1605,6 +1606,7 @@ def masked_msa_loss(logits, true_msa, bert_mask, num_classes, eps=1e-8, **kwargs
Returns: Returns:
Masked MSA loss Masked MSA loss
""" """
print(f"line 1609 logits shape: {logits.shape} and num_classes: {num_classes}")
errors = softmax_cross_entropy( errors = softmax_cross_entropy(
logits, torch.nn.functional.one_hot(true_msa, num_classes=num_classes) logits, torch.nn.functional.one_hot(true_msa, num_classes=num_classes)
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment