encode_reg_target.cu 1.9 KB
Newer Older
change3n8's avatar
init  
change3n8 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include <torch/extension.h>
#include <ATen/hip/HIPContext.h>
#include <vector>

__global__ void encode_reg_target_kernel(
    const float* input,      // [N, D]
    const int64_t len,       // 实际行数
    float* output,           // [N, Dout]
    const int N,
    const int D,
    const int Dout
) {
    const float eps = 1e-6f;
    const int X=0, Y=1, Z=2, W=3, L=4, H=5, YAW=6;

    for (int n = threadIdx.x + blockIdx.x * blockDim.x;
        n < N;
        n += blockDim.x * gridDim.x) {
        if (n >= len) continue;

        const float* in_row = input + n * D;
        float* out_row = output + n * Dout;

        // 1. copy X,Y,Z
        out_row[0] = in_row[X];
        out_row[1] = in_row[Y];
        out_row[2] = in_row[Z];

        // 2. log(W,L,H)
        out_row[3] = logf(fmaxf(in_row[W], eps));
        out_row[4] = logf(fmaxf(in_row[L], eps));
        out_row[5] = logf(fmaxf(in_row[H], eps));

        // 3. sin/cos(YAW)
        out_row[6] = sinf(in_row[YAW]);
        out_row[7] = cosf(in_row[YAW]);

        // 4. rest
        for (int i = 7; i < D; ++i) {
            out_row[8 + i - 7] = in_row[i];
        }
    }
}

torch::Tensor encode_reg_t(
    const torch::Tensor& input,
    const int64_t len             
) {
    int N = input.size(0);
    int D = input.size(1);
    std::vector<int> dims = {3, 3, 2, std::max(0, D-7)};
    int Dout = std::accumulate(dims.begin(), dims.end(), 0);

    auto options = torch::TensorOptions().device(input.device()).dtype(input.dtype());
    torch::Tensor output = torch::empty({N, Dout}, options);

    int threads = 256;
    int blocks = (N + threads - 1) / threads;
    dim3 block(threads);
    dim3 grid(blocks);

    cudaStream_t stream = at::cuda::getCurrentCUDAStream();

    encode_reg_target_kernel<<<grid, block, 0, stream>>>(
        input.data_ptr<float>(),
        len,
        output.data_ptr<float>(),
        N,
        D,
        Dout
    );

    return output;
}