Commit cfbf5f80 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by Boris Bonev
Browse files

doubling the indent to 4

parent 4805b39c
---
BasedOnStyle: Webkit
IndentWidth: 2
IndentWidth: 4
AccessModifierOffset: -2
AlignAfterOpenBracket: Align
AlignTrailingComments: true
......@@ -13,8 +13,8 @@ BreakBeforeTernaryOperators: false
BreakConstructorInitializers: AfterColon
ColumnLimit: 120
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 2
ContinuationIndentWidth: 2
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
FixNamespaceComments: true
NamespaceIndentation: All
......
......@@ -70,8 +70,9 @@
class ScopeTimer
{
public:
explicit ScopeTimer(const std::string &label = "") : label_(label), start_(std::chrono::high_resolution_clock::now())
public:
explicit ScopeTimer(const std::string &label = "") :
label_(label), start_(std::chrono::high_resolution_clock::now())
{
}
......@@ -82,7 +83,7 @@ public:
std::cout << label_ << "Elapsed time: " << elapsed.count() << " ms" << std::endl;
}
private:
private:
std::string label_;
std::chrono::high_resolution_clock::time_point start_;
};
......@@ -165,7 +166,9 @@ __global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel(
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip]; }
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip];
}
qdotk = __warp_sum_cub(qdotk);
qdotk_max = max(qdotk_max, qdotk);
}
......
......@@ -173,24 +173,24 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t
switch (pscale) {
case 1:
disco_bwd_blk_k<NTH, ELXTH, 1>
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
disco_bwd_blk_k<NTH, ELXTH, 1><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d,
row_d, col_d, val_d, inp_d, out_d);
break;
case 2:
disco_bwd_blk_k<NTH, ELXTH, 2>
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
disco_bwd_blk_k<NTH, ELXTH, 2><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d,
row_d, col_d, val_d, inp_d, out_d);
break;
case 3:
disco_bwd_blk_k<NTH, ELXTH, 3>
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
disco_bwd_blk_k<NTH, ELXTH, 3><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d,
row_d, col_d, val_d, inp_d, out_d);
break;
default:
disco_bwd_blk_k<NTH, ELXTH, 0>
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
disco_bwd_blk_k<NTH, ELXTH, 0><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d,
row_d, col_d, val_d, inp_d, out_d);
}
} else {
launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d,
stream);
launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d,
out_d, stream);
}
}
return;
......@@ -231,36 +231,41 @@ torch::Tensor disco_cuda_bwd(torch::Tensor inp, torch::Tensor roff_idx, torch::T
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<64, 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 128 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 256 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 512 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 1024 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else {
fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo,
......
......@@ -55,7 +55,8 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H
REAL_T __reg[ELXTH] = {0};
// align to larger supported fp type
extern __shared__ __align__(sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)]
extern __shared__ __align__(
sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)]
REAL_T *__sh = reinterpret_cast<REAL_T *>(__sh_ptr);
int col_prev = cols[soff];
......@@ -172,11 +173,11 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t
const int pscale = Wi / Wo;
size_t shmem = sizeof(*out_d) * (Wi * 2 + pscale * (NTH * ELXTH - Wo));
disco_fwd_blk_k<NTH, ELXTH>
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
disco_fwd_blk_k<NTH, ELXTH><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d,
col_d, val_d, inp_d, out_d);
} else {
launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d,
stream);
launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d,
out_d, stream);
}
}
return;
......@@ -218,36 +219,41 @@ torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::T
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<64, 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 128 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 256 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 512 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 1024 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else {
fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment