Commit c9ae23d9 authored by Mauro Bisson's avatar Mauro Bisson
Browse files

Updated pointer offset calculations to use 64-bit integers to prevent overflow...

Updated pointer offset calculations to use 64-bit integers to prevent overflow with large batch or image sizes.
parent 1eed5673
......@@ -214,10 +214,10 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha
shy[chan] = __vset<FLOATV_T>(0.f);
}
kx += batch*nlat_in*nlon_in*nchan;
vx += batch*nlat_in*nlon_in*nchan;
qy += batch*nlat_out*nlon_out*nchan + ho*nchan*nlon_out + wo*nchan;
y += batch*nlat_out*nlon_out*nchan + ho*nchan*nlon_out + wo*nchan;
kx += int64_t(batch)*nlat_in*nlon_in*nchan;
vx += int64_t(batch)*nlat_in*nlon_in*nchan;
qy += int64_t(batch)*nlat_out*nlon_out*nchan + int64_t(ho)*nchan*nlon_out + int64_t(wo)*nchan;
y += int64_t(batch)*nlat_out*nlon_out*nchan + int64_t(ho)*nchan*nlon_out + int64_t(wo)*nchan;
float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX;
......@@ -235,8 +235,8 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha
const int wi = col - (hi*nlon_in);
const int wip = (wi+wo) - ((wi+wo) / nlon_in) * nlon_in;
const FLOATV_T *_kx = kx + hi*nlon_in*nchan + wip*nchan;
const FLOATV_T *_vx = vx + hi*nlon_in*nchan + wip*nchan;
const FLOATV_T *_kx = kx + int64_t(hi)*nlon_in*nchan + int64_t(wip)*nchan;
const FLOATV_T *_vx = vx + int64_t(hi)*nlon_in*nchan + int64_t(wip)*nchan;
FLOATV_T qdotkv = __vset<FLOATV_T>(0.f);
......@@ -275,7 +275,6 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha
template<typename FLOATV_T>
void launch_gen_attn_kernel(int batch_size,
int nloc,
int nchans,
int nlat_in,
int nlon_in,
......@@ -352,10 +351,10 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan
const int wo = ctaid - (h*nlon_out);
const int ho = row_idx[h];
kx += batch*nlat_in*nlon_in*nchan + tidx;
vx += batch*nlat_in*nlon_in*nchan + tidx;
qy += batch*nlat_out*nlon_out*nchan + ho*nlon_out*nchan + wo*nchan + tidx;
y += batch*nlat_out*nlon_out*nchan + ho*nlon_out*nchan + wo*nchan + tidx;
kx += int64_t(batch)*nlat_in*nlon_in*nchan + tidx;
vx += int64_t(batch)*nlat_in*nlon_in*nchan + tidx;
qy += int64_t(batch)*nlat_out*nlon_out*nchan + int64_t(ho)*nlon_out*nchan + int64_t(wo)*nchan + tidx;
y += int64_t(batch)*nlat_out*nlon_out*nchan + int64_t(ho)*nlon_out*nchan + int64_t(wo)*nchan + tidx;
#pragma unroll
for(int i = 0; i < NLOC; i++) {
......@@ -386,8 +385,8 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan
const int wi = col - (hi*nlon_in);
const int wip = (wi+wo) - ((wi+wo) / nlon_in) * nlon_in;
const FLOATV_T *_kx = kx + hi*nlon_in*nchan + wip*nchan;
const FLOATV_T *_vx = vx + hi*nlon_in*nchan + wip*nchan;
const FLOATV_T *_kx = kx + int64_t(hi)*nlon_in*nchan + int64_t(wip)*nchan;
const FLOATV_T *_vx = vx + int64_t(hi)*nlon_in*nchan + int64_t(wip)*nchan;
FLOATV_T qdotkv = __vset<FLOATV_T>(0.f);
......@@ -626,7 +625,7 @@ static void s2_attention_dipatch(int batch_size,
case 256: launch_spc_attn_kernel< 256, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
case 512: launch_spc_attn_kernel< 512, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
case 1024: launch_spc_attn_kernel<1024, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
default: launch_gen_attn_kernel (batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
default: launch_gen_attn_kernel (batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
}
} else {
......@@ -649,7 +648,7 @@ static void s2_attention_dipatch(int batch_size,
case 256: launch_spc_attn_kernel< 256, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
case 512: launch_spc_attn_kernel< 512, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
case 1024: launch_spc_attn_kernel<1024, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
default: launch_gen_attn_kernel (batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
default: launch_gen_attn_kernel (batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment