compute.h 316 Bytes
Newer Older
1
2
3
4
#pragma once

#include <torch/script.h>

5
std::tuple<torch::Tensor, std::optional<torch::Tensor>> rnnt_loss(
6
7
    torch::Tensor& logits,
    const torch::Tensor& targets,
8
9
    const torch::Tensor& logit_lengths,
    const torch::Tensor& target_lengths,
10
    int64_t blank,
11
12
    double clamp,
    bool fused_log_softmax);