Commit cfbf5f80 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by Boris Bonev
Browse files

doubling the indent to 4

parent 4805b39c
--- ---
BasedOnStyle: Webkit BasedOnStyle: Webkit
IndentWidth: 2 IndentWidth: 4
AccessModifierOffset: -2 AccessModifierOffset: -2
AlignAfterOpenBracket: Align AlignAfterOpenBracket: Align
AlignTrailingComments: true AlignTrailingComments: true
...@@ -13,8 +13,8 @@ BreakBeforeTernaryOperators: false ...@@ -13,8 +13,8 @@ BreakBeforeTernaryOperators: false
BreakConstructorInitializers: AfterColon BreakConstructorInitializers: AfterColon
ColumnLimit: 120 ColumnLimit: 120
ConstructorInitializerAllOnOneLineOrOnePerLine: true ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 2 ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 2 ContinuationIndentWidth: 4
Cpp11BracedListStyle: true Cpp11BracedListStyle: true
FixNamespaceComments: true FixNamespaceComments: true
NamespaceIndentation: All NamespaceIndentation: All
......
...@@ -70,8 +70,9 @@ ...@@ -70,8 +70,9 @@
class ScopeTimer class ScopeTimer
{ {
public: public:
explicit ScopeTimer(const std::string &label = "") : label_(label), start_(std::chrono::high_resolution_clock::now()) explicit ScopeTimer(const std::string &label = "") :
label_(label), start_(std::chrono::high_resolution_clock::now())
{ {
} }
...@@ -82,7 +83,7 @@ public: ...@@ -82,7 +83,7 @@ public:
std::cout << label_ << "Elapsed time: " << elapsed.count() << " ms" << std::endl; std::cout << label_ << "Elapsed time: " << elapsed.count() << " ms" << std::endl;
} }
private: private:
std::string label_; std::string label_;
std::chrono::high_resolution_clock::time_point start_; std::chrono::high_resolution_clock::time_point start_;
}; };
...@@ -165,7 +166,9 @@ __global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel( ...@@ -165,7 +166,9 @@ __global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel(
const int wi = col - (hi * nlon_in); const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in; const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f; float qdotk = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip]; } for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip];
}
qdotk = __warp_sum_cub(qdotk); qdotk = __warp_sum_cub(qdotk);
qdotk_max = max(qdotk_max, qdotk); qdotk_max = max(qdotk_max, qdotk);
} }
......
...@@ -173,24 +173,24 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t ...@@ -173,24 +173,24 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t
switch (pscale) { switch (pscale) {
case 1: case 1:
disco_bwd_blk_k<NTH, ELXTH, 1> disco_bwd_blk_k<NTH, ELXTH, 1><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d,
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d); row_d, col_d, val_d, inp_d, out_d);
break; break;
case 2: case 2:
disco_bwd_blk_k<NTH, ELXTH, 2> disco_bwd_blk_k<NTH, ELXTH, 2><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d,
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d); row_d, col_d, val_d, inp_d, out_d);
break; break;
case 3: case 3:
disco_bwd_blk_k<NTH, ELXTH, 3> disco_bwd_blk_k<NTH, ELXTH, 3><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d,
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d); row_d, col_d, val_d, inp_d, out_d);
break; break;
default: default:
disco_bwd_blk_k<NTH, ELXTH, 0> disco_bwd_blk_k<NTH, ELXTH, 0><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d,
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d); row_d, col_d, val_d, inp_d, out_d);
} }
} else { } else {
launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d, launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d,
stream); out_d, stream);
} }
} }
return; return;
...@@ -231,36 +231,41 @@ torch::Tensor disco_cuda_bwd(torch::Tensor inp, torch::Tensor roff_idx, torch::T ...@@ -231,36 +231,41 @@ torch::Tensor disco_cuda_bwd(torch::Tensor inp, torch::Tensor roff_idx, torch::T
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<64, 1, scalar_t>( launch_kernel<64, 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
})); }));
} else if (Wo <= 128 * ELXTH_MAX) { } else if (Wo <= 128 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>( launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
})); }));
} else if (Wo <= 256 * ELXTH_MAX) { } else if (Wo <= 256 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>( launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
})); }));
} else if (Wo <= 512 * ELXTH_MAX) { } else if (Wo <= 512 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>( launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
})); }));
} else if (Wo <= 1024 * ELXTH_MAX) { } else if (Wo <= 1024 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>( launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
})); }));
} else { } else {
fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo, fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo,
......
...@@ -55,7 +55,8 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H ...@@ -55,7 +55,8 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H
REAL_T __reg[ELXTH] = {0}; REAL_T __reg[ELXTH] = {0};
// align to larger supported fp type // align to larger supported fp type
extern __shared__ __align__(sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)] extern __shared__ __align__(
sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)]
REAL_T *__sh = reinterpret_cast<REAL_T *>(__sh_ptr); REAL_T *__sh = reinterpret_cast<REAL_T *>(__sh_ptr);
int col_prev = cols[soff]; int col_prev = cols[soff];
...@@ -172,11 +173,11 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t ...@@ -172,11 +173,11 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t
const int pscale = Wi / Wo; const int pscale = Wi / Wo;
size_t shmem = sizeof(*out_d) * (Wi * 2 + pscale * (NTH * ELXTH - Wo)); size_t shmem = sizeof(*out_d) * (Wi * 2 + pscale * (NTH * ELXTH - Wo));
disco_fwd_blk_k<NTH, ELXTH> disco_fwd_blk_k<NTH, ELXTH><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d,
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d); col_d, val_d, inp_d, out_d);
} else { } else {
launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d, launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d,
stream); out_d, stream);
} }
} }
return; return;
...@@ -218,36 +219,41 @@ torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::T ...@@ -218,36 +219,41 @@ torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::T
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<64, 1, scalar_t>( launch_kernel<64, 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
})); }));
} else if (Wo <= 128 * ELXTH_MAX) { } else if (Wo <= 128 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>( launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
})); }));
} else if (Wo <= 256 * ELXTH_MAX) { } else if (Wo <= 256 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>( launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
})); }));
} else if (Wo <= 512 * ELXTH_MAX) { } else if (Wo <= 512 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>( launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
})); }));
} else if (Wo <= 1024 * ELXTH_MAX) { } else if (Wo <= 1024 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>( launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
})); }));
} else { } else {
fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo, fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment