Commit bae75e40 authored by change3n8's avatar change3n8
Browse files

init

parent cb6ef609
Pipeline #3406 failed with stages
in 0 seconds
# encode_reg_target # build
python3 setup.py build_ext --inplace
# run
python3 test_encode_reg_target.py
\ No newline at end of file
#include <torch/extension.h>
#include <ATen/hip/HIPContext.h>
#include <vector>
__global__ void encode_reg_target_kernel(
const float* input, // [N, D]
const int64_t len, // 实际行数
float* output, // [N, Dout]
const int N,
const int D,
const int Dout
) {
const float eps = 1e-6f;
const int X=0, Y=1, Z=2, W=3, L=4, H=5, YAW=6;
for (int n = threadIdx.x + blockIdx.x * blockDim.x;
n < N;
n += blockDim.x * gridDim.x) {
if (n >= len) continue;
const float* in_row = input + n * D;
float* out_row = output + n * Dout;
// 1. copy X,Y,Z
out_row[0] = in_row[X];
out_row[1] = in_row[Y];
out_row[2] = in_row[Z];
// 2. log(W,L,H)
out_row[3] = logf(fmaxf(in_row[W], eps));
out_row[4] = logf(fmaxf(in_row[L], eps));
out_row[5] = logf(fmaxf(in_row[H], eps));
// 3. sin/cos(YAW)
out_row[6] = sinf(in_row[YAW]);
out_row[7] = cosf(in_row[YAW]);
// 4. rest
for (int i = 7; i < D; ++i) {
out_row[8 + i - 7] = in_row[i];
}
}
}
torch::Tensor encode_reg_t(
const torch::Tensor& input,
const int64_t len
) {
int N = input.size(0);
int D = input.size(1);
std::vector<int> dims = {3, 3, 2, std::max(0, D-7)};
int Dout = std::accumulate(dims.begin(), dims.end(), 0);
auto options = torch::TensorOptions().device(input.device()).dtype(input.dtype());
torch::Tensor output = torch::empty({N, Dout}, options);
int threads = 256;
int blocks = (N + threads - 1) / threads;
dim3 block(threads);
dim3 grid(blocks);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
encode_reg_target_kernel<<<grid, block, 0, stream>>>(
input.data_ptr<float>(),
len,
output.data_ptr<float>(),
N,
D,
Dout
);
return output;
}
#include <torch/extension.h>
torch::Tensor encode_reg_t(const torch::Tensor &input, const int64_t lengths);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("encode_reg_t", &encode_reg_t, "encode_reg_t");
}
\ No newline at end of file
# build
python3 setup.py build_ext --inplace
# run
python3 test_encode_reg_target.py
\ No newline at end of file
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='encode_reg_target_ext',
ext_modules=[
CUDAExtension(
name='encode_reg_target_ext',
sources=[
'export.cpp',
'encode_reg_target.cu',
],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': ['-O3']
}
)
],
cmdclass={
'build_ext': BuildExtension
}
)
\ No newline at end of file
import torch
import encode_reg_target_ext
import time
# -----------------------------
# reference implementation
# -----------------------------
def encode_reg_target_ref(box_target, device=None):
outputs = []
for box in box_target:
output = torch.cat(
[
box[..., [X, Y, Z]],
box[..., [W, L, H]].log(),
torch.sin(box[..., YAW]).unsqueeze(-1),
torch.cos(box[..., YAW]).unsqueeze(-1),
box[..., YAW + 1:],
],
dim=-1,
)
if device is not None:
output = output.to(device=device)
outputs.append(output)
return outputs
# -----------------------------
# Optimized HIP/C++ implementation
# -----------------------------
def encode_reg_target_optimized(box_target_list, device=None):
if len(box_target_list) == 0:
return []
outputs = []
dev = device if device is not None else box_target_list[0].device
box_target_list = [t.to(dev) if t.device != dev else t for t in box_target_list]
for box in box_target_list:
N, D = box.shape
out = encode_reg_target_ext.encode_reg_t(box, N)
outputs.append(out)
return outputs
# -----------------------------
# Main benchmark
# -----------------------------
if __name__ == "__main__":
# index definition
X, Y, Z, W, L, H, SIN_YAW, COS_YAW, VX, VY, VZ = list(range(11))
YAW = 6
device = "cuda:0"
box_target_len = 100
N_dim = 900
D_dim = 10
m_dims = [
40, 12, 25, 111, 45, 84, 19, 14, 52, 13,
20, 19, 28, 28, 6, 62, 26, 56, 13, 33
]
box_target = [
torch.rand(m, D_dim, dtype=torch.float32, device=device)
for m in m_dims
]
num_warmup = 10
num_runs = 100
# -----------------------------
# Warmup
# -----------------------------
for _ in range(num_warmup):
encode_reg_target_ref(box_target, device)
for _ in range(num_warmup):
encode_reg_target_optimized(box_target, device)
torch.cuda.synchronize()
# -----------------------------
# Measure reference
# -----------------------------
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(num_runs):
out_ref = encode_reg_target_ref(box_target, device)
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms_ref = start_event.elapsed_time(end_event)
avg_time_ms_ref = elapsed_time_ms_ref / num_runs
print(f"reference avg time: {avg_time_ms_ref:.4f} ms")
# -----------------------------
# Measure optimized
# -----------------------------
start_event.record()
for _ in range(num_runs):
out_optimized = encode_reg_target_optimized(box_target, device)
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms_opt = start_event.elapsed_time(end_event)
avg_time_ms_opt = elapsed_time_ms_opt / num_runs
print(f"Optimized kernel avg time: {avg_time_ms_opt:.4f} ms")
# -----------------------------
# Accuracy check
# -----------------------------
all_close = True
max_diff = 0.0
mean_diff = 0.0
for ref, opt in zip(out_ref, out_optimized):
diff = torch.abs(ref - opt)
if not torch.allclose(ref, opt, atol=1e-5):
all_close = False
max_diff = max(max_diff, diff.max().item())
mean_diff += diff.mean().item()
mean_diff /= len(out_ref)
print("Accuracy check:")
print("Pass" if all_close else "Fail")
print(f"Max diff: {max_diff}")
print(f"Mean diff: {mean_diff}")
print("Test end.")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment