r"""Fusing the last linear layer with generalized JSD
Handle the forward and backward pass of the final linear layer via JSD by avoiding
the materialization of the large logits tensor.
Args:
jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
ignore_index (int): The index to ignore in the target. Default: `-100`
temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
Shape:
- student_input: :math:`(BT, H)`, where B is batch size, T is sequence length, H is hidden dimension.
- student_weight: :math:`(V, H)`, where V is vocab size.
- teacher_input: :math:`(BT, H')`, where H' is hidden dimension of the teacher model.
- teacher_weight: :math:`(V, H')`, where hidden size H and H' can be different.
As all the other losses in PyTorch, this function expects the first argument,
:attr:`log_q`, to be the predictions, the output of the student model in log-space,
and the second, :attr:`log_p`, to be the observations, the output of the teacher model in log-space.
This differs from the standard mathematical notation :math:`JSD(P || Q)` where
:math:`P` denotes the teacher model and :math:`Q` denotes the student model.
Args:
beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
ignore_index (int): The index to ignore in the target. Default: `-100`
Shape:
- Input: :math:`(BT, V)`, where B is batch size, T is sequence length, V is vocab size.
- Target: :math:`(BT, V)`, same shape as the input.
- shift_labels (Optional): :math:`(BT,)`
- Output: a scalar.
Examples:
```python
>>> (B, T, V) = (2, 2, 5)
>>> jsd = LigerJSD(beta=0.1)
>>> # input should be a distribution in the log space