Unverified Commit 75c53890 authored by Wenhao Chen's avatar Wenhao Chen Committed by GitHub
Browse files

[chat] fix compute_approx_kl (#4338)

parent 03654c0c
......@@ -19,7 +19,7 @@ def compute_approx_kl(log_probs: torch.Tensor,
action_mask: Mask for actions.
"""
log_ratio = log_probs - log_probs_base
log_ratio = log_probs_base - log_probs
approx_kl = (log_ratio.exp() - 1) - log_ratio
if action_mask is not None:
approx_kl = masked_mean(approx_kl, action_mask, dim=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment