Unverified Commit 49a61eee authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Merge pull request #83 from NVIDIA/maurob/devel

Optimized forward kernel for attention
parents c485a1fb c90b421a
...@@ -81,7 +81,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -81,7 +81,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
], ],
skip_on_empty=True, skip_on_empty=True,
) )
def test_custom_implementation(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True): def test_custom_implementation(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=False):
"""Tests numerical equivalence between the custom (CUDA) implementation and the reference torch implementation""" """Tests numerical equivalence between the custom (CUDA) implementation and the reference torch implementation"""
nlat_in, nlon_in = in_shape nlat_in, nlon_in = in_shape
...@@ -161,7 +161,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -161,7 +161,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
], ],
skip_on_empty=True, skip_on_empty=True,
) )
def test_neighborhood_global_equivalence(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True): def test_neighborhood_global_equivalence(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=False):
"""Tests numerical equivalence between the global spherical attention module and the neighborhood spherical attention module with the neighborhood set ot the whole sphere""" """Tests numerical equivalence between the global spherical attention module and the neighborhood spherical attention module with the neighborhood set ot the whole sphere"""
nlat_in, nlon_in = in_shape nlat_in, nlon_in = in_shape
...@@ -223,7 +223,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -223,7 +223,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
skip_on_empty=True, skip_on_empty=True,
) )
@unittest.skipUnless((torch.cuda.is_available() and _cuda_extension_available), "skipping performance test because CUDA is not available") @unittest.skipUnless((torch.cuda.is_available() and _cuda_extension_available), "skipping performance test because CUDA is not available")
def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True): def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=False):
# extract some parameters # extract some parameters
nlat_in, nlon_in = in_shape nlat_in, nlon_in = in_shape
......
...@@ -479,9 +479,10 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function): ...@@ -479,9 +479,10 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
qw = qw.reshape(B*nh, -1, H, W) qw = qw.reshape(B*nh, -1, H, W)
# convert to float32 # convert to float32
kw = kw.to(torch.float32) inp_dtype = kw.dtype
vw = vw.to(torch.float32) kw = kw.to(torch.float32).contiguous()
qw = qw.to(torch.float32) vw = vw.to(torch.float32).contiguous()
qw = qw.to(torch.float32).contiguous()
output = attention_cuda_extension.forward(kw, vw, qw, quad_weights, output = attention_cuda_extension.forward(kw, vw, qw, quad_weights,
col_idx, row_off, col_idx, row_off,
...@@ -490,6 +491,9 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function): ...@@ -490,6 +491,9 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
_, C, H, W = output.shape _, C, H, W = output.shape
output = output.reshape(B, -1, H, W) output = output.reshape(B, -1, H, W)
# convert back precision
output = output.to(dtype=inp_dtype)
return output return output
@staticmethod @staticmethod
......
...@@ -291,7 +291,7 @@ class NeighborhoodAttentionS2(nn.Module): ...@@ -291,7 +291,7 @@ class NeighborhoodAttentionS2(nn.Module):
# set the last value # set the last value
row_offset[row + 1] = idz + 1 row_offset[row + 1] = idz + 1
row_offset = torch.from_numpy(row_offset) row_offset = torch.from_numpy(row_offset).contiguous()
self.max_psi_nnz = col_idx.max().item() + 1 self.max_psi_nnz = col_idx.max().item() + 1
self.register_buffer("psi_row_idx", row_idx, persistent=False) self.register_buffer("psi_row_idx", row_idx, persistent=False)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// //
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met: // modification, are permitted provided that the following conditions are met:
// //
...@@ -39,111 +39,214 @@ ...@@ -39,111 +39,214 @@
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <limits> #include <limits>
using BlockReduceFloat256 = cub::BlockReduce<float, 256>;
using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
#define WARP_SIZE (32) #define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF) #define FULL_MASK (0xFFFFFFFF)
#define THREADS (64) #define THREADS (64)
#define DIV_UP(a, b) (((a) + ((b)-1)) / (b)) #define DIV_UP(a,b) (((a)+((b)-1))/(b))
#define NNZ_TRESH (32) #define TRANSP_WARPS_X_TILE_GENERIC (32)
#define TRANSP_WARPS_X_TILE_SM100 (4)
#define CHECK_CUDA(call) \ #define MAX_LOCAL_ARR_LEN (16)
{ \
cudaError_t err = call; \ #define CHECK_CUDA(call) { \
if (cudaSuccess != err) { \ cudaError_t err = call; \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", __FILE__, __LINE__, cudaGetErrorString(err)); \ if( cudaSuccess != err) { \
exit(EXIT_FAILURE); \ fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \
} \ __FILE__, __LINE__, cudaGetErrorString( err) ); \
} exit(EXIT_FAILURE); \
}}
#define CHECK_ERROR(errorMessage) { \
cudaError_t err = cudaGetLastError(); \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
// BEGIN - forward kernels and functions
#define CHECK_ERROR(errorMessage) \ template<typename VAL_T>
{ \ __device__ VAL_T __warp_sum(VAL_T val) {
cudaError_t err = cudaGetLastError(); \
if (cudaSuccess != err) { \ #pragma unroll
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", errorMessage, __FILE__, __LINE__, \ for(int i = WARP_SIZE/2; i; i /= 2) {
cudaGetErrorString(err)); \ val += __shfl_xor_sync(FULL_MASK, val, i);
exit(EXIT_FAILURE); \
} \
} }
return val;
}
template<int BDIM_X,
int BDIM_Y=1,
int BDIM_Z=1,
typename VAL_T>
__device__ VAL_T __block_sum(VAL_T val) {
const int NWARP = (BDIM_X*BDIM_Y*BDIM_Z) / WARP_SIZE;
val = __warp_sum(val);
static __device__ float __warp_sum(float val) if constexpr(NWARP > 1) {
{
#pragma unroll int tid = threadIdx.x;
for (int i = WARP_SIZE / 2; i; i /= 2) { val += __shfl_xor_sync(FULL_MASK, val, i); } if constexpr(BDIM_Y > 1) { tid += threadIdx.y*BDIM_X; }
if constexpr(BDIM_Z > 1) { tid += threadIdx.z*BDIM_X*BDIM_Y; }
const int lid = tid%WARP_SIZE;
const int wid = tid/WARP_SIZE;
__shared__ VAL_T sh[NWARP];
if (lid == 0) {
sh[wid] = val;
}
__syncthreads();
if (wid == 0) {
val = (lid < NWARP) ? sh[lid] : 0;
val = __warp_sum(val);
__syncwarp();
if (!lid) {
sh[0] = val;
}
}
__syncthreads();
val = sh[0];
__syncthreads();
}
return val; return val;
} }
// easier to understand version of manual shfl_xor_sync, performance appears similar template<typename FLOATV_T>
static __device__ float __warp_sum_cub(float val) __device__ FLOATV_T __vset(float x) {}
{
// use cub to reduce within a warp template<>
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage; __device__ float __forceinline__ __vset<float>(float x) {
return x;
}
__device__ float __forceinline__ __vmul(float a, float b) {
return a*b;
}
__device__ float __forceinline__ __vadd(float a, float b) {
return a+b;
}
__device__ float __forceinline__ __vred(float a) {
return a;
}
__device__ float __forceinline__ __vscale(float s, float v) {
return v*s;
}
__device__ float __forceinline__ __vdiv(float s, float v) {
return v/s;
}
template<>
__device__ float4 __forceinline__ __vset<float4>(float x) {
return make_float4(x, x, x, x);
}
__device__ float4 __forceinline__ __vmul(float4 a, float4 b) {
return make_float4(a.x*b.x, a.y*b.y, a.z*b.z, a.w*b.w);
}
// 1. Compute sum (initially only in lane 0) __device__ float4 __forceinline__ __vadd(float4 a, float4 b) {
float sum = cub::WarpReduce<float>(temp_storage).Sum(val); return make_float4(a.x+b.x, a.y+b.y, a.z+b.z, a.w+b.w);
// 2. Broadcast sum to all threads
sum = __shfl_sync(0xFFFFFFFF, sum, 0);
return sum;
} }
// one warp per (ho,wo) __device__ float __forceinline__ __vred(float4 a) {
template <int BDIM_X> return a.x + a.y + a.z + a.w;
__global__ __launch_bounds__(BDIM_X) void s2_attention_kernel( }
int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> y,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{
extern __shared__ float sh[]; __device__ float4 __forceinline__ __vscale(float s, float4 v) {
float *shy = sh + threadIdx.y * num_channels; return make_float4(s*v.x, s*v.y, s*v.z, s*v.w);
}
const uint64_t batchId = blockIdx.y; __device__ float4 __forceinline__ __vdiv(float s, float4 v) {
const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y; return make_float4(s/v.x, s/v.y, s/v.z, s/v.w);;
}
if (wid >= uint64_t(nlat_out) * nlon_in) { return; } // called with (blockDim.x=32 and blockDim.y>1, BDIM=blockDim.x*blockDim.y)
template<int BDIM,
typename FLOATV_T> // either float or float4
__global__
__launch_bounds__(BDIM)
void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along channel dim
int nlat_in,
int nlon_in,
int nlat_out,
int nlon_out,
const FLOATV_T *__restrict__ kx,
const FLOATV_T *__restrict__ vx,
const FLOATV_T *__restrict__ qy,
const torch::PackedTensorAccessor32< int, 1, torch::RestrictPtrTraits> row_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> row_off,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> col_idx,
const torch::PackedTensorAccessor32< float, 1, torch::RestrictPtrTraits> quad_weights,
FLOATV_T *__restrict__ y) {
extern __shared__ __align__(sizeof(float4)) float shext[];
FLOATV_T *shy = reinterpret_cast<FLOATV_T *>(shext) + threadIdx.y*nchan;
const int batch = blockIdx.y;
const int wid = blockIdx.x*blockDim.y + threadIdx.y;
if (wid >= nlat_out*nlon_out) {
return;
}
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
const int ho = wid / nlon_out; const int h = wid / nlon_out;
const int wo = wid - (ho * nlon_out); const int wo = wid - (h*nlon_out);
const int ho = row_idx[h];
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { for(int chan = tidx; chan < nchan; chan += WARP_SIZE) {
#if 0 shy[chan] = __vset<FLOATV_T>(0.f);
// useless read, y is always zeroed before kernel is called
shy[chan] = y[batchId][chan][ho][wo];
#else
shy[chan] = 0;
#endif
} }
kx += batch*nlat_in*nlon_in*nchan;
vx += batch*nlat_in*nlon_in*nchan;
qy += batch*nlat_out*nlon_out*nchan + ho*nchan*nlon_out + wo*nchan;
y += batch*nlat_out*nlon_out*nchan + ho*nchan*nlon_out + wo*nchan;
float alpha_sum = 0.0f; float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX; float qdotk_max = -FLT_MAX;
const int64_t rbeg = psi_row_offset[ho]; const int64_t rbeg = row_off[ho];
const int64_t rend = psi_row_offset[ho + 1]; const int64_t rend = row_off[ho+1];
const int rlen = rend - rbeg; const int rlen = rend-rbeg;
for (int off = 0; off < rlen; off++) { for(int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off]; const int64_t col = col_idx[rbeg+off];
const int hi = col / nlon_in; const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in); const int wi = col - (hi*nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in; const int wip = (wi+wo) - ((wi+wo) / nlon_in) * nlon_in;
float qdotk = 0.0f; const FLOATV_T *_kx = kx + hi*nlon_in*nchan + wip*nchan;
const FLOATV_T *_vx = vx + hi*nlon_in*nchan + wip*nchan;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { FLOATV_T qdotkv = __vset<FLOATV_T>(0.f);
qdotk += qy[batchId][chan][ho][wo] * kx[batchId][chan][hi][wip];
for(int chan = tidx; chan < nchan; chan += WARP_SIZE) {
qdotkv = __vadd(qdotkv,
__vmul( qy[chan],
_kx[chan]));
} }
qdotk = __warp_sum_cub(qdotk);
float qdotk = __warp_sum(__vred(qdotkv));
float qdotk_max_tmp; float qdotk_max_tmp;
float alpha; float alpha;
...@@ -153,24 +256,626 @@ __global__ __launch_bounds__(BDIM_X) void s2_attention_kernel( ...@@ -153,24 +256,626 @@ __global__ __launch_bounds__(BDIM_X) void s2_attention_kernel(
alpha = expf(qdotk - qdotk_max_tmp) * quad_weights[hi]; alpha = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
exp_save = expf(qdotk_max - qdotk_max_tmp); exp_save = expf(qdotk_max - qdotk_max_tmp);
alpha_sum = alpha + alpha_sum * exp_save; alpha_sum = alpha + alpha_sum*exp_save;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { for(int chan = tidx; chan < nchan; chan += WARP_SIZE) {
shy[chan] = shy[chan] * exp_save + vx[batchId][chan][hi][wip] * alpha; shy[chan] = __vadd(__vscale(exp_save, shy[chan]),
__vscale( alpha, _vx[chan]));
} }
qdotk_max = qdotk_max_tmp; qdotk_max = qdotk_max_tmp;
} }
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { y[batchId][chan][ho][wo] = shy[chan] / alpha_sum; } alpha_sum = 1.0f / alpha_sum;
for(int chan = tidx; chan < nchan; chan += WARP_SIZE) {
y[chan] = __vscale(alpha_sum, shy[chan]);
}
return;
}
template<typename FLOATV_T>
void launch_gen_attn_kernel(int batch_size,
int nloc,
int nchans,
int nlat_in,
int nlon_in,
int nlat_out,
int nlon_out,
FLOATV_T *__restrict__ _kxp,
FLOATV_T *__restrict__ _vxp,
FLOATV_T *__restrict__ _qyp,
at::Tensor row_idx,
at::Tensor row_off,
at::Tensor col_idx,
at::Tensor quad_weights,
FLOATV_T *__restrict__ _yp,
cudaStream_t stream) {
dim3 block(WARP_SIZE, THREADS/WARP_SIZE);
dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size);
size_t shsize = sizeof(FLOATV_T)*nchans * block.y;
auto _row_idx = row_idx.packed_accessor32< int, 1, torch::RestrictPtrTraits>();
auto _row_off = row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>();
auto _col_idx = col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>();
auto _quad_weights = quad_weights.packed_accessor32< float, 1, torch::RestrictPtrTraits>();
s2_attn_fwd_generic_vec_k<THREADS>
<<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp);
return; return;
} }
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy, at::Tensor quad_weights, // called with either (BDIM_X=32 and BDIM_Y>1) || (2^K=BDIM_X > 32 and BDIM_Y=1)
at::Tensor psi_col_idx, at::Tensor psi_row_off, int nlon_in, int nlat_out, template<int BDIM_X,
int nlon_out) int BDIM_Y,
{ int NLOC,
typename FLOATV_T> // either float or float4
__global__
__launch_bounds__(BDIM_X*BDIM_Y)
void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along channel dim
int nlat_in,
int nlon_in,
int nlat_out,
int nlon_out,
const FLOATV_T *__restrict__ kx,
const FLOATV_T *__restrict__ vx,
const FLOATV_T *__restrict__ qy,
const torch::PackedTensorAccessor32< int, 1, torch::RestrictPtrTraits> row_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> row_off,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> col_idx,
const torch::PackedTensorAccessor32< float, 1, torch::RestrictPtrTraits> quad_weights,
FLOATV_T *__restrict__ y) {
static_assert(0 == (BDIM_X & (BDIM_X-1)));
static_assert(0 == (BDIM_Y & (BDIM_Y-1)));
static_assert((BDIM_X == 32 && BDIM_Y > 1) ||
(BDIM_X > 32 && BDIM_Y == 1)) ;
constexpr int NLOC_M1 = NLOC-1;
const int tidx = threadIdx.x;
const int batch = blockIdx.y;
const int ctaid = blockIdx.x*blockDim.y + threadIdx.y;
if (ctaid >= nlat_out*nlon_out) {
return;
}
FLOATV_T locy[NLOC];
extern __shared__ __align__(sizeof(float4)) float shext[];
FLOATV_T *shq = reinterpret_cast<FLOATV_T *>(shext) + threadIdx.y*nchan + tidx;
const int h = ctaid / nlon_out;
const int wo = ctaid - (h*nlon_out);
const int ho = row_idx[h];
kx += batch*nlat_in*nlon_in*nchan + tidx;
vx += batch*nlat_in*nlon_in*nchan + tidx;
qy += batch*nlat_out*nlon_out*nchan + ho*nlon_out*nchan + wo*nchan + tidx;
y += batch*nlat_out*nlon_out*nchan + ho*nlon_out*nchan + wo*nchan + tidx;
#pragma unroll
for(int i = 0; i < NLOC; i++) {
locy[i] = __vset<FLOATV_T>(0.f);
}
#pragma unroll
for(int i = 0; i < NLOC_M1; i++) {
shq[i*BDIM_X] = qy[i*BDIM_X];
}
if (NLOC_M1*BDIM_X+tidx < nchan) {
shq[NLOC_M1*BDIM_X] = qy[NLOC_M1*BDIM_X];
}
float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX;
const int64_t rbeg = row_off[ho];
const int64_t rend = row_off[ho+1];
const int rlen = rend-rbeg;
for(int off = 0; off < rlen; off++) {
const int64_t col = col_idx[rbeg+off];
const int hi = col / nlon_in;
const int wi = col - (hi*nlon_in);
const int wip = (wi+wo) - ((wi+wo) / nlon_in) * nlon_in;
const FLOATV_T *_kx = kx + hi*nlon_in*nchan + wip*nchan;
const FLOATV_T *_vx = vx + hi*nlon_in*nchan + wip*nchan;
FLOATV_T qdotkv = __vset<FLOATV_T>(0.f);
#pragma unroll
for(int i = 0; i < NLOC_M1; i++) {
qdotkv = __vadd(qdotkv,
__vmul(shq[i*BDIM_X],
_kx[i*BDIM_X]));
}
if (NLOC_M1*BDIM_X+tidx < nchan) {
qdotkv = __vadd(qdotkv,
__vmul(shq[NLOC_M1*BDIM_X],
_kx[NLOC_M1*BDIM_X]));
}
float qdotk = __vred(qdotkv);
if constexpr(BDIM_X == 32) { qdotk = __warp_sum(qdotk); }
else { qdotk = __block_sum<BDIM_X>(qdotk); }
float qdotk_max_tmp;
float alpha;
float exp_save;
qdotk_max_tmp = max(qdotk_max, qdotk);
alpha = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
exp_save = expf(qdotk_max - qdotk_max_tmp);
alpha_sum = alpha + alpha_sum*exp_save;
#pragma unroll
for(int i = 0; i < NLOC_M1; i++) {
locy[i] = __vadd(__vscale(exp_save, locy[i]),
__vscale(alpha, _vx[i*BDIM_X]));
}
if (NLOC_M1*BDIM_X+tidx < nchan) {
locy[NLOC_M1] = __vadd(__vscale(exp_save, locy[NLOC_M1]),
__vscale(alpha, _vx[NLOC_M1*BDIM_X]));
}
qdotk_max = qdotk_max_tmp;
}
alpha_sum = 1.0f / alpha_sum;
#pragma unroll
for(int i = 0; i < NLOC_M1; i++) {
y[i*BDIM_X] = __vscale(alpha_sum, locy[i]);
}
if (NLOC_M1*BDIM_X+tidx < nchan) {
y[NLOC_M1*BDIM_X] = __vscale(alpha_sum, locy[NLOC_M1]);
}
return;
}
template<int BDIM_X,
int BDIM_Y,
int CUR_LOC_SIZE,
int MAX_LOC_SIZE, // max size of FLOATV_T[] local array
typename FLOATV_T>
void launch_spc_attn_kernel(int batch_size,
int nloc, // "BDIM_X*nloc" >= nchans
int nchans,
int nlat_in,
int nlon_in,
int nlat_out,
int nlon_out,
FLOATV_T *__restrict__ _kxp,
FLOATV_T *__restrict__ _vxp,
FLOATV_T *__restrict__ _qyp,
at::Tensor row_idx,
at::Tensor row_off,
at::Tensor col_idx,
at::Tensor quad_weights,
FLOATV_T *__restrict__ _yp,
cudaStream_t stream) {
if (CUR_LOC_SIZE == nloc) {
auto _row_idx = row_idx.packed_accessor32< int, 1, torch::RestrictPtrTraits>();
auto _row_off = row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>();
auto _col_idx = col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>();
auto _quad_weights = quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>();
dim3 block(BDIM_X, BDIM_Y);
dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size);
//printf("block: (%d, %d)\n", block.x, block.y);
//printf("grid: (%d, %d)\n", grid.x, grid.y);
size_t shsize = sizeof(FLOATV_T)*nchans * block.y; // block.y > 1 iif block.x==32
s2_attn_fwd_special_vec_k<BDIM_X, BDIM_Y, CUR_LOC_SIZE>
<<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp);
return;
}
if constexpr(CUR_LOC_SIZE < MAX_LOC_SIZE) {
launch_spc_attn_kernel<BDIM_X,
BDIM_Y,
CUR_LOC_SIZE+1,
MAX_LOC_SIZE>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp,
stream);
}
return;
}
__global__ void set_rlen_rids_k(const int n,
const int64_t *__restrict__ offs,
int *__restrict__ rids,
int *__restrict__ rlen) {
const int nth = gridDim.x*blockDim.x;
const int tid = blockIdx.x*blockDim.x + threadIdx.x;
for(int i = tid; i < n; i += nth) {
rids[i] = i;
rlen[i] = offs[i+1]-offs[i];
}
return;
}
at::Tensor sortRows(int nlat_out, at::Tensor row_off, cudaStream_t stream) {
int64_t *_row_off_d = reinterpret_cast<int64_t *>(row_off.data_ptr());
auto options = torch::TensorOptions().dtype(torch::kInt32).device(row_off.device());
torch::Tensor rids_d = torch::empty({nlat_out}, options);
torch::Tensor rlen_d = torch::empty({nlat_out}, options);
int *_rids_d = reinterpret_cast<int *>(rids_d.data_ptr());
int *_rlen_d = reinterpret_cast<int *>(rlen_d.data_ptr());
const int grid = DIV_UP(nlat_out, THREADS);
const int block = THREADS;
set_rlen_rids_k<<<grid, block, 0, stream>>>(nlat_out,
_row_off_d,
_rids_d,
_rlen_d);
torch::Tensor rids_sort_d = torch::empty({nlat_out}, options);
torch::Tensor rlen_sort_d = torch::empty({nlat_out}, options);
int *_rids_sort_d = reinterpret_cast<int *>(rids_sort_d.data_ptr());
int *_rlen_sort_d = reinterpret_cast<int *>(rlen_sort_d.data_ptr());
size_t temp_storage_bytes = 0;
CHECK_CUDA(cub::DeviceRadixSort::SortPairsDescending(NULL, temp_storage_bytes,
_rlen_d, _rlen_sort_d,
_rids_d, _rids_sort_d,
nlat_out, 0, sizeof(*_rlen_d)*8, stream));
options = torch::TensorOptions().dtype(torch::kByte).device(row_off.device());
torch::Tensor temp_storage_d = torch::empty({int64_t(temp_storage_bytes)}, options);
void *_temp_storage_d = reinterpret_cast<void *>(temp_storage_d.data_ptr());
CHECK_CUDA(cub::DeviceRadixSort::SortPairsDescending(_temp_storage_d, temp_storage_bytes,
_rlen_d, _rlen_sort_d,
_rids_d, _rids_sort_d,
nlat_out, 0, sizeof(*_rlen_d)*8, stream));
return rids_sort_d;
}
template<unsigned int ALIGN>
int is_aligned(const void *ptr) {
static_assert(0 == (ALIGN & (ALIGN-1)));
return (0 == (uintptr_t(ptr) & (ALIGN-1)));
}
static unsigned int next_pow2(unsigned int x) {
x -= 1;
#pragma unroll
for(int i = 1; i <= sizeof(x)*8 / 2; i *= 2) {
x |= x >> i;
}
return x+1;
}
static void s2_attention_dipatch(int batch_size,
int nchans,
int nlon_in,
int nlat_out,
int nlon_out,
at::Tensor kxP,
at::Tensor vxP,
at::Tensor qyP,
at::Tensor row_off,
at::Tensor col_idx,
at::Tensor quad_weights,
at::Tensor yP,
cudaStream_t stream) {
static_assert(0 == (MAX_LOCAL_ARR_LEN & (MAX_LOCAL_ARR_LEN-1)));
// sort row indices (ho-s) in descending order
// based on (row_off[ho+1]-row_off[ho])
at::Tensor row_idx = sortRows(nlat_out, row_off, stream);
const int nlat_in = kxP.size(1);
// smallest power of two "bdimx" (>=32) s.t. bdimx*MAX_LOCAL_ARR_LEN >= nchans
int bdimx;
bdimx = DIV_UP(nchans, MAX_LOCAL_ARR_LEN);
bdimx = max(bdimx, WARP_SIZE);
bdimx = next_pow2(bdimx);
float *_kxp = reinterpret_cast<float *>(kxP.data_ptr());
float *_vxp = reinterpret_cast<float *>(vxP.data_ptr());
float *_qyp = reinterpret_cast<float *>(qyP.data_ptr());
float *_yp = reinterpret_cast<float *>(yP.data_ptr());
constexpr int VEC_SIZE = sizeof(float4) / sizeof(float);
if (!is_aligned<sizeof(float4)>(_kxp) ||
!is_aligned<sizeof(float4)>(_vxp) ||
!is_aligned<sizeof(float4)>(_qyp) ||
!is_aligned<sizeof(float4)>(_yp) ||
(nchans % VEC_SIZE) != 0) {
const int nloc = DIV_UP(nchans, bdimx);
// use 2D blocks only if 32 threads are enough
switch(bdimx) {
case 32: launch_spc_attn_kernel< 32, 2, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
case 64: launch_spc_attn_kernel< 64, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
case 128: launch_spc_attn_kernel< 128, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
case 256: launch_spc_attn_kernel< 256, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
case 512: launch_spc_attn_kernel< 512, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
case 1024: launch_spc_attn_kernel<1024, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
default: launch_gen_attn_kernel (batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
}
} else {
float4 *_kxp4 = reinterpret_cast<float4 *>(_kxp);
float4 *_vxp4 = reinterpret_cast<float4 *>(_vxp);
float4 *_qyp4 = reinterpret_cast<float4 *>(_qyp);
float4 *_yp4 = reinterpret_cast<float4 *>(_yp);
nchans /= VEC_SIZE;
const int nloc = DIV_UP(nchans, bdimx);
static constexpr int MAX_LOCAL_VEC_LEN = MAX_LOCAL_ARR_LEN / VEC_SIZE;
// use 2D blocks only if 32 threads are enough
switch(bdimx) {
case 32: launch_spc_attn_kernel< 32, 2, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
case 64: launch_spc_attn_kernel< 64, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
case 128: launch_spc_attn_kernel< 128, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
case 256: launch_spc_attn_kernel< 256, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
case 512: launch_spc_attn_kernel< 512, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
case 1024: launch_spc_attn_kernel<1024, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
default: launch_gen_attn_kernel (batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
}
}
return;
}
// END - forward kernels and functions
// BEGIN - tensor permutation kernels and functions
template<int BDIM_X,
int BDIM_Y,
typename VAL_T>
__global__
__launch_bounds__(BDIM_X*BDIM_Y)
void permute_to0231_k(const int nchn,
const int nlat,
const int nlon,
const torch::PackedTensorAccessor32<VAL_T, 4, torch::RestrictPtrTraits> src,
torch::PackedTensorAccessor32<VAL_T, 4, torch::RestrictPtrTraits> dst) {
static_assert(!(BDIM_X & (BDIM_X-1)));
static_assert(!(BDIM_Y & (BDIM_Y-1)));
static_assert(BDIM_X >= BDIM_Y);
__shared__ VAL_T sh[BDIM_X][BDIM_X+1];
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int coff = blockIdx.x*BDIM_X; // channel offset
const int woff = blockIdx.y*BDIM_X; // width offset
const int batch = blockIdx.z / nlat; // batch (same for all block)
const int h = blockIdx.z - (batch * nlat); // height (same for all block)
const int nchn_full = (nchn-coff) >= BDIM_X;
const int nlon_full = (nlon-woff) >= BDIM_X;
if (nchn_full && nlon_full) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = src[batch][coff + j+tidy][h][woff+tidx];
}
__syncthreads();
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy];
}
} else {
if (woff+tidx < nlon) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = (coff + j+tidy < nchn) ? src[batch][coff + j+tidy][h][woff+tidx] : 0.f;
}
}
__syncthreads();
if (coff+tidx < nchn) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
if (woff + j+tidy < nlon) {
dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy];
}
}
}
}
return;
}
__global__ void empty_k() {}
static int getPtxver() {
cudaFuncAttributes attrs;
CHECK_CUDA(cudaFuncGetAttributes(&attrs, empty_k));
return attrs.ptxVersion*10;
}
static at::Tensor permute_4D_floatT_to0231(at::Tensor src, cudaStream_t stream) {
dim3 block;
dim3 grid;
block.x = WARP_SIZE;
grid.x = DIV_UP(src.size(1), block.x);
grid.y = DIV_UP(src.size(3), block.x);
grid.z = src.size(2)*src.size(0);
assert(grid.y < 65536);
assert(grid.z < 65536);
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(src.device());
torch::Tensor dst = torch::empty({src.size(0), src.size(2), src.size(3), src.size(1)}, options);
const int ptxv = getPtxver();
// to be further specialized for additional archs, if necessary
if (ptxv < 100) {
block.y = TRANSP_WARPS_X_TILE_GENERIC;
permute_to0231_k<WARP_SIZE, TRANSP_WARPS_X_TILE_GENERIC>
<<<grid, block, 0, stream>>>(src.size(1),
src.size(2),
src.size(3),
src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
} else {
block.y = TRANSP_WARPS_X_TILE_SM100;
permute_to0231_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SM100>
<<<grid, block, 0, stream>>>(src.size(1),
src.size(2),
src.size(3),
src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
}
return dst;
}
template<int BDIM_X,
int BDIM_Y,
typename VAL_T>
__global__
__launch_bounds__(BDIM_X*BDIM_Y)
void permute_to0312_k(const int nchn,
const int nlat,
const int nlon,
const torch::PackedTensorAccessor32<VAL_T, 4, torch::RestrictPtrTraits> src,
torch::PackedTensorAccessor32<VAL_T, 4, torch::RestrictPtrTraits> dst) {
static_assert(!(BDIM_X & (BDIM_X-1)));
static_assert(!(BDIM_Y & (BDIM_Y-1)));
static_assert(BDIM_X >= BDIM_Y);
__shared__ VAL_T sh[BDIM_X][BDIM_X+1];
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int woff = blockIdx.x*BDIM_X; // width offset
const int coff = blockIdx.y*BDIM_X; // channel offset
const int batch = blockIdx.z / nlat; // batch (same for all block)
const int h = blockIdx.z - (batch * nlat); // height (same for all block)
const int nchn_full = (nchn-coff) >= BDIM_X;
const int nlon_full = (nlon-woff) >= BDIM_X;
if (nchn_full && nlon_full) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = src[batch][h][woff + j+tidy][coff+tidx];
}
__syncthreads();
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];
}
} else {
if (coff+tidx < nchn) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = (woff + j+tidy < nlon) ? src[batch][h][woff + j+tidy][coff+tidx] : 0.f;
}
}
__syncthreads();
if (woff+tidx < nlon) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
if (coff + j+tidy < nchn) {
dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];;
}
}
}
}
return;
}
static at::Tensor permute_4D_floatT_to0312(at::Tensor src, cudaStream_t stream) {
dim3 block;
dim3 grid;
block.x = WARP_SIZE;
grid.x = DIV_UP(src.size(2), block.x);
grid.y = DIV_UP(src.size(3), block.x);
grid.z = src.size(1)*src.size(0);
assert(grid.y < 65536);
assert(grid.z < 65536);
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(src.device());
torch::Tensor dst = torch::empty({src.size(0), src.size(3), src.size(1), src.size(2)}, options);
const int ptxv = getPtxver();
// to be further specialized for additional archs, if necessary
if (ptxv < 100) {
block.y = TRANSP_WARPS_X_TILE_GENERIC;
permute_to0312_k<WARP_SIZE, TRANSP_WARPS_X_TILE_GENERIC>
<<<grid, block, 0, stream>>>(src.size(3),
src.size(1),
src.size(2),
src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
} else {
block.y = TRANSP_WARPS_X_TILE_SM100;
permute_to0312_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SM100>
<<<grid, block, 0, stream>>>(src.size(3),
src.size(1),
src.size(2),
src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
}
return dst;
}
// END - tensor permutation kernels and functions
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in,
int nlat_out,
int nlon_out) {
CHECK_CUDA_TENSOR(kx); CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx); CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy); CHECK_CUDA_TENSOR(qy);
...@@ -186,68 +891,34 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy, ...@@ -186,68 +891,34 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy,
const int batch_size = kx.size(0); const int batch_size = kx.size(0);
torch::Tensor kxP = kx;
torch::Tensor vxP = vx;
torch::Tensor qyP = qy;
auto k_channel_first = kx.strides()[1] == 1; auto k_channel_first = kx.strides()[1] == 1;
auto v_channel_first = vx.strides()[1] == 1; auto v_channel_first = vx.strides()[1] == 1;
auto q_channel_first = qy.strides()[1] == 1; auto q_channel_first = qy.strides()[1] == 1;
// transpose inputs so that channels are in the last dimension, allowing for if (!k_channel_first) { kxP = permute_4D_floatT_to0231(kx, stream); }
// coalesced memory access if (!v_channel_first) { vxP = permute_4D_floatT_to0231(vx, stream); }
nvtxRangePush("s2_attention_fwd_kernel_mbT permute inputs"); if (!q_channel_first) { qyP = permute_4D_floatT_to0231(qy, stream); }
// Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto kxP = at::Tensor(); torch::Tensor yP = torch::empty_like(qyP);
if (!k_channel_first) {
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n"); s2_attention_dipatch(batch_size,
kxP = kx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2}); uo_num_channels,
} else { nlon_in,
kxP = kx; nlat_out,
} nlon_out,
auto vxP = at::Tensor(); kxP, vxP, qyP,
if (!v_channel_first) { psi_row_off,
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n"); psi_col_idx,
vxP = vx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2}); quad_weights,
} else { yP, // out tensor
vxP = vx; stream);
}
auto qyP = at::Tensor(); torch::Tensor y = yP;
if (!q_channel_first) { if (!q_channel_first) { y = permute_4D_floatT_to0312(yP, stream); }
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP = qy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
qyP = qy;
}
cudaDeviceSynchronize();
nvtxRangePop();
torch::Tensor y = torch::empty_like(qy);
dim3 block(WARP_SIZE, THREADS / WARP_SIZE);
dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size);
size_t shared_size = sizeof(float) * uo_num_channels * block.y;
cudaEvent_t start, stop;
float milliseconds = 0;
CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_kernel<THREADS><<<grid, block, shared_size, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
y.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// printf("s2_attention_kernel_fwd execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
// match output layout to input layout
if (!q_channel_first) y = y.contiguous();
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment