Commit 1ec1a824 authored by lishen's avatar lishen
Browse files

pytorch1.13版本,替换THC

parent 5e65c1c3
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="10.6.10.62" preserveTimestamps="false" deleteMissingItems="true" createEmptyFolders="true" filePermissions="420" folderPermissions="493" confirmBeforeUploading="false" confirmBeforeDeletion="false" autoUploadExternalChanges="true">
<component name="PublishConfigData" serverName="10.6.10.62" preserveTimestamps="false" deleteMissingItems="true" createEmptyFolders="true" confirmBeforeUploading="false" confirmBeforeDeletion="false">
<option name="confirmBeforeDeletion" value="false" />
<option name="confirmBeforeUploading" value="false" />
<serverData>
......@@ -61,6 +61,5 @@
</serverdata>
</paths>
</serverData>
<option name="myAutoUpload" value="ON_EXPLICIT_SAVE" />
</component>
</project>
\ No newline at end of file
......@@ -130,15 +130,14 @@ void compute_alpha_kernel(const ProbT *probs, const int *label_sizes,
const int *label_global = &labels[blockIdx.x * S_memoffset];
ProbT *alpha = &alphas[blockIdx.x * (S_memoffset * T_memoffset)];
// Set the first row of alpha neg_inf - it is much more efficient to do it
// here than outside
#pragma unroll
// Set the first row of alpha neg_inf - it is much more efficient to do it here than outside
//#pragma unroll
for (int idx = tid; idx < min(S, NV); idx += blockDim.x) {
alpha[idx] = ctc_helper::neg_inf<ProbT>();
}
// Load labels into shared memory
#pragma unroll
//#pragma unroll
for (int i = tid; i < S; i += NT) {
label[i] = label_global[i];
}
......@@ -272,8 +271,8 @@ void compute_betas_and_grad_kernel(const ProbT *probs, const int *label_sizes,
int start = S > 1 ? (S - 2) : 0;
int end = (L + repeats < T) ? S : S - 1;
// Setup shared memory buffers
#pragma unroll
// // Setup shared memory buffers
//#pragma unroll
for (int idx = tid; idx < NV; idx += NT) {
label[idx] = (idx < S) ? label_global[idx] : INT_MAX;
}
......@@ -290,7 +289,7 @@ void compute_betas_and_grad_kernel(const ProbT *probs, const int *label_sizes,
int key[VT];
int gather_val[VT];
#pragma unroll
//#pragma unroll
for (int i = 0; i < VT; ++i) {
const int idx = tid * VT + i;
gather_val[i] = idx;
......
......@@ -7,9 +7,7 @@
#include "ATen/cuda/CUDAContext.h"
#include <c10/cuda/CUDAGuard.h>
#include "ATen/cuda/CUDAEvent.h"
#include <THC/THCGeneral.h>
extern THCState* state;
#include <ATen/cuda/ThrustAllocator.h>
#endif
#include "ctc.h"
......@@ -91,7 +89,7 @@ int gpu_ctc(torch::Tensor probs,
probs_size, minibatch_size,
options, &gpu_size_bytes);
void* gpu_workspace = THCudaMalloc(state, gpu_size_bytes);
void* gpu_workspace = c10::cuda::CUDACachingAllocator::raw_alloc(gpu_size_bytes);
compute_ctc_loss(probs_ptr, grads_ptr,
labels_ptr, label_sizes_ptr,
......@@ -99,7 +97,8 @@ int gpu_ctc(torch::Tensor probs,
minibatch_size, costs_ptr,
gpu_workspace, options);
THCudaFree(state, (void *) gpu_workspace);
c10::cuda::CUDACachingAllocator::raw_delete((void *) gpu_workspace);
return 1;
}
#endif
......
#pragma once
/*
int gpu_ctc(THCudaTensor *probs,
THCudaTensor *grads,
THIntTensor *labels_ptr,
THIntTensor *label_sizes_ptr,
THIntTensor *sizes,
int minibatch_size,
THFloatTensor *costs,
int blank_label);
*/
int gpu_ctc(torch::Tensor probs,
torch::Tensor grads,
torch::Tensor labels,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment