Unverified Commit 744d2269 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Merge pull request #87 from NVIDIA/maurob/devel

Use 64-bit for pointer offsets
parents 1eed5673 c9ae23d9
...@@ -214,10 +214,10 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha ...@@ -214,10 +214,10 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha
shy[chan] = __vset<FLOATV_T>(0.f); shy[chan] = __vset<FLOATV_T>(0.f);
} }
kx += batch*nlat_in*nlon_in*nchan; kx += int64_t(batch)*nlat_in*nlon_in*nchan;
vx += batch*nlat_in*nlon_in*nchan; vx += int64_t(batch)*nlat_in*nlon_in*nchan;
qy += batch*nlat_out*nlon_out*nchan + ho*nchan*nlon_out + wo*nchan; qy += int64_t(batch)*nlat_out*nlon_out*nchan + int64_t(ho)*nchan*nlon_out + int64_t(wo)*nchan;
y += batch*nlat_out*nlon_out*nchan + ho*nchan*nlon_out + wo*nchan; y += int64_t(batch)*nlat_out*nlon_out*nchan + int64_t(ho)*nchan*nlon_out + int64_t(wo)*nchan;
float alpha_sum = 0.0f; float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX; float qdotk_max = -FLT_MAX;
...@@ -235,8 +235,8 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha ...@@ -235,8 +235,8 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha
const int wi = col - (hi*nlon_in); const int wi = col - (hi*nlon_in);
const int wip = (wi+wo) - ((wi+wo) / nlon_in) * nlon_in; const int wip = (wi+wo) - ((wi+wo) / nlon_in) * nlon_in;
const FLOATV_T *_kx = kx + hi*nlon_in*nchan + wip*nchan; const FLOATV_T *_kx = kx + int64_t(hi)*nlon_in*nchan + int64_t(wip)*nchan;
const FLOATV_T *_vx = vx + hi*nlon_in*nchan + wip*nchan; const FLOATV_T *_vx = vx + int64_t(hi)*nlon_in*nchan + int64_t(wip)*nchan;
FLOATV_T qdotkv = __vset<FLOATV_T>(0.f); FLOATV_T qdotkv = __vset<FLOATV_T>(0.f);
...@@ -275,7 +275,6 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha ...@@ -275,7 +275,6 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha
template<typename FLOATV_T> template<typename FLOATV_T>
void launch_gen_attn_kernel(int batch_size, void launch_gen_attn_kernel(int batch_size,
int nloc,
int nchans, int nchans,
int nlat_in, int nlat_in,
int nlon_in, int nlon_in,
...@@ -352,10 +351,10 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan ...@@ -352,10 +351,10 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan
const int wo = ctaid - (h*nlon_out); const int wo = ctaid - (h*nlon_out);
const int ho = row_idx[h]; const int ho = row_idx[h];
kx += batch*nlat_in*nlon_in*nchan + tidx; kx += int64_t(batch)*nlat_in*nlon_in*nchan + tidx;
vx += batch*nlat_in*nlon_in*nchan + tidx; vx += int64_t(batch)*nlat_in*nlon_in*nchan + tidx;
qy += batch*nlat_out*nlon_out*nchan + ho*nlon_out*nchan + wo*nchan + tidx; qy += int64_t(batch)*nlat_out*nlon_out*nchan + int64_t(ho)*nlon_out*nchan + int64_t(wo)*nchan + tidx;
y += batch*nlat_out*nlon_out*nchan + ho*nlon_out*nchan + wo*nchan + tidx; y += int64_t(batch)*nlat_out*nlon_out*nchan + int64_t(ho)*nlon_out*nchan + int64_t(wo)*nchan + tidx;
#pragma unroll #pragma unroll
for(int i = 0; i < NLOC; i++) { for(int i = 0; i < NLOC; i++) {
...@@ -386,8 +385,8 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan ...@@ -386,8 +385,8 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan
const int wi = col - (hi*nlon_in); const int wi = col - (hi*nlon_in);
const int wip = (wi+wo) - ((wi+wo) / nlon_in) * nlon_in; const int wip = (wi+wo) - ((wi+wo) / nlon_in) * nlon_in;
const FLOATV_T *_kx = kx + hi*nlon_in*nchan + wip*nchan; const FLOATV_T *_kx = kx + int64_t(hi)*nlon_in*nchan + int64_t(wip)*nchan;
const FLOATV_T *_vx = vx + hi*nlon_in*nchan + wip*nchan; const FLOATV_T *_vx = vx + int64_t(hi)*nlon_in*nchan + int64_t(wip)*nchan;
FLOATV_T qdotkv = __vset<FLOATV_T>(0.f); FLOATV_T qdotkv = __vset<FLOATV_T>(0.f);
...@@ -626,7 +625,7 @@ static void s2_attention_dipatch(int batch_size, ...@@ -626,7 +625,7 @@ static void s2_attention_dipatch(int batch_size,
case 256: launch_spc_attn_kernel< 256, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break; case 256: launch_spc_attn_kernel< 256, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
case 512: launch_spc_attn_kernel< 512, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break; case 512: launch_spc_attn_kernel< 512, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
case 1024: launch_spc_attn_kernel<1024, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break; case 1024: launch_spc_attn_kernel<1024, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
default: launch_gen_attn_kernel (batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break; default: launch_gen_attn_kernel (batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
} }
} else { } else {
...@@ -649,7 +648,7 @@ static void s2_attention_dipatch(int batch_size, ...@@ -649,7 +648,7 @@ static void s2_attention_dipatch(int batch_size,
case 256: launch_spc_attn_kernel< 256, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break; case 256: launch_spc_attn_kernel< 256, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
case 512: launch_spc_attn_kernel< 512, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break; case 512: launch_spc_attn_kernel< 512, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
case 1024: launch_spc_attn_kernel<1024, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break; case 1024: launch_spc_attn_kernel<1024, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
default: launch_gen_attn_kernel (batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break; default: launch_gen_attn_kernel (batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment