binding.cu 3.41 KB
Newer Older
lishen's avatar
lishen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include <iostream>
#include <vector>
#include <numeric>
#include <torch/extension.h>

#ifdef WARPCTC_ENABLE_GPU
	#include "ATen/cuda/CUDAContext.h"
	#include <c10/cuda/CUDAGuard.h>
	#include "ATen/cuda/CUDAEvent.h"
    #include <THC/THCGeneral.h>

    extern THCState* state;
#endif

#include "ctc.h"

int cpu_ctc(torch::Tensor probs,
            torch::Tensor grads,
            torch::Tensor labels,
            torch::Tensor label_sizes,
            torch::Tensor sizes,
            int minibatch_size,
            torch::Tensor costs,
            int blank_label)
{
    float* probs_ptr       = (float*)probs.data_ptr();
    float* grads_ptr       = grads.storage() ? (float*)grads.data_ptr() : NULL;
    int*   sizes_ptr       = (int*)sizes.data_ptr();
    int*   labels_ptr      = (int*)labels.data_ptr();
    int*   label_sizes_ptr = (int*)label_sizes.data_ptr();
    float* costs_ptr       = (float*)costs.data_ptr();

    const int probs_size = probs.size(2);

    ctcOptions options;
    memset(&options, 0, sizeof(options));
    options.loc = CTC_CPU;
    options.num_threads = 0; // will use default number of threads
    options.blank_label = blank_label;

#if defined(CTC_DISABLE_OMP) || defined(APPLE)
    // have to use at least one
    options.num_threads = std::max(options.num_threads, (unsigned int) 1);
#endif

    size_t cpu_size_bytes;
    get_workspace_size(label_sizes_ptr, sizes_ptr,
                       probs_size, minibatch_size,
                       options, &cpu_size_bytes);

    float* cpu_workspace = new float[cpu_size_bytes / sizeof(float)];

    compute_ctc_loss(probs_ptr, grads_ptr,
                     labels_ptr, label_sizes_ptr,
                     sizes_ptr, probs_size,
                     minibatch_size, costs_ptr,
                     cpu_workspace, options);

    delete[] cpu_workspace;
    return 1;
}


#ifdef WARPCTC_ENABLE_GPU
int gpu_ctc(torch::Tensor probs,
            torch::Tensor grads,
            torch::Tensor labels,
            torch::Tensor label_sizes,
            torch::Tensor sizes,
            int minibatch_size,
            torch::Tensor costs,
            int blank_label)
{
    float* probs_ptr       = (float*)probs.data_ptr();
    float* grads_ptr       = grads.storage() ? (float*)grads.data_ptr() : NULL;
    int*   sizes_ptr       = (int*)sizes.data_ptr();
    int*   labels_ptr      = (int*)labels.data_ptr();
    int*   label_sizes_ptr = (int*)label_sizes.data_ptr();
    float* costs_ptr       = (float*)costs.data_ptr();

    const int probs_size = probs.size(2);

    ctcOptions options;
    memset(&options, 0, sizeof(options));
    options.loc = CTC_GPU;
    options.blank_label = blank_label;
    options.stream = at::cuda::getCurrentCUDAStream();

    size_t gpu_size_bytes;
    get_workspace_size(label_sizes_ptr, sizes_ptr,
                       probs_size, minibatch_size,
                       options, &gpu_size_bytes);

    void* gpu_workspace = THCudaMalloc(state, gpu_size_bytes);

    compute_ctc_loss(probs_ptr, grads_ptr,
                     labels_ptr, label_sizes_ptr,
                     sizes_ptr, probs_size,
                     minibatch_size, costs_ptr,
                     gpu_workspace, options);

    THCudaFree(state, (void *) gpu_workspace);
    return 1;
}
#endif

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("cpu_ctc", &cpu_ctc, "CTC Loss function with cpu");
#ifdef WARPCTC_ENABLE_GPU
  m.def("gpu_ctc", &gpu_ctc, "CTC Loss function with gpu");
#endif
}