#include <torch/extension.h>



/*
  Forward of mutual_information.  See also """... """ comment of
  `mutual_information` in mutual_information.py.  This It is the core recursion
  in the sequence-to-sequence mutual information computation.

  Args:
      px:     Tensor of shape [B][S][T + 1]; contains the log-odds ratio of
              generating the next x in the sequence, i.e.
              xy[b][s][t] is the log of
                 p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
             i.e. the log-prob of generating x_s given subsequences of lengths
             (s, t), divided by the prior probability of generating x_s.  (See
              mutual_information.py for more info).
      py:     The log-odds ratio of generating the next y in the sequence.
              Shape [B][S + 1][T]
      p:      This function writes to p[b][s][t] the mutual information between
              sub-sequences of x and y of length s and t respectively, from the
              b'th sequences in the batch.  Its shape is [B][S + 1][T + 1].
              Concretely, this function implements the following recursion,
              in the case where s_begin == t_begin == 0:

               p[b,0,0] = 0.0
               p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
                                  p[b,s,t-1] + py[b,s,t-1])
                       if s > 0 or t > 0,
                       treating values with any -1 index as -infinity.
              .. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
   boundary:  If set, a tensor of shape [B][4] of type int64_t, which
              contains, where for each batch element b, boundary[b] equals
              [s_begin, t_begin, s_end, t_end]
              which are the beginning and end (i.e. one-past-the-last) of the
              x and y sequences that we should process.  If not set, these
              default to (0, 0, S, T); and they should not exceed these bounds.
     ans: a tensor `ans` of shape [B], where this function will set
            ans[b] = p[b][s_end][t_end],
            with s_end and t_end being (S, T) if `boundary` was specified,
            and (boundary[b][2], boundary[b][3]) otherwise.
            `ans` represents the mutual information between each pair of
            sequences (i.e. x[b] and y[b], although the sequences are not
            supplied directy to this function).

   The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
   be at least 128.
*/
torch::Tensor mutual_information_cuda(torch::Tensor px,  // [B][S][T+1]
                                      torch::Tensor py,  // [B][S+1][T]
                                      std::optional<torch::Tensor> boundary_info,  // [B][4], int64_t.
                                      torch::Tensor p);  //  [B][S+1][T+1]; an output


/*
  backward of mutual_information; returns (grad_px, grad_py)

  if overwrite_ans_grad == true, this function will overwrite ans_grad with a
  value that, if the computation worked correctly, should be identical to or
  very close to the value of ans_grad at entry.  This can be used
  to validate the correctness of this code.
*/
std::vector<torch::Tensor> mutual_information_backward_cuda(
    torch::Tensor px,
    torch::Tensor py,
    std::optional<torch::Tensor> boundary_info,
    torch::Tensor p,
    torch::Tensor ans_grad,
    bool overwrite_ans_grad);




PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("mutual_information_cuda", &mutual_information_cuda, "Mutual information forward function (CUDA)");
  m.def("mutual_information_backward_cuda", &mutual_information_backward_cuda, "Mutual information backward function (CUDA)");
}
