Commit a60ed4d7 authored by lishen's avatar lishen
Browse files

warpctc from github

parent 949fcc19
/*
int cpu_ctc(THFloatTensor *probs,
THFloatTensor *grads,
THIntTensor *labels_ptr,
THIntTensor *label_sizes_ptr,
THIntTensor *sizes,
int minibatch_size,
THFloatTensor *costs,
int blank_label);
*/
int cpu_ctc(torch::Tensor probs,
torch::Tensor grads,
torch::Tensor labels,
torch::Tensor label_sizes,
torch::Tensor sizes,
int minibatch_size,
torch::Tensor costs,
int blank_label);
/*
int gpu_ctc(THCudaTensor *probs,
THCudaTensor *grads,
THIntTensor *labels_ptr,
THIntTensor *label_sizes_ptr,
THIntTensor *sizes,
int minibatch_size,
THFloatTensor *costs,
int blank_label);
*/
int gpu_ctc(torch::Tensor probs,
torch::Tensor grads,
torch::Tensor labels,
torch::Tensor label_sizes,
torch::Tensor sizes,
int minibatch_size,
torch::Tensor costs,
int blank_label);
import torch
import warpctc_pytorch as warp_ctc
import pytest
def test_simple():
probs = torch.FloatTensor([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]]]).transpose(0, 1).contiguous()
grads = torch.zeros(probs.size())
labels = torch.IntTensor([1, 2])
label_sizes = torch.IntTensor([2])
sizes = torch.IntTensor(probs.size(1)).fill_(probs.size(0))
minibatch_size = probs.size(1)
costs = torch.zeros(minibatch_size)
warp_ctc.cpu_ctc(probs,
grads,
labels,
label_sizes,
sizes,
minibatch_size,
costs,
0)
print('CPU_cost: %f' % costs.sum())
@pytest.mark.parametrize("multiplier", [1.0, 200.0])
def test_medium(multiplier):
probs = torch.FloatTensor([
[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]],
[[0.6, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.5, 0.2, 0.1]]
]).contiguous() * multiplier
grads = torch.zeros(probs.size())
labels = torch.IntTensor([1, 2, 1, 2])
label_sizes = torch.IntTensor([2, 2])
sizes = torch.IntTensor([2, 2])
minibatch_size = probs.size(1)
costs = torch.zeros(minibatch_size)
warp_ctc.cpu_ctc(probs,
grads,
labels,
label_sizes,
sizes,
minibatch_size,
costs,
0)
print('CPU_cost: %f' % costs.sum())
def test_empty_label():
probs = torch.FloatTensor([
[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]],
[[0.6, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.5, 0.2, 0.1]]
]).contiguous()
grads = torch.zeros(probs.size())
labels = torch.IntTensor([1, 2])
label_sizes = torch.IntTensor([2, 0])
sizes = torch.IntTensor([2, 2])
minibatch_size = probs.size(1)
costs = torch.zeros(minibatch_size)
warp_ctc.cpu_ctc(probs,
grads,
labels,
label_sizes,
sizes,
minibatch_size,
costs,
0)
print('CPU_cost: %f' % costs.sum())
def test_CTCLoss():
probs = torch.FloatTensor([[
[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]
]]).transpose(0, 1).contiguous()
labels = torch.IntTensor([1, 2])
label_sizes = torch.IntTensor([2])
probs_sizes = torch.IntTensor([2])
probs.requires_grad_(True)
ctc_loss = warp_ctc.CTCLoss()
cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
cost.backward()
if __name__ == '__main__':
pytest.main([__file__])
import torch
import warpctc_pytorch as warp_ctc
import pytest
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU")
def test_simple():
probs = torch.FloatTensor([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]]]).transpose(0, 1).contiguous()
grads = torch.zeros(probs.size())
labels = torch.IntTensor([1, 2])
label_sizes = torch.IntTensor([2])
sizes = torch.IntTensor(probs.size(1)).fill_(probs.size(0))
minibatch_size = probs.size(1)
costs = torch.zeros(minibatch_size)
warp_ctc.cpu_ctc(probs,
grads,
labels,
label_sizes,
sizes,
minibatch_size,
costs,
0)
print('CPU_cost: %f' % costs.sum())
probs = probs.clone().cuda()
grads = torch.zeros(probs.size()).cuda()
costs = torch.zeros(minibatch_size)
warp_ctc.gpu_ctc(probs,
grads,
labels,
label_sizes,
sizes,
minibatch_size,
costs,
0)
print('GPU_cost: %f' % costs.sum())
print(grads.view(grads.size(0) * grads.size(1), grads.size(2)))
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU")
@pytest.mark.parametrize("multiplier", [1.0, 200.0])
def test_medium(multiplier):
probs = torch.FloatTensor([
[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]],
[[0.6, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.5, 0.2, 0.1]]
]).contiguous() * multiplier
grads = torch.zeros(probs.size())
labels = torch.IntTensor([1, 2, 1, 2])
label_sizes = torch.IntTensor([2, 2])
sizes = torch.IntTensor([2, 2])
minibatch_size = probs.size(1)
costs = torch.zeros(minibatch_size)
warp_ctc.cpu_ctc(probs,
grads,
labels,
label_sizes,
sizes,
minibatch_size,
costs,
0)
print('CPU_cost: %f' % costs.sum())
probs = probs.clone().cuda()
grads = torch.zeros(probs.size()).cuda()
costs = torch.zeros(minibatch_size)
warp_ctc.gpu_ctc(probs,
grads,
labels,
label_sizes,
sizes,
minibatch_size,
costs,
0)
print('GPU_cost: %f' % costs.sum())
print(grads.view(grads.size(0) * grads.size(1), grads.size(2)))
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU")
def test_empty_label():
probs = torch.FloatTensor([
[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]],
[[0.6, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.5, 0.2, 0.1]]
]).contiguous()
grads = torch.zeros(probs.size())
labels = torch.IntTensor([1, 2])
label_sizes = torch.IntTensor([2, 0])
sizes = torch.IntTensor([2, 2])
minibatch_size = probs.size(1)
costs = torch.zeros(minibatch_size)
warp_ctc.cpu_ctc(probs,
grads,
labels,
label_sizes,
sizes,
minibatch_size,
costs,
0)
print('CPU_cost: %f' % costs.sum())
probs = probs.clone().cuda()
grads = torch.zeros(probs.size()).cuda()
costs = torch.zeros(minibatch_size)
warp_ctc.gpu_ctc(probs,
grads,
labels,
label_sizes,
sizes,
minibatch_size,
costs,
0)
print('GPU_cost: %f' % costs.sum())
print(grads.view(grads.size(0) * grads.size(1), grads.size(2)))
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU")
def test_CTCLoss():
probs = torch.FloatTensor([[
[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]
]]).transpose(0, 1).contiguous().cuda()
labels = torch.IntTensor([1, 2])
label_sizes = torch.IntTensor([2])
probs_sizes = torch.IntTensor([2])
probs.requires_grad_(True)
ctc_loss = warp_ctc.CTCLoss()
cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
cost.backward()
if __name__ == '__main__':
pytest.main([__file__])
import torch
import warpctc_pytorch as warp_ctc
from torch.autograd import Function
from torch.nn import Module
from ._warp_ctc import * # noqa
def _assert_no_grad(tensor):
assert not tensor.requires_grad, \
"gradients only computed for acts - please " \
"mark other tensors as not requiring gradients"
class _CTC(Function):
@staticmethod
def forward(ctx, acts, labels, act_lens, label_lens, size_average=False,
length_average=False, blank=0):
is_cuda = True if acts.is_cuda else False
acts = acts.contiguous()
loss_func = warp_ctc.gpu_ctc if is_cuda else warp_ctc.cpu_ctc
grads = torch.zeros(acts.size()).type_as(acts)
minibatch_size = acts.size(1)
costs = torch.zeros(minibatch_size).cpu()
loss_func(acts,
grads,
labels,
label_lens,
act_lens,
minibatch_size,
costs,
blank)
costs = torch.FloatTensor([costs.sum()])
if length_average:
# Compute the avg. log-probability per batch sample and frame.
total_length = torch.sum(act_lens).item()
grads = grads / total_length
costs = costs / total_length
elif size_average:
# Compute the avg. log-probability per batch sample.
grads = grads / minibatch_size
costs = costs / minibatch_size
ctx.grads = grads
return costs
@staticmethod
def backward(ctx, grad_output):
_grad_output = grad_output.to(ctx.grads.device)
return ctx.grads.mul_(_grad_output), None, None, None, None, None, None
class CTCLoss(Module):
"""
Parameters:
size_average (bool): normalize the loss by the batch size
(default: `False`)
length_average (bool): normalize the loss by the total number of frames
in the batch. If `True`, supersedes `size_average`
(default: `False`)
"""
def __init__(self, blank=0, size_average=False, length_average=False):
super(CTCLoss, self).__init__()
self.ctc = _CTC.apply
self.blank = blank
self.size_average = size_average
self.length_average = length_average
def forward(self, acts, labels, act_lens, label_lens):
"""
acts: Tensor of (seqLength x batch x outputDim) containing output from network
labels: 1 dimensional Tensor containing all the targets of the batch in one sequence
act_lens: Tensor of size (batch) containing size of each output sequence from the network
label_lens: Tensor of (batch) containing label length of each example
"""
assert len(labels.size()) == 1 # labels must be 1 dimensional
_assert_no_grad(labels)
_assert_no_grad(act_lens)
_assert_no_grad(label_lens)
return self.ctc(acts, labels, act_lens, label_lens, self.size_average,
self.length_average, self.blank)
#include <cstddef>
#include <iostream>
#include <algorithm>
#include <ctc.h>
#include "detail/cpu_ctc.h"
#ifdef __CUDACC__
#include "detail/gpu_ctc.h"
#endif
extern "C" {
int get_warpctc_version() {
return 2;
}
const char* ctcGetStatusString(ctcStatus_t status) {
switch (status) {
case CTC_STATUS_SUCCESS:
return "no error";
case CTC_STATUS_MEMOPS_FAILED:
return "cuda memcpy or memset failed";
case CTC_STATUS_INVALID_VALUE:
return "invalid value";
case CTC_STATUS_EXECUTION_FAILED:
return "execution failed";
case CTC_STATUS_UNKNOWN_ERROR:
default:
return "unknown error";
}
}
ctcStatus_t compute_ctc_loss(const float* const activations,
float* gradients,
const int* const flat_labels,
const int* const label_lengths,
const int* const input_lengths,
int alphabet_size,
int minibatch,
float *costs,
void *workspace,
ctcOptions options) {
if (activations == nullptr ||
flat_labels == nullptr ||
label_lengths == nullptr ||
input_lengths == nullptr ||
costs == nullptr ||
workspace == nullptr ||
alphabet_size <= 0 ||
minibatch <= 0)
return CTC_STATUS_INVALID_VALUE;
if (options.loc == CTC_CPU) {
CpuCTC<float> ctc(alphabet_size, minibatch, workspace, options.num_threads,
options.blank_label);
if (gradients != NULL)
return ctc.cost_and_grad(activations, gradients,
costs,
flat_labels, label_lengths,
input_lengths);
else
return ctc.score_forward(activations, costs, flat_labels,
label_lengths, input_lengths);
} else if (options.loc == CTC_GPU) {
#ifdef __CUDACC__
GpuCTC<float> ctc(alphabet_size, minibatch, workspace, options.stream,
options.blank_label);
if (gradients != NULL)
return ctc.cost_and_grad(activations, gradients, costs,
flat_labels, label_lengths,
input_lengths);
else
return ctc.score_forward(activations, costs, flat_labels,
label_lengths, input_lengths);
#else
std::cerr << "GPU execution requested, but not compiled with GPU support" << std::endl;
return CTC_STATUS_EXECUTION_FAILED;
#endif
} else {
return CTC_STATUS_INVALID_VALUE;
}
}
ctcStatus_t get_workspace_size(const int* const label_lengths,
const int* const input_lengths,
int alphabet_size, int minibatch,
ctcOptions options,
size_t* size_bytes)
{
if (label_lengths == nullptr ||
input_lengths == nullptr ||
size_bytes == nullptr ||
alphabet_size <= 0 ||
minibatch <= 0)
return CTC_STATUS_INVALID_VALUE;
// This is the max of all S and T for all examples in the minibatch.
int maxL = *std::max_element(label_lengths, label_lengths + minibatch);
int maxT = *std::max_element(input_lengths, input_lengths + minibatch);
const int S = 2 * maxL + 1;
*size_bytes = 0;
if (options.loc == CTC_GPU) {
// GPU storage
//nll_forward, nll_backward
*size_bytes += 2 * sizeof(float) * minibatch;
//repeats
*size_bytes += sizeof(int) * minibatch;
//label offsets
*size_bytes += sizeof(int) * minibatch;
//utt_length
*size_bytes += sizeof(int) * minibatch;
//label lengths
*size_bytes += sizeof(int) * minibatch;
//labels without blanks - overallocate for now
*size_bytes += sizeof(int) * maxL * minibatch;
//labels with blanks
*size_bytes += sizeof(int) * S * minibatch;
//alphas
*size_bytes += sizeof(float) * S * maxT * minibatch;
//denoms
*size_bytes += sizeof(float) * maxT * minibatch;
//probs (since we will pass in activations)
*size_bytes += sizeof(float) * alphabet_size * maxT * minibatch;
} else {
//cpu can eventually replace all minibatch with
//max number of concurrent threads if memory is
//really tight
//per minibatch memory
size_t per_minibatch_bytes = 0;
//output
per_minibatch_bytes += sizeof(float) * alphabet_size ;
//alphas
per_minibatch_bytes += sizeof(float) * S * maxT;
//betas
per_minibatch_bytes += sizeof(float) * S;
//labels w/blanks, e_inc, s_inc
per_minibatch_bytes += 3 * sizeof(int) * S;
*size_bytes = per_minibatch_bytes * minibatch;
//probs
*size_bytes += sizeof(float) * alphabet_size * maxT * minibatch;
}
return CTC_STATUS_SUCCESS;
}
}
ctc_entrypoint.cpp
\ No newline at end of file
// Includes, system
// #include <stdio.h>
// #include <stdlib.h>
// Includes, cuda
// #include <cuda_runtime.h>
// #include <cublas_v2.h>
// Includes, cuda helper functions
// #include <helper_cuda.h>
// For the functors
#include "detail/ctc_helper.h"
#include "ctc.h"
const int warp_size = 32;
template<int NT, typename T, typename Rop>
struct CTAReduce;
template<int NT, typename T, typename Rop>
struct CTAReduce {
enum { Size = NT, Capacity = NT };
struct Storage { T shared[Capacity]; };
__device__ static T reduce(int tid, T x, Storage& storage, int count, Rop g) {
T* s = storage.shared;
s[tid] = x;
__syncthreads();
// Fold the data in half with each pass.
#pragma unroll
for(int offset = NT / 2; offset >= warp_size; offset /= 2) {
if(tid + offset < count && tid < offset) {
// Read from the right half and store to the left half.
x = g(x, s[offset + tid]);
s[tid] = x;
}
__syncthreads();
}
T shuff;
for (int offset = warp_size / 2; offset > 0; offset /= 2) {
shuff = __shfl_down_sync(0xFFFFFFFF, x, offset);
if (tid + offset < count && tid < offset)
x = g(x, shuff);
}
return x;
}
};
template <int NT, typename Iop, typename Rop, typename T>
__global__ void reduce_rows(Iop f, Rop g, const T* input, T* output,
int num_rows, int num_cols) {
typedef CTAReduce<NT, T, Rop> R;
__shared__ typename R::Storage storage;
int tid = threadIdx.x;
int idx = tid;
int col = blockIdx.x;
T curr;
// Each block works on a column
if (idx < num_rows)
curr = f(input[idx + col*num_rows]);
idx += NT;
while (idx < num_rows) {
curr = g(curr, f(input[idx + col*num_rows]));
idx += NT;
}
// Sum thread-totals over the CTA.
curr = R::reduce(tid, curr, storage, num_rows, g);
// Store result in out
if (tid == 0)
output[col] = curr;
}
template <int NT, typename Iop, typename Rop, typename T>
__global__ void reduce_cols(Iop f, Rop g, const T* input, T* output,
int num_rows, int num_cols) {
__shared__ T s[NT];
int warps_per_block = NT / warp_size;
int row = blockDim.x * blockIdx.x + threadIdx.x;
int col = threadIdx.y;
T curr;
if (row < num_rows && col < num_cols) {
curr = f(input[row + col*num_rows]);
col += blockDim.y;
while (col < num_cols) {
curr = g(curr, f(input[row + col*num_rows]));
col += blockDim.y;
}
}
s[threadIdx.x * warps_per_block + threadIdx.y] = curr;
__syncthreads();
// Reduce
if (threadIdx.y == 0 && row < num_rows) {
#pragma unroll
for (int i = 1; i < warps_per_block && i < num_cols; ++i)
curr = g(curr, s[i + threadIdx.x * warps_per_block]);
output[row] = curr;
}
}
struct ReduceHelper {
template<typename T, typename Iof, typename Rof>
static void impl(Iof f, Rof g, const T* input, T* output, int num_rows, int num_cols, bool axis, cudaStream_t stream) {
int grid_size;
if (axis) {
grid_size = num_cols;
reduce_rows<128><<<grid_size, 128, 0, stream>>>
(f, g, input, output, num_rows, num_cols);
} else {
dim3 tpb(warp_size, 128 / warp_size);
grid_size = (num_cols + warp_size - 1)/warp_size;
reduce_cols<128><<<grid_size, tpb, 0, stream>>>
(f, g, input, output, num_rows, num_cols);
}
}
};
template<typename T, typename Iof, typename Rof>
ctcStatus_t reduce(Iof f, Rof g, const T* input, T* output, int rows, int cols, bool axis, cudaStream_t stream) {
ReduceHelper::impl(f, g, input, output, rows, cols, axis, stream);
cudaStreamSynchronize(stream);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
return CTC_STATUS_EXECUTION_FAILED;
return CTC_STATUS_SUCCESS;
}
ctcStatus_t reduce_negate(const float *input, float *output, int rows, int cols, bool axis, cudaStream_t stream) {
return reduce(ctc_helper::negate<float>(), ctc_helper::add<float>(), input, output, rows, cols, axis, stream);
}
ctcStatus_t reduce_exp(const float *input, float *output, int rows, int cols, bool axis, cudaStream_t stream) {
return reduce(ctc_helper::exponential<float>(), ctc_helper::add<float>(), input, output, rows, cols, axis, stream);
}
ctcStatus_t reduce_max(const float *input, float *output, int rows, int cols, bool axis, cudaStream_t stream) {
return reduce(ctc_helper::identity<float>(), ctc_helper::maximum<float>(),input, output, rows, cols, axis, stream);
}
#include <vector>
#include <random>
std::vector<float>
genActs(int size) {
std::vector<float> arr(size);
std::mt19937 gen(0);
std::uniform_real_distribution<> dis(0, 1);
for(int i = 0; i < size; ++i)
arr[i] = dis(gen);
return arr;
}
std::vector<int>
genLabels(int alphabet_size, int L) {
std::vector<int> label(L);
std::mt19937 gen(1);
std::uniform_int_distribution<> dis(1, alphabet_size - 1);
for(int i = 0; i < L; ++i) {
label[i] = dis(gen);
}
// guarantee repeats for testing
if (L >= 3) {
label[L / 2] = label[L / 2 + 1];
label[L / 2 - 1] = label[L / 2];
}
return label;
}
#pragma once
#include <stdexcept>
#include <vector>
#include <limits>
#include <numeric>
#include <ctc.h>
inline void throw_on_error(ctcStatus_t status, const char* message) {
if (status != CTC_STATUS_SUCCESS) {
throw std::runtime_error(message + (", stat = " +
std::string(ctcGetStatusString(status))));
}
}
#ifdef __CUDACC__
#include <thrust/system_error.h>
#include <thrust/system/cuda/error.h>
inline void throw_on_error(cudaError_t error, const char* message) {
if (error) {
throw thrust::system_error(error, thrust::cuda_category(), message);
}
}
#endif
std::vector<float> genActs(int size);
std::vector<int> genLabels(int alphabet_size, int L);
float rel_diff(const std::vector<float>& grad,
const std::vector<float>& num_grad) {
float diff = 0.;
float tot = 0.;
for(size_t idx = 0; idx < grad.size(); ++idx) {
diff += (grad[idx] - num_grad[idx]) * (grad[idx] - num_grad[idx]);
tot += grad[idx] * grad[idx];
}
return diff / tot;
}
// Numerically stable softmax for a minibatch of 1
void softmax(const float* const acts,
int alphabet_size, int T,
float *probs) {
for (int t = 0; t < T; ++t) {
float max_activation =
-std::numeric_limits<float>::infinity();
for (int a = 0; a < alphabet_size; ++a)
max_activation =
std::max(max_activation, acts[t*alphabet_size + a]);
float denom = 0;
for (int a = 0; a < alphabet_size; ++a)
denom += std::exp(acts[t*alphabet_size + a] - max_activation);
for (int a = 0; a < alphabet_size; ++a)
probs[t*alphabet_size + a] =
std::exp(acts[t*alphabet_size + a] - max_activation) / denom;
}
}
#include <cmath>
#include <random>
#include <tuple>
#include <vector>
#include <iostream>
#include <ctc.h>
#include "test.h"
bool small_test() {
const int alphabet_size = 5;
const int T = 2;
std::vector<float> activations = {0.1, 0.6, 0.1, 0.1, 0.1,
0.1, 0.1, 0.6, 0.1, 0.1};
// Calculate the score analytically
float expected_score;
{
std::vector<float> probs(activations.size());
softmax(activations.data(), alphabet_size, T, probs.data());
// Score calculation is specific to the given activations above
expected_score = probs[1] * probs[7];
}
std::vector<int> labels = {1, 2};
std::vector<int> label_lengths = {2};
std::vector<int> lengths;
lengths.push_back(T);
float score;
ctcOptions options{};
options.loc = CTC_CPU;
options.num_threads = 1;
size_t cpu_alloc_bytes;
throw_on_error(get_workspace_size(label_lengths.data(), lengths.data(),
alphabet_size, lengths.size(), options,
&cpu_alloc_bytes),
"Error: get_workspace_size in small_test");
void* ctc_cpu_workspace = malloc(cpu_alloc_bytes);
throw_on_error(compute_ctc_loss(activations.data(), NULL,
labels.data(), label_lengths.data(),
lengths.data(),
alphabet_size,
lengths.size(),
&score,
ctc_cpu_workspace,
options),
"Error: compute_ctc_loss in small_test");
free(ctc_cpu_workspace);
score = std::exp(-score);
const float eps = 1e-6;
const float lb = expected_score - eps;
const float ub = expected_score + eps;
return (score > lb && score < ub);
}
int offset(int t, int n, int a) {
constexpr int minibatch = 2;
constexpr int alphabet_size = 6;
return (t * minibatch + n) * alphabet_size + a;
}
bool options_test() {
const int alphabet_size = 6;
const int T = 5;
const int minibatch = 2;
std::vector<float> activations =
{0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553,
0.30176, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508,
0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436,
0.24082, 0.397533, 0.0557226, 0.0546814, 0.0557528, 0.19549,
0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688,
0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, 0.202456,
0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533,
0.280884, 0.429522, 0.0326593, 0.0339046, 0.0326856, 0.190345,
0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107,
0.423286, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046};
std::vector<float> expected_grads = // from tensorflow
{-0.366234, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553,
-0.69824, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508,
0.111121, -0.411608, 0.278779, 0.0055756, 0.00569609, 0.010436,
0.24082, -0.602467, 0.0557226, 0.0546814, 0.0557528, 0.19549,
0.0357786, 0.633813, -0.678582, 0.00249248, 0.00272882, 0.0037688,
0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, -0.797544,
0.0663296, -0.356151, 0.280111, 0.00283995, 0.0035545, 0.00331533,
0.280884, -0.570478, 0.0326593, 0.0339046, 0.0326856, 0.190345,
-0.541765, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107,
-0.576714, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046};
// Calculate the expected scores analytically
std::vector<double> expected_scores(2);
auto& a = activations;
expected_scores[0] =
-std::log(a[offset(0, 0, 0)] * a[offset(1, 0, 1)] * a[offset(2, 0, 2)]
* a[offset(3, 0, 1)] * a[offset(4, 0, 0)]);
expected_scores[1] = 5.42262; // from tensorflow
// now take the log to account for the softmax
for (auto& a : activations) {
a = std::log(a);
}
std::vector<int> labels = {0, 1, 2, 1, 0,
0, 1, 1, 0};
std::vector<int> label_lengths = {5, 4};
std::vector<int> lengths = {5, 5};
std::vector<float> grads(alphabet_size * T * minibatch);
std::vector<float> scores(2);
ctcOptions options{};
options.loc = CTC_CPU;
options.num_threads = 1;
options.blank_label = 5;
size_t cpu_alloc_bytes;
throw_on_error(get_workspace_size(label_lengths.data(), lengths.data(),
alphabet_size, lengths.size(), options,
&cpu_alloc_bytes),
"Error: get_workspace_size in options_test");
void* ctc_cpu_workspace = malloc(cpu_alloc_bytes);
throw_on_error(compute_ctc_loss(activations.data(), grads.data(),
labels.data(), label_lengths.data(),
lengths.data(),
alphabet_size,
lengths.size(),
scores.data(),
ctc_cpu_workspace,
options),
"Error: compute_ctc_loss in options_test");
free(ctc_cpu_workspace);
const double eps = 1e-4;
bool result = true;
for (int i = 0; i < grads.size(); i++) {
const double lb = expected_grads[i] - eps;
const double ub = expected_grads[i] + eps;
if (!(grads[i] > lb && grads[i] < ub)) {
std::cerr << "grad mismatch in options_test"
<< " expected grad: " << expected_grads[i]
<< " calculated score: " << grads[i]
<< " !(" << lb << " < " << grads[i]
<< " < " << ub << ")" << std::endl;
result = false;
}
}
for (int i = 0; i < 2; i++) {
const double lb = expected_scores[i] - eps;
const double ub = expected_scores[i] + eps;
if (!(scores[i] > lb && scores[i] < ub)) {
std::cerr << "score mismatch in options_test"
<< " expected score: " << expected_scores[i]
<< " calculated score: " << scores[i]
<< " !(" << lb << " < " << scores[i]
<< " < " << ub << ")" << std::endl;
result = false;
}
}
return result;
}
bool inf_test() {
const int alphabet_size = 15;
const int T = 50;
const int L = 10;
const int minibatch = 1;
std::vector<int> labels = genLabels(alphabet_size, L);
labels[0] = 2;
std::vector<int> label_lengths = {L};
std::vector<float> acts = genActs(alphabet_size * T * minibatch);
for (int i = 0; i < T; ++i)
acts[alphabet_size * i + 2] = -1e30;
std::vector<int> sizes;
sizes.push_back(T);
std::vector<float> grads(alphabet_size * T);
float cost;
ctcOptions options{};
options.loc = CTC_CPU;
options.num_threads = 1;
size_t cpu_alloc_bytes;
throw_on_error(get_workspace_size(label_lengths.data(), sizes.data(),
alphabet_size, sizes.size(), options,
&cpu_alloc_bytes),
"Error: get_workspace_size in inf_test");
void* ctc_cpu_workspace = malloc(cpu_alloc_bytes);
throw_on_error(compute_ctc_loss(acts.data(), grads.data(),
labels.data(), label_lengths.data(),
sizes.data(),
alphabet_size,
sizes.size(),
&cost,
ctc_cpu_workspace,
options),
"Error: compute_ctc_loss in inf_test");
free(ctc_cpu_workspace);
bool status = true;
status &= std::isinf(cost);
for (int i = 0; i < alphabet_size * T; ++i)
status &= !std::isnan(grads[i]);
return status;
}
float grad_check(int T, int alphabet_size,
std::vector<float>& acts,
const std::vector<std::vector<int>>& labels,
const std::vector<int>& sizes) {
float epsilon = 1e-2;
const int minibatch = labels.size();
std::vector<int> flat_labels;
std::vector<int> label_lengths;
for (const auto& l : labels) {
flat_labels.insert(flat_labels.end(), l.begin(), l.end());
label_lengths.push_back(l.size());
}
std::vector<float> costs(minibatch);
std::vector<float> grads(acts.size());
ctcOptions options{};
options.loc = CTC_CPU;
options.num_threads = 1;
size_t cpu_alloc_bytes;
throw_on_error(get_workspace_size(label_lengths.data(), sizes.data(),
alphabet_size, sizes.size(), options,
&cpu_alloc_bytes),
"Error: get_workspace_size in grad_check");
void* ctc_cpu_workspace = malloc(cpu_alloc_bytes);
throw_on_error(compute_ctc_loss(acts.data(), grads.data(),
flat_labels.data(), label_lengths.data(),
sizes.data(),
alphabet_size,
minibatch,
costs.data(),
ctc_cpu_workspace,
options),
"Error: compute_ctc_loss (0) in grad_check");
float cost = std::accumulate(costs.begin(), costs.end(), 0.);
std::vector<float> num_grad(grads.size());
//perform 2nd order central differencing
for (int i = 0; i < T * alphabet_size * minibatch; ++i) {
std::vector<float> costsP1(minibatch);
std::vector<float> costsP2(minibatch);
acts[i] += epsilon;
throw_on_error(compute_ctc_loss(acts.data(), NULL,
flat_labels.data(), label_lengths.data(),
sizes.data(),
alphabet_size,
minibatch,
costsP1.data(),
ctc_cpu_workspace,
options),
"Error: compute_ctc_loss (1) in grad_check");
acts[i] -= 2 * epsilon;
throw_on_error(compute_ctc_loss(acts.data(), NULL,
flat_labels.data(), label_lengths.data(),
sizes.data(),
alphabet_size,
minibatch,
costsP2.data(),
ctc_cpu_workspace,
options),
"Error: compute_ctc_loss (2) in grad_check");
float costP1 = std::accumulate(costsP1.begin(), costsP1.end(), 0.);
float costP2 = std::accumulate(costsP2.begin(), costsP2.end(), 0.);
acts[i] += epsilon;
num_grad[i] = (costP1 - costP2) / (2 * epsilon);
}
free(ctc_cpu_workspace);
float diff = rel_diff(grads, num_grad);
return diff;
}
bool run_tests() {
std::vector<std::tuple<int, int, int, int, float>> problem_sizes =
{std::make_tuple(20, 50, 15, 1, 1e-5),
std::make_tuple(5, 10, 5, 65, 1e-4)
};
std::mt19937 gen(2);
bool status = true;
for (auto problem : problem_sizes) {
int alphabet_size, T, L, minibatch;
float tol;
std::tie(alphabet_size, T, L, minibatch, tol) = problem;
std::vector<float> acts = genActs(alphabet_size * T * minibatch);
std::vector<std::vector<int>> labels;
std::vector<int> sizes;
for (int mb = 0; mb < minibatch; ++mb) {
int actual_length = L;
labels.push_back(genLabels(alphabet_size, actual_length));
sizes.push_back(T);
}
float diff = grad_check(T, alphabet_size, acts, labels, sizes);
status &= (diff < tol);
}
return status;
}
int main(void) {
if (get_warpctc_version() != 2) {
std::cerr << "Invalid WarpCTC version." << std::endl;
return 1;
}
std::cout << "Running CPU tests" << std::endl;
bool status = true;
status &= small_test();
status &= options_test();
status &= inf_test();
status &= run_tests();
if (status) {
std::cout << "Tests pass" << std::endl;
return 0;
} else {
std::cout << "Some or all tests fail" << std::endl;
return 1;
}
}
#include <cmath>
#include <tuple>
#include <vector>
#include <iostream>
#include <ctc.h>
#include "test.h"
bool small_test() {
const int alphabet_size = 5;
const int T = 2;
std::vector<float> activations = {0.1, 0.6, 0.1, 0.1, 0.1,
0.1, 0.1, 0.6, 0.1, 0.1};
// Calculate the score analytically
float expected_score;
{
std::vector<float> probs(activations.size());
softmax(activations.data(), alphabet_size, T, probs.data());
// Score calculation is specific to the given activations above
expected_score = probs[1] * probs[7];
}
cudaStream_t stream;
throw_on_error(cudaStreamCreate(&stream),
"cudaStreamCreate");
float *activations_gpu;
throw_on_error(cudaMalloc(&activations_gpu,
activations.size() * sizeof(float)),
"cudaMalloc");
throw_on_error(cudaMemcpyAsync(activations_gpu, activations.data(),
activations.size() * sizeof(float),
cudaMemcpyHostToDevice, stream),
"cudaMemcpyAsync");
std::vector<int> labels = {1, 2};
std::vector<int> label_lengths = {2};
std::vector<int> lengths;
lengths.push_back(T);
float score;
ctcOptions options{};
options.loc = CTC_GPU;
options.stream = stream;
size_t gpu_alloc_bytes;
throw_on_error(get_workspace_size(label_lengths.data(), lengths.data(),
alphabet_size, lengths.size(), options,
&gpu_alloc_bytes),
"Error: get_workspace_size in small_test");
char *ctc_gpu_workspace;
throw_on_error(cudaMalloc(&ctc_gpu_workspace, gpu_alloc_bytes),
"cudaMalloc");
throw_on_error(compute_ctc_loss(activations_gpu, nullptr,
labels.data(), label_lengths.data(),
lengths.data(),
alphabet_size,
lengths.size(),
&score,
ctc_gpu_workspace,
options),
"Error: compute_ctc_loss in small_test");
score = std::exp(-score);
const float eps = 1e-6;
const float lb = expected_score - eps;
const float ub = expected_score + eps;
throw_on_error(cudaFree(activations_gpu),
"cudaFree");
throw_on_error(cudaFree(ctc_gpu_workspace),
"cudaFree");
throw_on_error(cudaStreamDestroy(stream),
"cudaStreamDestroy");
return (score > lb && score < ub);
}
int offset(int t, int n, int a) {
constexpr int minibatch = 2;
constexpr int alphabet_size = 6;
return (t * minibatch + n) * alphabet_size + a;
}
bool options_test() {
const int alphabet_size = 6;
const int T = 5;
const int minibatch = 2;
std::vector<float> activations =
{0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553,
0.30176, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508,
0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436,
0.24082, 0.397533, 0.0557226, 0.0546814, 0.0557528, 0.19549,
0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688,
0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, 0.202456,
0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533,
0.280884, 0.429522, 0.0326593, 0.0339046, 0.0326856, 0.190345,
0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107,
0.423286, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046};
std::vector<float> expected_grads = // from tensorflow
{-0.366234, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553,
-0.69824, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508,
0.111121, -0.411608, 0.278779, 0.0055756, 0.00569609, 0.010436,
0.24082, -0.602467, 0.0557226, 0.0546814, 0.0557528, 0.19549,
0.0357786, 0.633813, -0.678582, 0.00249248, 0.00272882, 0.0037688,
0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, -0.797544,
0.0663296, -0.356151, 0.280111, 0.00283995, 0.0035545, 0.00331533,
0.280884, -0.570478, 0.0326593, 0.0339046, 0.0326856, 0.190345,
-0.541765, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107,
-0.576714, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046};
// Calculate the expected scores analytically
auto& a = activations;
double expected_score[2];
expected_score[0] =
-std::log(a[offset(0, 0, 0)] * a[offset(1, 0, 1)] * a[offset(2, 0, 2)]
* a[offset(3, 0, 1)] * a[offset(4, 0, 0)]);
expected_score[1] = 5.42262; // from tensorflow
// now take the log to account for the softmax
for (auto& a : activations) {
a = std::log(a);
}
cudaStream_t stream;
throw_on_error(cudaStreamCreate(&stream),
"cudaStreamCreate");
float *activations_gpu;
throw_on_error(cudaMalloc(&activations_gpu,
activations.size() * sizeof(float)),
"cudaMalloc");
throw_on_error(cudaMemcpyAsync(activations_gpu, activations.data(),
activations.size() * sizeof(float),
cudaMemcpyHostToDevice, stream),
"cudaMemcpyAsync");
std::vector<int> labels = {0, 1, 2, 1, 0,
0, 1, 1, 0};
std::vector<int> label_lengths = {5, 4};
std::vector<int> lengths = {5, 5};
float score[2];
float *grads_gpu;
throw_on_error(cudaMalloc(&grads_gpu, (alphabet_size * T * minibatch) * sizeof(float)),
"cudaMalloc");
ctcOptions options{};
options.loc = CTC_GPU;
options.stream = stream;
options.blank_label = 5;
size_t gpu_alloc_bytes;
throw_on_error(get_workspace_size(label_lengths.data(), lengths.data(),
alphabet_size, lengths.size(), options,
&gpu_alloc_bytes),
"Error: get_workspace_size in options_test");
char *ctc_gpu_workspace;
throw_on_error(cudaMalloc(&ctc_gpu_workspace, gpu_alloc_bytes),
"cudaMalloc");
throw_on_error(compute_ctc_loss(activations_gpu, grads_gpu,
labels.data(), label_lengths.data(),
lengths.data(),
alphabet_size,
lengths.size(),
&score[0],
ctc_gpu_workspace,
options),
"Error: compute_ctc_loss in options_test");
std::vector<float> grads(alphabet_size * T * minibatch);
throw_on_error(cudaMemcpyAsync(grads.data(), grads_gpu,
grads.size() * sizeof(float),
cudaMemcpyDeviceToHost, stream),
"cudaMemcpyAsync");
throw_on_error(cudaStreamSynchronize(stream), "cudaStreamSynchronize");
throw_on_error(cudaFree(activations_gpu),
"cudaFree");
throw_on_error(cudaFree(ctc_gpu_workspace),
"cudaFree");
throw_on_error(cudaStreamDestroy(stream),
"cudaStreamDestroy");
const double eps = 1e-4;
bool result = true;
for (int i = 0; i < grads.size(); i++) {
const double lb = expected_grads[i] - eps;
const double ub = expected_grads[i] + eps;
if (!(grads[i] > lb && grads[i] < ub)) {
std::cerr << "grad mismatch in options_test"
<< " expected grad: " << expected_grads[i]
<< " calculated score: " << grads[i]
<< " !(" << lb << " < " << grads[i]
<< " < " << ub << ")" << std::endl;
result = false;
}
}
for (int i = 0; i < 2; i++) {
const double lb = expected_score[i] - eps;
const double ub = expected_score[i] + eps;
if (!(score[i] > lb && score[i] < ub)) {
std::cerr << "score mismatch in options_test"
<< " expected score: " << expected_score[i]
<< " calculated score: " << score[i] << std::endl;
result = false;
}
}
return result;
}
bool inf_test() {
const int alphabet_size = 15;
const int T = 50;
const int L = 10;
const int minibatch = 1;
std::vector<int> labels = genLabels(alphabet_size, L);
labels[0] = 2;
std::vector<int> label_lengths = {L};
std::vector<float> acts = genActs(alphabet_size * T * minibatch);
for (int i = 0; i < T; ++i)
acts[alphabet_size * i + 2] = -1e30;
cudaStream_t stream;
throw_on_error(cudaStreamCreate(&stream),
"cudaStreamCreate");
float *acts_gpu;
throw_on_error(cudaMalloc(&acts_gpu, acts.size() * sizeof(float)),
"cudaMalloc");
throw_on_error(cudaMemcpyAsync(acts_gpu, acts.data(),
acts.size() * sizeof(float),
cudaMemcpyHostToDevice, stream),
"cudaMemcpyAsync");
std::vector<int> lengths;
lengths.push_back(T);
float *grads_gpu;
throw_on_error(cudaMalloc(&grads_gpu, (alphabet_size * T) * sizeof(float)),
"cudaMalloc");
float cost;
ctcOptions options{};
options.loc = CTC_GPU;
options.stream = stream;
size_t gpu_alloc_bytes;
throw_on_error(get_workspace_size(label_lengths.data(), lengths.data(),
alphabet_size, lengths.size(), options,
&gpu_alloc_bytes),
"Error: get_workspace_size in inf_test");
char *ctc_gpu_workspace;
throw_on_error(cudaMalloc(&ctc_gpu_workspace, gpu_alloc_bytes),
"cudaMalloc");
throw_on_error(compute_ctc_loss(acts_gpu, grads_gpu,
labels.data(), label_lengths.data(),
lengths.data(),
alphabet_size,
lengths.size(),
&cost,
ctc_gpu_workspace,
options),
"Error: compute_ctc_loss in inf_test");
bool status = std::isinf(cost);
std::vector<float> grads(alphabet_size * T);
throw_on_error(cudaMemcpyAsync(grads.data(), grads_gpu,
grads.size() * sizeof(float),
cudaMemcpyDeviceToHost, stream),
"cudaMemcpyAsync");
throw_on_error(cudaStreamSynchronize(stream), "cudaStreamSynchronize");
for (int i = 0; i < alphabet_size * T; ++i)
status &= !std::isnan(grads[i]);
throw_on_error(cudaFree(acts_gpu), "cudaFree");
throw_on_error(cudaFree(grads_gpu), "cudaFree");
throw_on_error(cudaFree(ctc_gpu_workspace), "cudaFree");
throw_on_error(cudaStreamDestroy(stream), "cudaStreamDestroy");
return status;
}
float grad_check(int T, int alphabet_size,
std::vector<float>& acts,
const std::vector<std::vector<int>>& labels,
const std::vector<int>& lengths) {
float epsilon = 1e-2;
const int minibatch = labels.size();
cudaStream_t stream;
throw_on_error(cudaStreamCreate(&stream),
"cudaStreamCreate");
float *acts_gpu;
throw_on_error(cudaMalloc(&acts_gpu, acts.size() * sizeof(float)),
"cudaMalloc");
throw_on_error(cudaMemcpyAsync(acts_gpu, acts.data(),
acts.size() * sizeof(float),
cudaMemcpyHostToDevice, stream),
"cudaMemcpyAsync");
std::vector<int> flat_labels;
std::vector<int> label_lengths;
for (const auto& l : labels) {
flat_labels.insert(flat_labels.end(), l.begin(), l.end());
label_lengths.push_back(l.size());
}
std::vector<float> costs(minibatch);
float *grads_gpu;
throw_on_error(cudaMalloc(&grads_gpu, acts.size() * sizeof(float)),
"cudaMalloc");
ctcOptions options{};
options.loc = CTC_GPU;
options.stream = stream;
size_t gpu_alloc_bytes;
throw_on_error(get_workspace_size(label_lengths.data(),
lengths.data(),
alphabet_size,
lengths.size(),
options,
&gpu_alloc_bytes),
"Error: get_workspace_size in grad_check");
char *ctc_gpu_workspace;
throw_on_error(cudaMalloc(&ctc_gpu_workspace, gpu_alloc_bytes),
"cudaMalloc");
throw_on_error(compute_ctc_loss(acts_gpu, grads_gpu,
flat_labels.data(),
label_lengths.data(),
lengths.data(),
alphabet_size,
minibatch,
costs.data(),
ctc_gpu_workspace,
options),
"Error: compute_ctc_loss (0) in grad_check");
std::vector<float> grads(acts.size());
throw_on_error(cudaMemcpyAsync(grads.data(),
grads_gpu, grads.size() * sizeof(float),
cudaMemcpyDeviceToHost, stream),
"cudaMemcpyAsync");
throw_on_error(cudaStreamSynchronize(stream), "cudaStreamSynchronize");
std::vector<float> num_grad(grads.size());
//perform 2nd order central differencing
for (int i = 0; i < T * alphabet_size * minibatch; ++i) {
acts[i] += epsilon;
throw_on_error(cudaMemcpyAsync(acts_gpu, acts.data(),
acts.size() * sizeof(float),
cudaMemcpyHostToDevice, stream),
"cudaMemcpyAsync");
std::vector<float> costsP1(minibatch);
std::vector<float> costsP2(minibatch);
throw_on_error(compute_ctc_loss(acts_gpu, NULL,
flat_labels.data(),
label_lengths.data(),
lengths.data(),
alphabet_size,
minibatch,
costsP1.data(),
ctc_gpu_workspace,
options),
"Error: compute_ctc_loss (1) in grad_check");
acts[i] -= 2 * epsilon;
throw_on_error(cudaMemcpyAsync(acts_gpu, acts.data(),
acts.size() * sizeof(float),
cudaMemcpyHostToDevice, stream),
"cudaMemcpyAsync");
throw_on_error(compute_ctc_loss(acts_gpu, NULL,
flat_labels.data(),
label_lengths.data(),
lengths.data(),
alphabet_size,
minibatch,
costsP2.data(),
ctc_gpu_workspace,
options),
"Error: compute_ctc_loss (2) in grad_check");
float costP1 = std::accumulate(costsP1.begin(), costsP1.end(), 0.);
float costP2 = std::accumulate(costsP2.begin(), costsP2.end(), 0.);
acts[i] += epsilon;
num_grad[i] = (costP1 - costP2) / (2 * epsilon);
}
float diff = rel_diff(grads, num_grad);
throw_on_error(cudaFree(acts_gpu),
"cudaFree");
throw_on_error(cudaFree(grads_gpu),
"cudaFree");
throw_on_error(cudaFree(ctc_gpu_workspace),
"cudaFree");
throw_on_error(cudaStreamDestroy(stream),
"cudaStreamDestroy");
return diff;
}
bool run_tests() {
std::vector<std::tuple<int, int, int, int, float>> problem_sizes =
{ std::make_tuple(28, 50, 15, 1, 1e-5) };
bool status = true;
for (auto problem : problem_sizes) {
int alphabet_size, T, L, minibatch;
float tol;
std::tie(alphabet_size, T, L, minibatch, tol) = problem;
std::vector<float> acts = genActs(alphabet_size * T * minibatch);
std::vector<std::vector<int>> labels;
std::vector<int> sizes;
for (int mb = 0; mb < minibatch; ++mb) {
int actual_length = L;
labels.push_back(genLabels(alphabet_size, actual_length));
sizes.push_back(T);
}
float diff = grad_check(T, alphabet_size, acts, labels, sizes);
status &= (diff < tol);
}
return status;
}
int main(void) {
if (get_warpctc_version() != 2) {
std::cerr << "Invalid WarpCTC version." << std::endl;
return 1;
}
std::cout << "Running GPU tests" << std::endl;
throw_on_error(cudaSetDevice(0), "cudaSetDevice");
bool status = true;
status &= small_test();
status &= options_test();
status &= inf_test();
status &= run_tests();
if (status) {
std::cout << "Tests pass" << std::endl;
return 0;
} else {
std::cout << "Some or all tests fail" << std::endl;
return 1;
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment