Commit c46b6925 authored by Max Rietmann's avatar Max Rietmann
Browse files

Applied new formatting

parent 1ea5c4ca
---
BasedOnStyle: Webkit
IndentWidth: 2
IndentWidth: 4
AccessModifierOffset: -2
AlignAfterOpenBracket: Align
AlignTrailingComments: true
......@@ -13,8 +13,8 @@ BreakBeforeTernaryOperators: false
BreakConstructorInitializers: AfterColon
ColumnLimit: 120
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 2
ContinuationIndentWidth: 2
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
FixNamespaceComments: true
NamespaceIndentation: All
......
......@@ -51,7 +51,7 @@
#define THREADS (64)
#endif
#ifndef DIV_UP
#define DIV_UP(a, b) (((a) + ((b)-1)) / (b))
#define DIV_UP(a, b) (((a) + ((b) - 1)) / (b))
#endif
#ifndef CHECK_CUDA
#define CHECK_CUDA(call) \
......@@ -70,8 +70,9 @@
class ScopeTimer
{
public:
explicit ScopeTimer(const std::string &label = "") : label_(label), start_(std::chrono::high_resolution_clock::now())
public:
explicit ScopeTimer(const std::string &label = "") :
label_(label), start_(std::chrono::high_resolution_clock::now())
{
}
......@@ -82,7 +83,7 @@ public:
std::cout << label_ << "Elapsed time: " << elapsed.count() << " ms" << std::endl;
}
private:
private:
std::string label_;
std::chrono::high_resolution_clock::time_point start_;
};
......
......@@ -45,7 +45,7 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
#define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF)
#define THREADS (64)
#define DIV_UP(a, b) (((a) + ((b)-1)) / (b))
#define DIV_UP(a, b) (((a) + ((b) - 1)) / (b))
#define NNZ_TRESH (32)
......
......@@ -40,7 +40,7 @@
CHECK_CUDA_TENSOR(x); \
CHECK_CONTIGUOUS_TENSOR(x)
#define DIV_UP(a, b) (((a) + ((b)-1)) / (b))
#define DIV_UP(a, b) (((a) + ((b) - 1)) / (b))
#define MIN_THREADS (64)
#define ELXTH_MAX (32)
......
......@@ -140,7 +140,7 @@ __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int H
template <int BDIM_X, int ELXTH, int PSCALE, typename REAL_T>
__global__
__launch_bounds__(BDIM_X) void disco_bwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
__launch_bounds__(BDIM_X) void disco_bwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
const int pscale, const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers, const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals,
......@@ -173,24 +173,24 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t
switch (pscale) {
case 1:
disco_bwd_blk_k<NTH, ELXTH, 1>
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
disco_bwd_blk_k<NTH, ELXTH, 1><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d,
row_d, col_d, val_d, inp_d, out_d);
break;
case 2:
disco_bwd_blk_k<NTH, ELXTH, 2>
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
disco_bwd_blk_k<NTH, ELXTH, 2><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d,
row_d, col_d, val_d, inp_d, out_d);
break;
case 3:
disco_bwd_blk_k<NTH, ELXTH, 3>
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
disco_bwd_blk_k<NTH, ELXTH, 3><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d,
row_d, col_d, val_d, inp_d, out_d);
break;
default:
disco_bwd_blk_k<NTH, ELXTH, 0>
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
disco_bwd_blk_k<NTH, ELXTH, 0><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d,
row_d, col_d, val_d, inp_d, out_d);
}
} else {
launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d,
stream);
launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d,
out_d, stream);
}
}
return;
......@@ -231,36 +231,41 @@ torch::Tensor disco_cuda_bwd(torch::Tensor inp, torch::Tensor roff_idx, torch::T
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<64, 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 128 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 256 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 512 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 1024 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else {
fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo,
......
......@@ -55,7 +55,8 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H
REAL_T __reg[ELXTH] = {0};
// align to larger supported fp type
extern __shared__ __align__(sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)]
extern __shared__ __align__(
sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)]
REAL_T *__sh = reinterpret_cast<REAL_T *>(__sh_ptr);
int col_prev = cols[soff];
......@@ -145,7 +146,7 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H
template <int BDIM_X, int ELXTH, typename REAL_T>
__global__
__launch_bounds__(BDIM_X) void disco_fwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
__launch_bounds__(BDIM_X) void disco_fwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
const int pscale, const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers, const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals,
......@@ -172,11 +173,11 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t
const int pscale = Wi / Wo;
size_t shmem = sizeof(*out_d) * (Wi * 2 + pscale * (NTH * ELXTH - Wo));
disco_fwd_blk_k<NTH, ELXTH>
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
disco_fwd_blk_k<NTH, ELXTH><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d,
col_d, val_d, inp_d, out_d);
} else {
launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d,
stream);
launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d,
out_d, stream);
}
}
return;
......@@ -218,36 +219,41 @@ torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::T
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<64, 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 128 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 256 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 512 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 1024 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else {
fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment