Commit 0bf5eb5f authored by lishen's avatar lishen
Browse files

warpctc for dcu

parent 949fcc19
#pragma once
/*
int gpu_ctc(THCudaTensor *probs,
THCudaTensor *grads,
THIntTensor *labels_ptr,
THIntTensor *label_sizes_ptr,
THIntTensor *sizes,
int minibatch_size,
THFloatTensor *costs,
int blank_label);
*/
int gpu_ctc(torch::Tensor probs,
torch::Tensor grads,
torch::Tensor labels,
torch::Tensor label_sizes,
torch::Tensor sizes,
int minibatch_size,
torch::Tensor costs,
int blank_label);
import torch
import warpctc_pytorch as warp_ctc
def test_empty_label(test_cpu=True, test_gpu=True):
probs = torch.FloatTensor([
[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]],
[[0.6, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.5, 0.2, 0.1]]
]).contiguous()
grads = torch.zeros(probs.size())
labels = torch.IntTensor([1, 2])
label_sizes = torch.IntTensor([2, 0])
sizes = torch.IntTensor([2, 2])
minibatch_size = probs.size(1)
if test_cpu:
costs = torch.zeros(minibatch_size)
warp_ctc.cpu_ctc(probs, grads, labels, label_sizes, sizes, minibatch_size, costs, 0)
print('CPU_cost: %f' % costs.sum())
print('CPU probs={}\ngrads={}\ncosts={}'.format(probs, grads, costs))
if test_gpu:
probs = probs.clone().cuda()
grads = torch.zeros(probs.size()).cuda()
costs = torch.zeros(minibatch_size)
warp_ctc.gpu_ctc(probs, grads, labels, label_sizes, sizes, minibatch_size, costs, 0)
print('GPU_cost: %f' % costs.sum())
print(grads.view(grads.size(0) * grads.size(1), grads.size(2)))
print('GPU probs={}\ngrads={}\ncosts={}'.format(probs, grads, costs))
if __name__ == '__main__':
print('torch.cuda.is_available() ', torch.cuda.is_available())
# test_empty_label(test_cpu=True, test_gpu=False)
test_empty_label(test_cpu=False, test_gpu=True)
# HIP_VISIBLE_DEVICES=1 python3 test_gpu_new.py
import torch
import warpctc_pytorch_change1 as warp_ctc_new
import warpctc_pytorch as warp_ctc
import time
def test_compare_cpu(repeat_num=20):
probs = torch.FloatTensor([
[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]],
[[0.6, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.5, 0.2, 0.1]]
]).contiguous()
labels = torch.IntTensor([1, 2])
label_sizes = torch.IntTensor([2, 0])
sizes = torch.IntTensor([2, 2])
minibatch_size = probs.size(1)
costs = torch.zeros(minibatch_size)
grads = torch.zeros(probs.size())
time_st = time.perf_counter()
# 1.运行老版本 CPU
for i in range(repeat_num):
probs_old = probs.clone()
costs_old = costs.clone()
grads_old = grads.clone()
warp_ctc.cpu_ctc(probs_old, grads_old, labels, label_sizes, sizes, minibatch_size, costs_old, 0)
if i == 0:
print('CPU_costs_old: %f' % costs_old.sum())
print('CPU probs_old={}\ngrads_old={}\ncosts_old={}'.format(probs_old, grads_old, costs_old))
time_used = (time.perf_counter() - time_st) / repeat_num
print('CPU warp_ctc old version using time: ', time_used)
time_st = time.perf_counter()
# 2.运行新版本 CPU
for i in range(repeat_num):
probs_new = probs.clone()
costs_new = costs.clone()
grads_new = grads.clone()
warp_ctc_new.cpu_ctc(probs_new, grads_new, labels, label_sizes, sizes, minibatch_size, costs_new, 0)
if i == 0:
print('CPU_costs_new: %f' % costs_new.sum())
print('CPU probs={}\ngrads_new={}\ncosts_new={}'.format(probs_new, grads_new, costs_new))
time_used = (time.perf_counter() - time_st) / repeat_num
print('CPU warp_ctc new version using time: ', time_used)
def test_compare_gpu():
probs0 = torch.FloatTensor([
[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]],
[[0.6, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.5, 0.2, 0.1]]
]).contiguous().cuda()
labels = torch.IntTensor([1, 2])
label_sizes = torch.IntTensor([2, 0])
sizes = torch.IntTensor([2, 2])
minibatch_size = probs0.size(1)
# 1.运行新版本 CPU
probs_new = probs0.clone().cuda()
costs_new = torch.zeros(minibatch_size)
grads_new = torch.zeros(probs0.size())
warp_ctc_new.cpu_ctc(probs_new, grads_new, labels, label_sizes, sizes, minibatch_size, costs_new, 0)
print('CPU_costs_new: %f' % costs_new.sum())
print('CPU probs_new={}\ngrads_new={}\ncosts_new={}'.format(probs_new, grads_new, costs_new))
# 2.运行老版本 CPU
probs = probs0.clone().cuda()
costs = torch.zeros(minibatch_size)
grads = torch.zeros(probs0.size())
warp_ctc.cpu_ctc(probs0, grads, labels, label_sizes, sizes, minibatch_size, costs, 0)
print('CPU_cost: %f' % costs.sum())
print('CPU probs={}\ngrads={}\ncosts={}'.format(probs, grads, costs))
if __name__ == '__main__':
print('torch.cuda.is_available() ', torch.cuda.is_available())
test_compare_cpu()
test_compare_gpu()
import torch
import warpctc_pytorch as warp_ctc
from torch.autograd import Function
from torch.nn import Module
from _warp_ctc import * # noqa
def _assert_no_grad(tensor):
assert not tensor.requires_grad, \
"gradients only computed for acts - please " \
"mark other tensors as not requiring gradients"
class _CTC(Function):
@staticmethod
def forward(ctx, acts, labels, act_lens, label_lens, size_average=False,
length_average=False, blank=0):
is_cuda = True if acts.is_cuda else False
# print('_CTC is_cuda', is_cuda)
acts = acts.contiguous()
loss_func = warp_ctc.gpu_ctc if is_cuda else warp_ctc.cpu_ctc
grads = torch.zeros(acts.size()).type_as(acts)
minibatch_size = acts.size(1)
costs = torch.zeros(minibatch_size).cpu()
loss_func(acts,
grads,
labels,
label_lens,
act_lens,
minibatch_size,
costs,
blank)
costs = torch.FloatTensor([costs.sum()])
if length_average:
# Compute the avg. log-probability per batch sample and frame.
total_length = torch.sum(act_lens).item()
grads = grads / total_length
costs = costs / total_length
elif size_average:
# Compute the avg. log-probability per batch sample.
grads = grads / minibatch_size
costs = costs / minibatch_size
ctx.grads = grads
return costs
@staticmethod
def backward(ctx, grad_output):
_grad_output = grad_output.to(ctx.grads.device)
return ctx.grads.mul_(_grad_output), None, None, None, None, None, None
class CTCLoss(Module):
"""
Parameters:
size_average (bool): normalize the loss by the batch size
(default: `False`)
length_average (bool): normalize the loss by the total number of frames
in the batch. If `True`, supersedes `size_average`
(default: `False`)
"""
def __init__(self, blank=0, size_average=False, length_average=False):
super(CTCLoss, self).__init__()
self.ctc = _CTC.apply
self.blank = blank
self.size_average = size_average
self.length_average = length_average
def forward(self, acts, labels, act_lens, label_lens):
"""
acts: Tensor of (seqLength x batch x outputDim) containing output from network
labels: 1 dimensional Tensor containing all the targets of the batch in one sequence
act_lens: Tensor of size (batch) containing size of each output sequence from the network
label_lens: Tensor of (batch) containing label length of each example
"""
# labels must be 1 dimensional
if len(labels.size()) != 1:
print('error!! len(labels.size()) must be 1, get {}'.format(len(labels.size())))
raise ValueError
_assert_no_grad(labels)
_assert_no_grad(act_lens)
_assert_no_grad(label_lens)
return self.ctc(acts, labels, act_lens, label_lens, self.size_average,
self.length_average, self.blank)
#include <cstddef>
#include <iostream>
#include <algorithm>
#include "ctc.h"
#include "detail/cpu_ctc.h"
#ifdef __HIPCC__
#include "detail/gpu_ctc.h"
#endif
extern "C" {
int get_warpctc_version() {
return 13;
}
const char *ctcGetStatusString(ctcStatus_t status) {
switch (status) {
case CTC_STATUS_SUCCESS:
return "no error";
case CTC_STATUS_MEMOPS_FAILED:
return "cuda memcpy or memset failed";
case CTC_STATUS_INVALID_VALUE:
return "invalid value";
case CTC_STATUS_EXECUTION_FAILED:
return "execution failed";
case CTC_STATUS_UNKNOWN_ERROR:
default:
return "unknown error";
}
}
ctcStatus_t compute_ctc_loss(const float *const activations,
float *gradients,
const int *const flat_labels,
const int *const label_lengths,
const int *const input_lengths,
int alphabet_size,
int minibatch,
float *costs,
void *workspace,
ctcOptions options) {
if (activations == nullptr ||
flat_labels == nullptr ||
label_lengths == nullptr ||
input_lengths == nullptr ||
costs == nullptr ||
workspace == nullptr ||
alphabet_size <= 0 ||
minibatch <= 0)
return CTC_STATUS_INVALID_VALUE;
if (options.loc == CTC_CPU) {
CpuCTC<float> ctc(alphabet_size, minibatch, workspace, options.num_threads,
options.blank_label);
if (gradients != NULL)
return ctc.cost_and_grad(activations, gradients,
costs,
flat_labels, label_lengths,
input_lengths);
else
return ctc.score_forward(activations,
costs, flat_labels,
label_lengths, input_lengths);
} else if (options.loc == CTC_GPU) {
#ifdef __HIPCC__
GpuCTC<float> ctc(alphabet_size, minibatch, workspace, options.stream,
options.blank_label);
if (gradients != NULL)
return ctc.cost_and_grad(activations, gradients, costs,
flat_labels, label_lengths,
input_lengths);
else
return ctc.score_forward(activations, costs, flat_labels,
label_lengths, input_lengths);
#else
std::cerr << "GPU execution requested, but not compiled with GPU support" << std::endl;
return CTC_STATUS_EXECUTION_FAILED;
#endif
} else {
return CTC_STATUS_INVALID_VALUE;
}
}
ctcStatus_t get_workspace_size(const int *const label_lengths,
const int *const input_lengths,
int alphabet_size, int minibatch,
ctcOptions options,
size_t *size_bytes) {
if (label_lengths == nullptr ||
input_lengths == nullptr ||
size_bytes == nullptr ||
alphabet_size <= 0 ||
minibatch <= 0)
return CTC_STATUS_INVALID_VALUE;
// This is the max of all S and T for all examples in the minibatch.
int maxL = *std::max_element(label_lengths, label_lengths + minibatch);
int maxT = *std::max_element(input_lengths, input_lengths + minibatch);
const int S = 2 * maxL + 1;
*size_bytes = 0;
if (options.loc == CTC_GPU) {
// GPU storage
//nll_forward, nll_backward
*size_bytes += 2 * sizeof(float) * minibatch;
//repeats
*size_bytes += sizeof(int) * minibatch;
//label offsets
*size_bytes += sizeof(int) * minibatch;
//utt_length
*size_bytes += sizeof(int) * minibatch;
//label lengths
*size_bytes += sizeof(int) * minibatch;
//labels without blanks - overallocate for now
*size_bytes += sizeof(int) * maxL * minibatch;
//labels with blanks
*size_bytes += sizeof(int) * S * minibatch;
//alphas
*size_bytes += sizeof(float) * S * maxT * minibatch;
//denoms
*size_bytes += sizeof(float) * maxT * minibatch;
//probs (since we will pass in activations)
*size_bytes += sizeof(float) * alphabet_size * maxT * minibatch;
} else {
//cpu can eventually replace all minibatch with
//max number of concurrent threads if memory is
//really tight
//per minibatch memory
size_t per_minibatch_bytes = 0;
//output
per_minibatch_bytes += sizeof(float) * alphabet_size;
//alphas
per_minibatch_bytes += sizeof(float) * S * maxT;
//betas
per_minibatch_bytes += sizeof(float) * S;
//labels w/blanks, e_inc, s_inc
per_minibatch_bytes += 3 * sizeof(int) * S;
*size_bytes = per_minibatch_bytes * minibatch;
//probs
*size_bytes += sizeof(float) * alphabet_size * maxT * minibatch;
}
return CTC_STATUS_SUCCESS;
}
}
// Includes, system
#include <stdio.h>
#include <stdlib.h>
// Includes, cuda
#include <cuda_runtime.h>
//#include<cublas_v2.h>
#include <cuda_runtime_api.h>
// Includes, cuda helper functions
// #include <helper_cuda.h>
// For the functors
#include "detail/ctc_helper.h"
#include "ctc.h"
const int warp_size = 64;
const int kCUDABlockNumThreads = 256;
template<int NT, typename T, typename Rop>
struct CTAReduce;
template<int NT, typename T, typename Rop>
struct CTAReduce {
enum {
Size = NT, Capacity = NT
};
struct Storage {
T shared[Capacity];
};
__device__ static T reduce(int tid, T x, Storage &storage, int count, Rop g) {
T *s = storage.shared;
s[tid] = x;
__syncthreads();
// Fold the data in half with each pass.
#pragma unroll
for (int offset = NT / 2; offset >= warp_size; offset /= 2) {
if (tid + offset < count && tid < offset) {
x = g(x, s[offset + tid]);
s[tid] = x;
}
__syncthreads();
}
T shuff;
for (int offset = warp_size / 2; offset > 0; offset /= 2) {
// shuff = __shfl_down(0xFFFFFFF, x, offset);
shuff = __shfl_down(x, offset);
if (tid + offset < count && tid < offset) {
x = g(x, shuff);
}
}
return x;
}
};
template<int NT, typename Iop, typename Rop, typename T>
__global__ void reduce_rows(Iop f, Rop g, const T *input, T *output,
int num_rows, int num_cols) {
typedef CTAReduce<NT, T, Rop> R;
__shared__ typename R::Storage storage;
int tid = threadIdx.x;
int idx = tid;
int col = blockIdx.x;
T curr;
// Each block works on a column
if (idx < num_rows) {
curr = f(input[idx + col * num_rows]);
}
// __syncthreads();
idx += NT;
while (idx < num_rows) {
curr = g(curr, f(input[idx + col * num_rows]));
idx += NT;
}
// Sum thread-totals over the CTA.
curr = R::reduce(tid, curr, storage, num_rows, g);
// Store result in out
if (tid == 0) {
output[col] = curr;
}
}
template<int NT, typename Iop, typename Rop, typename T>
__global__ void reduce_cols(Iop f, Rop g, const T *input, T *output,
int num_rows, int num_cols) {
__shared__ T s[NT];
int warps_per_block = NT / warp_size;
int row = blockDim.x * blockIdx.x + threadIdx.x;
int col = threadIdx.y;
T curr;
if (row < num_rows && col < num_cols) {
curr = f(input[row + col * num_rows]);
col += blockDim.y;
while (col < num_cols) {
curr = g(curr, f(input[row + col * num_rows]));
col += blockDim.y;
}
}
s[threadIdx.x * warps_per_block + threadIdx.y] = curr;
__syncthreads();
// Reduce
if (threadIdx.y == 0 && row < num_rows) {
#pragma unroll
for (int i = 1; i < warps_per_block && i < num_cols; ++i)
curr = g(curr, s[i + threadIdx.x * warps_per_block]);
output[row] = curr;
}
}
struct ReduceHelper {
template<typename T, typename Iof, typename Rof>
static void impl(Iof f, Rof g, const T *input, T *output, int num_rows, int num_cols, bool axis, CUstream stream) {
int grid_size;
if (axis) {
grid_size = num_cols;
reduce_rows<kCUDABlockNumThreads><<<grid_size, kCUDABlockNumThreads, 0, stream>>>
(f, g, input, output, num_rows, num_cols);
} else {
dim3 tpb(warp_size, kCUDABlockNumThreads / warp_size);
grid_size = (num_cols + warp_size - 1) / warp_size;
reduce_cols<kCUDABlockNumThreads><<<grid_size, tpb, 0, stream>>>
(f, g, input, output, num_rows, num_cols);
}
}
};
template<typename T, typename Iof, typename Rof>
ctcStatus_t reduce(Iof f, Rof g, const T *input, T *output, int rows, int cols, bool axis, CUstream stream) {
ReduceHelper::impl(f, g, input, output, rows, cols, axis, stream);
hipStreamSynchronize(stream);
hipError_t err = hipGetLastError();
if (err != hipSuccess)
return CTC_STATUS_EXECUTION_FAILED;
return CTC_STATUS_SUCCESS;
}
ctcStatus_t reduce_negate(const float *input, float *output, int rows, int cols, bool axis, CUstream stream) {
return reduce(ctc_helper::negate<float>(), ctc_helper::add<float>(), input, output, rows, cols, axis, stream);
}
ctcStatus_t reduce_exp(const float *input, float *output, int rows, int cols, bool axis, CUstream stream) {
return reduce(ctc_helper::exponential<float>(), ctc_helper::add<float>(), input, output, rows, cols, axis, stream);
}
ctcStatus_t reduce_max(const float *input, float *output, int rows, int cols, bool axis, CUstream stream) {
auto ctc_status = reduce(ctc_helper::identity<float>(), ctc_helper::maximum<float>(), input, output, rows, cols, axis, stream);
return ctc_status;
}
// !!! This is a file automatically generated by hipify!!!
// Includes, system
#include <stdio.h>
#include <stdlib.h>
// Includes, cuda
#include <hip/hip_runtime.h>
//#include<rocblas.h>
#include <hip/hip_runtime_api.h>
// Includes, cuda helper functions
// #include <helper_cuda.h>
// For the functors
#include "detail/ctc_helper.h"
#include "ctc.h"
const int warp_size = 64;
const int kCUDABlockNumThreads = 256;
template<int NT, typename T, typename Rop>
struct CTAReduce;
template<int NT, typename T, typename Rop>
struct CTAReduce {
enum {
Size = NT, Capacity = NT
};
struct Storage {
T shared[Capacity];
};
__device__ static T reduce(int tid, T x, Storage &storage, int count, Rop g) {
T *s = storage.shared;
s[tid] = x;
__syncthreads();
// Fold the data in half with each pass.
#pragma unroll
for (int offset = NT / 2; offset >= warp_size; offset /= 2) {
if (tid + offset < count && tid < offset) {
x = g(x, s[offset + tid]);
s[tid] = x;
}
__syncthreads();
}
T shuff;
for (int offset = warp_size / 2; offset > 0; offset /= 2) {
// shuff = __shfl_down(0xFFFFFFF, x, offset);
shuff = __shfl_down(x, offset);
if (tid + offset < count && tid < offset) {
x = g(x, shuff);
}
}
return x;
}
};
template<int NT, typename Iop, typename Rop, typename T>
__global__ void reduce_rows(Iop f, Rop g, const T *input, T *output,
int num_rows, int num_cols) {
typedef CTAReduce<NT, T, Rop> R;
__shared__ typename R::Storage storage;
int tid = threadIdx.x;
int idx = tid;
int col = blockIdx.x;
T curr;
// Each block works on a column
if (idx < num_rows) {
curr = f(input[idx + col * num_rows]);
}
// __syncthreads();
idx += NT;
while (idx < num_rows) {
curr = g(curr, f(input[idx + col * num_rows]));
idx += NT;
}
// Sum thread-totals over the CTA.
curr = R::reduce(tid, curr, storage, num_rows, g);
// Store result in out
if (tid == 0) {
output[col] = curr;
}
}
template<int NT, typename Iop, typename Rop, typename T>
__global__ void reduce_cols(Iop f, Rop g, const T *input, T *output,
int num_rows, int num_cols) {
__shared__ T s[NT];
int warps_per_block = NT / warp_size;
int row = blockDim.x * blockIdx.x + threadIdx.x;
int col = threadIdx.y;
T curr;
if (row < num_rows && col < num_cols) {
curr = f(input[row + col * num_rows]);
col += blockDim.y;
while (col < num_cols) {
curr = g(curr, f(input[row + col * num_rows]));
col += blockDim.y;
}
}
s[threadIdx.x * warps_per_block + threadIdx.y] = curr;
__syncthreads();
// Reduce
if (threadIdx.y == 0 && row < num_rows) {
#pragma unroll
for (int i = 1; i < warps_per_block && i < num_cols; ++i)
curr = g(curr, s[i + threadIdx.x * warps_per_block]);
output[row] = curr;
}
}
struct ReduceHelper {
template<typename T, typename Iof, typename Rof>
static void impl(Iof f, Rof g, const T *input, T *output, int num_rows, int num_cols, bool axis, hipStream_t stream) {
int grid_size;
if (axis) {
grid_size = num_cols;
hipLaunchKernelGGL(( reduce_rows<kCUDABlockNumThreads>), dim3(grid_size), dim3(kCUDABlockNumThreads), 0, stream,
f, g, input, output, num_rows, num_cols);
} else {
dim3 tpb(warp_size, kCUDABlockNumThreads / warp_size);
grid_size = (num_cols + warp_size - 1) / warp_size;
hipLaunchKernelGGL(( reduce_cols<kCUDABlockNumThreads>), dim3(grid_size), dim3(tpb), 0, stream,
f, g, input, output, num_rows, num_cols);
}
}
};
template<typename T, typename Iof, typename Rof>
ctcStatus_t reduce(Iof f, Rof g, const T *input, T *output, int rows, int cols, bool axis, hipStream_t stream) {
ReduceHelper::impl(f, g, input, output, rows, cols, axis, stream);
hipStreamSynchronize(stream);
hipError_t err = hipGetLastError();
if (err != hipSuccess)
return CTC_STATUS_EXECUTION_FAILED;
return CTC_STATUS_SUCCESS;
}
ctcStatus_t reduce_negate(const float *input, float *output, int rows, int cols, bool axis, hipStream_t stream) {
return reduce(ctc_helper::negate<float>(), ctc_helper::add<float>(), input, output, rows, cols, axis, stream);
}
ctcStatus_t reduce_exp(const float *input, float *output, int rows, int cols, bool axis, hipStream_t stream) {
return reduce(ctc_helper::exponential<float>(), ctc_helper::add<float>(), input, output, rows, cols, axis, stream);
}
ctcStatus_t reduce_max(const float *input, float *output, int rows, int cols, bool axis, hipStream_t stream) {
auto ctc_status = reduce(ctc_helper::identity<float>(), ctc_helper::maximum<float>(), input, output, rows, cols, axis, stream);
return ctc_status;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment