Commit 07fa44d6 authored by Mauro Bisson's avatar Mauro Bisson Committed by Thorsten Kurth
Browse files

Added missing kernel launch error checks.

parent 9689109f
...@@ -61,6 +61,7 @@ def get_compile_args(module_name): ...@@ -61,6 +61,7 @@ def get_compile_args(module_name):
nvcc_extra_flags = [] nvcc_extra_flags = []
if profile_mode: if profile_mode:
nvcc_extra_flags.append("-lineinfo") nvcc_extra_flags.append("-lineinfo")
nvcc_extra_flags.append("-Xptxas=-v")
if debug_mode: if debug_mode:
print(f"WARNING: Compiling {module_name} with debugging flags") print(f"WARNING: Compiling {module_name} with debugging flags")
......
...@@ -715,6 +715,8 @@ void launch_gen_attn_bwd(int batch_size, ...@@ -715,6 +715,8 @@ void launch_gen_attn_bwd(int batch_size,
<<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out, <<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx, _kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx,
_quad_weights, _dkxp, _dvxp, _dqyp); _quad_weights, _dkxp, _dvxp, _dqyp);
CHECK_ERROR("s2_attn_bwd_generic_vec_k");
return; return;
} }
...@@ -754,6 +756,8 @@ void launch_spc_attn_bwd(int batch_size, ...@@ -754,6 +756,8 @@ void launch_spc_attn_bwd(int batch_size,
<<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out, <<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx, _kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx,
_quad_weights, _dkxp, _dvxp, _dqyp); _quad_weights, _dkxp, _dvxp, _dqyp);
CHECK_ERROR("s2_attn_bwd_special_vec_k");
return; return;
} }
if constexpr(CUR_LOC_SIZE < MAX_LOC_SIZE) { if constexpr(CUR_LOC_SIZE < MAX_LOC_SIZE) {
......
...@@ -311,6 +311,8 @@ void launch_gen_attn_fwd(int batch_size, ...@@ -311,6 +311,8 @@ void launch_gen_attn_fwd(int batch_size,
s2_attn_fwd_generic_vec_k<THREADS> s2_attn_fwd_generic_vec_k<THREADS>
<<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out, <<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp); _kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp);
CHECK_ERROR("s2_attn_fwd_generic_vec_k");
return; return;
} }
...@@ -346,6 +348,8 @@ void launch_spc_attn_fwd(int batch_size, ...@@ -346,6 +348,8 @@ void launch_spc_attn_fwd(int batch_size,
s2_attn_fwd_special_vec_k<BDIM_X, BDIM_Y, CUR_LOC_SIZE> s2_attn_fwd_special_vec_k<BDIM_X, BDIM_Y, CUR_LOC_SIZE>
<<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out, <<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp); _kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp);
CHECK_ERROR("s2_attn_fwd_special_vec_k");
return; return;
} }
if constexpr(CUR_LOC_SIZE < MAX_LOC_SIZE) { if constexpr(CUR_LOC_SIZE < MAX_LOC_SIZE) {
......
...@@ -206,6 +206,7 @@ at::Tensor permute_4D_floatT_to0231(at::Tensor src, cudaStream_t stream) { ...@@ -206,6 +206,7 @@ at::Tensor permute_4D_floatT_to0231(at::Tensor src, cudaStream_t stream) {
src.size(3), src.size(3),
src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>()); dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
CHECK_ERROR("permute_to0231_k_tile_generic");
} else { } else {
block.y = TRANSP_WARPS_X_TILE_SM100; block.y = TRANSP_WARPS_X_TILE_SM100;
permute_to0231_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SM100> permute_to0231_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SM100>
...@@ -214,6 +215,7 @@ at::Tensor permute_4D_floatT_to0231(at::Tensor src, cudaStream_t stream) { ...@@ -214,6 +215,7 @@ at::Tensor permute_4D_floatT_to0231(at::Tensor src, cudaStream_t stream) {
src.size(3), src.size(3),
src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>()); dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
CHECK_ERROR("permute_to0231_k_tile_sm100");
} }
return dst; return dst;
...@@ -306,6 +308,7 @@ at::Tensor permute_4D_floatT_to0312(at::Tensor src, cudaStream_t stream) { ...@@ -306,6 +308,7 @@ at::Tensor permute_4D_floatT_to0312(at::Tensor src, cudaStream_t stream) {
src.size(2), src.size(2),
src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>()); dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
CHECK_ERROR("permute_to0312_k_tile_generic");
} else { } else {
block.y = TRANSP_WARPS_X_TILE_SM100; block.y = TRANSP_WARPS_X_TILE_SM100;
permute_to0312_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SM100> permute_to0312_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SM100>
...@@ -314,6 +317,7 @@ at::Tensor permute_4D_floatT_to0312(at::Tensor src, cudaStream_t stream) { ...@@ -314,6 +317,7 @@ at::Tensor permute_4D_floatT_to0312(at::Tensor src, cudaStream_t stream) {
src.size(2), src.size(2),
src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>()); dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
CHECK_ERROR("permute_to0312_k_tile_sm100");
} }
return dst; return dst;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment