"vscode:/vscode.git/clone" did not exist on "dfaf2b20fb7c7e2f24553342af27ef67989426a7"
Commit 07fa44d6 authored by Mauro Bisson's avatar Mauro Bisson Committed by Thorsten Kurth
Browse files

Added missing kernel launch error checks.

parent 9689109f
......@@ -61,6 +61,7 @@ def get_compile_args(module_name):
nvcc_extra_flags = []
if profile_mode:
nvcc_extra_flags.append("-lineinfo")
nvcc_extra_flags.append("-Xptxas=-v")
if debug_mode:
print(f"WARNING: Compiling {module_name} with debugging flags")
......
......@@ -715,6 +715,8 @@ void launch_gen_attn_bwd(int batch_size,
<<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx,
_quad_weights, _dkxp, _dvxp, _dqyp);
CHECK_ERROR("s2_attn_bwd_generic_vec_k");
return;
}
......@@ -754,6 +756,8 @@ void launch_spc_attn_bwd(int batch_size,
<<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx,
_quad_weights, _dkxp, _dvxp, _dqyp);
CHECK_ERROR("s2_attn_bwd_special_vec_k");
return;
}
if constexpr(CUR_LOC_SIZE < MAX_LOC_SIZE) {
......
......@@ -311,6 +311,8 @@ void launch_gen_attn_fwd(int batch_size,
s2_attn_fwd_generic_vec_k<THREADS>
<<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp);
CHECK_ERROR("s2_attn_fwd_generic_vec_k");
return;
}
......@@ -346,6 +348,8 @@ void launch_spc_attn_fwd(int batch_size,
s2_attn_fwd_special_vec_k<BDIM_X, BDIM_Y, CUR_LOC_SIZE>
<<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp);
CHECK_ERROR("s2_attn_fwd_special_vec_k");
return;
}
if constexpr(CUR_LOC_SIZE < MAX_LOC_SIZE) {
......
......@@ -206,6 +206,7 @@ at::Tensor permute_4D_floatT_to0231(at::Tensor src, cudaStream_t stream) {
src.size(3),
src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
CHECK_ERROR("permute_to0231_k_tile_generic");
} else {
block.y = TRANSP_WARPS_X_TILE_SM100;
permute_to0231_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SM100>
......@@ -214,6 +215,7 @@ at::Tensor permute_4D_floatT_to0231(at::Tensor src, cudaStream_t stream) {
src.size(3),
src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
CHECK_ERROR("permute_to0231_k_tile_sm100");
}
return dst;
......@@ -306,6 +308,7 @@ at::Tensor permute_4D_floatT_to0312(at::Tensor src, cudaStream_t stream) {
src.size(2),
src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
CHECK_ERROR("permute_to0312_k_tile_generic");
} else {
block.y = TRANSP_WARPS_X_TILE_SM100;
permute_to0312_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SM100>
......@@ -314,6 +317,7 @@ at::Tensor permute_4D_floatT_to0312(at::Tensor src, cudaStream_t stream) {
src.size(2),
src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
CHECK_ERROR("permute_to0312_k_tile_sm100");
}
return dst;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment