Unverified Commit 4285874d authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Common] Add checks to CUDA kernel launch and CUDA API calls (#2074)



* add checks to cuda kernel launch and cuda API calls
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* Remove exceptions from destructors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix weired dispatch in ln/rmsnorm
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 715c3bb8
...@@ -13,7 +13,7 @@ using namespace transformer_engine::normalization; ...@@ -13,7 +13,7 @@ using namespace transformer_engine::normalization;
template <typename weight_t, typename input_t, typename output_t, typename compute_t, template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N, typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
int BYTES_PER_LDG> int BYTES_PER_LDG>
void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params, void launch_rmsnorm_fwd_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*) const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE, using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>; CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>;
...@@ -21,8 +21,8 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params, ...@@ -21,8 +21,8 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
if (configure_params) { if (configure_params) {
int ctas_per_sm; int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
launch_params.params.ctas_per_row = CTAS_PER_ROW; launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col = launch_params.params.ctas_per_col =
launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row;
...@@ -46,18 +46,20 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params, ...@@ -46,18 +46,20 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
if (ctas_per_row == 1) { if (ctas_per_row == 1) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>( kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(
launch_params.params); launch_params.params);
NVTE_CHECK_CUDA(cudaGetLastError());
} else { } else {
dim3 grid(ctas_per_row * ctas_per_col); dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA); dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = reinterpret_cast<void *>(&launch_params.params); void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, // NOLINT(*) NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
Kernel_traits::SMEM_BYTES_FWD, stream); reinterpret_cast<void **>(&params_),
Kernel_traits::SMEM_BYTES_FWD, stream));
} }
} }
template <typename weight_t, typename input_t, typename output_t, typename compute_t, template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG> typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG>
void launch_general_(LaunchParams<ForwardKernelParams> &launch_params, void launch_rmsnorm_fwd_general_(LaunchParams<ForwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*) const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE, using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
1, WARPS_M, WARPS_N, BYTES_PER_LDG>; 1, WARPS_M, WARPS_N, BYTES_PER_LDG>;
...@@ -71,8 +73,8 @@ void launch_general_(LaunchParams<ForwardKernelParams> &launch_params, ...@@ -71,8 +73,8 @@ void launch_general_(LaunchParams<ForwardKernelParams> &launch_params,
int ctas_per_row = launch_params.params.ctas_per_row; int ctas_per_row = launch_params.params.ctas_per_row;
if (configure_params) { if (configure_params) {
int ctas_per_sm; int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0));
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_row = ceil_div(cols, HIDDEN_SIZE); ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row);
...@@ -92,10 +94,11 @@ void launch_general_(LaunchParams<ForwardKernelParams> &launch_params, ...@@ -92,10 +94,11 @@ void launch_general_(LaunchParams<ForwardKernelParams> &launch_params,
dim3 block(Kernel_traits::THREADS_PER_CTA); dim3 block(Kernel_traits::THREADS_PER_CTA);
if (ctas_per_row == 1) { if (ctas_per_row == 1) {
kernel<<<grid, block, 0, stream>>>(launch_params.params); kernel<<<grid, block, 0, stream>>>(launch_params.params);
NVTE_CHECK_CUDA(cudaGetLastError());
} else { } else {
void *params_ = reinterpret_cast<void *>(&launch_params.params); void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block, NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream); reinterpret_cast<void **>(&params_), 0, stream));
} }
} }
...@@ -105,8 +108,8 @@ void launch_general_(LaunchParams<ForwardKernelParams> &launch_params, ...@@ -105,8 +108,8 @@ void launch_general_(LaunchParams<ForwardKernelParams> &launch_params,
void \ void \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \ LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \
launch_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, __VA_ARGS__>( \ launch_rmsnorm_fwd_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, \
launch_params, configure_params); \ __VA_ARGS__>(launch_params, configure_params); \
} \ } \
REGISTER_NORM_BASE( \ REGISTER_NORM_BASE( \
NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
......
...@@ -35,17 +35,20 @@ void nvshmem_wait_on_stream(uint64_t* sig_addr, WaitKind wait_kind, cudaStream_t ...@@ -35,17 +35,20 @@ void nvshmem_wait_on_stream(uint64_t* sig_addr, WaitKind wait_kind, cudaStream_t
switch (wait_kind) { switch (wait_kind) {
case WaitKind::KERNEL_WAIT: case WaitKind::KERNEL_WAIT:
wait_until_on_stream_and_reset<<<1, 1, 0, cur_stream>>>(sig_addr, wait_value, signal_reset); wait_until_on_stream_and_reset<<<1, 1, 0, cur_stream>>>(sig_addr, wait_value, signal_reset);
NVTE_CHECK_CUDA(cudaGetLastError());
break; break;
case WaitKind::NVSHMEM_WAIT: case WaitKind::NVSHMEM_WAIT:
nvshmemx_uint64_wait_until_on_stream(sig_addr, NVSHMEM_CMP_EQ, wait_value, cur_stream); nvshmemx_uint64_wait_until_on_stream(sig_addr, NVSHMEM_CMP_EQ, wait_value, cur_stream);
cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)signal_reset, NVTE_CHECK_CUDA_DRIVER(cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr,
CU_STREAM_WRITE_VALUE_DEFAULT); (cuuint64_t)signal_reset,
CU_STREAM_WRITE_VALUE_DEFAULT));
break; break;
case WaitKind::STREAM_WAIT: case WaitKind::STREAM_WAIT:
cuStreamWaitValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)wait_value, NVTE_CHECK_CUDA_DRIVER(cuStreamWaitValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr,
CU_STREAM_WAIT_VALUE_GEQ); (cuuint64_t)wait_value, CU_STREAM_WAIT_VALUE_GEQ));
cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)signal_reset, NVTE_CHECK_CUDA_DRIVER(cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr,
CU_STREAM_WRITE_VALUE_DEFAULT); (cuuint64_t)signal_reset,
CU_STREAM_WRITE_VALUE_DEFAULT));
break; break;
} }
} }
...@@ -243,11 +243,13 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, ...@@ -243,11 +243,13 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id,
moe_permute_row_map<<<blocks, threads, 0, stream>>>(sorted_row_id, row_id_map, num_rows, topK, moe_permute_row_map<<<blocks, threads, 0, stream>>>(sorted_row_id, row_id_map, num_rows, topK,
num_out_tokens); num_out_tokens);
NVTE_CHECK_CUDA(cudaGetLastError());
blocks = num_rows; blocks = num_rows;
threads = std::min(num_cols / kElementsPerAccess, 1024); threads = std::min(num_cols / kElementsPerAccess, 1024);
moe_permute_kernel<T, TCompute, 128, false><<<blocks, threads, 0, stream>>>( moe_permute_kernel<T, TCompute, 128, false><<<blocks, threads, 0, stream>>>(
input, nullptr, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); input, nullptr, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols);
NVTE_CHECK_CUDA(cudaGetLastError());
} else { } else {
// moe_unpermute_bwd // moe_unpermute_bwd
...@@ -259,6 +261,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, ...@@ -259,6 +261,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id,
moe_permute_kernel<T, TCompute, 1, false><<<blocks, threads, 0, stream>>>( moe_permute_kernel<T, TCompute, 1, false><<<blocks, threads, 0, stream>>>(
input, input_fwd, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); input, input_fwd, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols);
NVTE_CHECK_CUDA(cudaGetLastError());
} else { } else {
// moe_unpermute_bwd with probs // moe_unpermute_bwd with probs
...@@ -282,6 +285,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, ...@@ -282,6 +285,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id,
} else { } else {
NVTE_ERROR("topK cannot exceed 128."); NVTE_ERROR("topK cannot exceed 128.");
} }
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
} }
...@@ -306,11 +310,13 @@ void nvte_unpermute_launcher(const T *input, T *output, int *row_id_map, const f ...@@ -306,11 +310,13 @@ void nvte_unpermute_launcher(const T *input, T *output, int *row_id_map, const f
moe_unpermute_kernel<T, TCompute, false><<<blocks, threads, smem_bytes, stream>>>( moe_unpermute_kernel<T, TCompute, false><<<blocks, threads, smem_bytes, stream>>>(
input, output, row_id_map, nullptr, num_rows, topK, num_cols); input, output, row_id_map, nullptr, num_rows, topK, num_cols);
NVTE_CHECK_CUDA(cudaGetLastError());
} else { } else {
// moe_unpermute_fwd with probs // moe_unpermute_fwd with probs
moe_unpermute_kernel<T, TCompute, true><<<blocks, threads, smem_bytes, stream>>>( moe_unpermute_kernel<T, TCompute, true><<<blocks, threads, smem_bytes, stream>>>(
input, output, row_id_map, prob, num_rows, topK, num_cols); input, output, row_id_map, prob, num_rows, topK, num_cols);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
......
...@@ -60,7 +60,7 @@ __launch_bounds__(amax_kernel_threads) __global__ ...@@ -60,7 +60,7 @@ __launch_bounds__(amax_kernel_threads) __global__
template <int nvec, typename InputType> template <int nvec, typename InputType>
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) { void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) {
// Zero out amax so we can update with atomic max // Zero out amax so we can update with atomic max
cudaMemsetAsync(amax, 0, sizeof(float), stream); NVTE_CHECK_CUDA(cudaMemsetAsync(amax, 0, sizeof(float), stream));
// Return immediately if tensor is empty // Return immediately if tensor is empty
if (N == 0) { if (N == 0) {
......
...@@ -183,6 +183,7 @@ void fp8_block_scaling_compute_partial_amax(const Tensor inp, Tensor amax, size_ ...@@ -183,6 +183,7 @@ void fp8_block_scaling_compute_partial_amax(const Tensor inp, Tensor amax, size_
reinterpret_cast<float *>(amax.data.dptr), reinterpret_cast<float *>(amax.data.dptr),
amax_stride_h, amax_stride_w, h, w, start_offset, amax_stride_h, amax_stride_w, h, w, start_offset,
len);) len);)
NVTE_CHECK_CUDA(cudaGetLastError());
} }
void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor scale, size_t h, void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor scale, size_t h,
...@@ -215,6 +216,7 @@ void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor s ...@@ -215,6 +216,7 @@ void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor s
reinterpret_cast<fp8_type *>(out.data.dptr), reinterpret_cast<fp8_type *>(out.data.dptr),
reinterpret_cast<const float *>(scale.data.dptr), scale_stride_h, scale_stride_w, reinterpret_cast<const float *>(scale.data.dptr), scale_stride_h, scale_stride_w,
h, w, start_offset, len);))) h, w, start_offset, len);)))
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} // namespace fp8_block_scaling_recipe } // namespace fp8_block_scaling_recipe
......
...@@ -387,22 +387,25 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -387,22 +387,25 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
const int original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE; const int original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE;
switch (vec_load_size) { switch (vec_load_size) {
case 4: case 4:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute(swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>( <<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break; break;
case 2: case 2:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute(swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>( <<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
break; break;
case 1: case 1:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute(swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>( <<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
...@@ -411,6 +414,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -411,6 +414,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
NVTE_ERROR("Not valid vec_load_size."); NVTE_ERROR("Not valid vec_load_size.");
break; break;
} }
NVTE_CHECK_CUDA(cudaGetLastError());
} }
if (input->has_columnwise_data()) { if (input->has_columnwise_data()) {
int vec_load_size = (num_tiles_m - 1) % 4 + 1; int vec_load_size = (num_tiles_m - 1) % 4 + 1;
...@@ -422,24 +426,27 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -422,24 +426,27 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE;
switch (vec_load_size) { switch (vec_load_size) {
case 4: case 4:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute(swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr, <<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m, output->columnwise_scale_inv.dptr, m,
k, original_M, original_K); k, original_M, original_K);
break; break;
case 2: case 2:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute(swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr, <<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m, output->columnwise_scale_inv.dptr, m,
k, original_M, original_K); k, original_M, original_K);
break; break;
case 1: case 1:
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute(swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr, <<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m, output->columnwise_scale_inv.dptr, m,
...@@ -449,6 +456,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -449,6 +456,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
NVTE_ERROR("Not valid vec_load_size."); NVTE_ERROR("Not valid vec_load_size.");
break; break;
} }
NVTE_CHECK_CUDA(cudaGetLastError());
} }
// 2D block scaling // 2D block scaling
...@@ -489,23 +497,23 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, ...@@ -489,23 +497,23 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args,
if (is_rowwise) { if (is_rowwise) {
switch (vec_load_size) { switch (vec_load_size) {
case 4: case 4:
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
multi_tensor_swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>, multi_tensor_swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
multi_tensor_swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K> multi_tensor_swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args); <<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break; break;
case 2: case 2:
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
multi_tensor_swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>, multi_tensor_swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
multi_tensor_swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K> multi_tensor_swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args); <<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break; break;
case 1: case 1:
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
multi_tensor_swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>, multi_tensor_swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
multi_tensor_swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K> multi_tensor_swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args); <<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break; break;
...@@ -516,23 +524,23 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, ...@@ -516,23 +524,23 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args,
} else { } else {
switch (vec_load_size) { switch (vec_load_size) {
case 4: case 4:
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
multi_tensor_swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>, multi_tensor_swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
multi_tensor_swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K> multi_tensor_swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args); <<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break; break;
case 2: case 2:
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
multi_tensor_swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>, multi_tensor_swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
multi_tensor_swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K> multi_tensor_swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args); <<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break; break;
case 1: case 1:
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
multi_tensor_swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>, multi_tensor_swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
multi_tensor_swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K> multi_tensor_swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(kernel_args); <<<num_blocks, block_size, slm_size, stream>>>(kernel_args);
break; break;
......
...@@ -544,11 +544,11 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) { ...@@ -544,11 +544,11 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) {
// Zero out tensor data if allocated // Zero out tensor data if allocated
if (t.data.dptr != nullptr) { if (t.data.dptr != nullptr) {
const size_t size_in_bytes = nvte_tensor_size_bytes(tensor); const size_t size_in_bytes = nvte_tensor_size_bytes(tensor);
cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream); NVTE_CHECK_CUDA(cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream));
} }
// Set amax to 0 if allocated // Set amax to 0 if allocated
if (t.amax.dptr != nullptr) { if (t.amax.dptr != nullptr) {
cudaMemsetAsync(t.amax.dptr, 0, sizeof(float), stream); NVTE_CHECK_CUDA(cudaMemsetAsync(t.amax.dptr, 0, sizeof(float), stream));
} }
} }
......
...@@ -335,6 +335,7 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cu ...@@ -335,6 +335,7 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cu
static_cast<const CType *>(output.scale.dptr), static_cast<const CType *>(output.scale.dptr),
static_cast<CType *>(output.amax.dptr), static_cast<CType *>(output.amax.dptr),
static_cast<CType *>(output.scale_inv.dptr), row_length, num_rows); static_cast<CType *>(output.scale_inv.dptr), row_length, num_rows);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} else { } else {
NVTE_ERROR("Not implemented scaling mode: ", to_string(output.scaling_mode)); NVTE_ERROR("Not implemented scaling mode: ", to_string(output.scaling_mode));
......
...@@ -264,6 +264,7 @@ void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_lengt ...@@ -264,6 +264,7 @@ void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_lengt
reinterpret_cast<InputType *>(dbias->data.dptr), reinterpret_cast<InputType *>(dbias->data.dptr),
reinterpret_cast<const fp32 *>(workspace.data.dptr), reduce_dbias_row_length, reinterpret_cast<const fp32 *>(workspace.data.dptr), reduce_dbias_row_length,
reduce_dbias_num_rows); reduce_dbias_num_rows);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ComputeType, typename Param, template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ComputeType, typename Param,
...@@ -737,14 +738,15 @@ void cast_transpose_fused(const Tensor &input, const Tensor *act_input, Tensor * ...@@ -737,14 +738,15 @@ void cast_transpose_fused(const Tensor &input, const Tensor *act_input, Tensor *
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, IS_ACT, ComputeType,
Param, nvec_in, nvec_out, Empty, OP>, Param, nvec_in, nvec_out, Empty, OP>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100); cudaFuncAttributePreferredSharedMemoryCarveout, 100));
cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Param, cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Param,
nvec_in, nvec_out, Empty, OP> nvec_in, nvec_out, Empty, OP>
<<<num_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>( <<<num_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>(
param, row_length, num_rows, num_tiles); param, row_length, num_rows, num_tiles);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
...@@ -1197,10 +1199,10 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu ...@@ -1197,10 +1199,10 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu
const size_t shmem_size = cast_transpose_num_threads / n_warps_per_tile * const size_t shmem_size = cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>); (THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>);
if (full_tile) { if (full_tile) {
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
dgated_act_cast_transpose_kernel<nvec_in, nvec_out, ComputeType, InputType, dgated_act_cast_transpose_kernel<nvec_in, nvec_out, ComputeType, InputType,
OutputType, Empty, OP1, OP2>, OutputType, Empty, OP1, OP2>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100); cudaFuncAttributePreferredSharedMemoryCarveout, 100));
dgated_act_cast_transpose_kernel<nvec_in, nvec_out, ComputeType, InputType, OutputType, dgated_act_cast_transpose_kernel<nvec_in, nvec_out, ComputeType, InputType, OutputType,
Empty, OP1, OP2> Empty, OP1, OP2>
...@@ -1213,11 +1215,12 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu ...@@ -1213,11 +1215,12 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu
reinterpret_cast<fp32 *>(output->amax.dptr), reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), row_length, num_rows, reinterpret_cast<fp32 *>(output->scale_inv.dptr), row_length, num_rows,
n_tiles); n_tiles);
NVTE_CHECK_CUDA(cudaGetLastError());
} else { } else {
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
dgated_act_cast_transpose_kernel_notaligned<nvec_in, nvec_out, ComputeType, dgated_act_cast_transpose_kernel_notaligned<nvec_in, nvec_out, ComputeType,
InputType, OutputType, Empty, OP1, OP2>, InputType, OutputType, Empty, OP1, OP2>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100); cudaFuncAttributePreferredSharedMemoryCarveout, 100));
dgated_act_cast_transpose_kernel_notaligned<nvec_in, nvec_out, ComputeType, InputType, dgated_act_cast_transpose_kernel_notaligned<nvec_in, nvec_out, ComputeType, InputType,
OutputType, Empty, OP1, OP2> OutputType, Empty, OP1, OP2>
<<<n_blocks, cast_transpose_num_threads, shmem_size, stream>>>( <<<n_blocks, cast_transpose_num_threads, shmem_size, stream>>>(
...@@ -1229,6 +1232,7 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu ...@@ -1229,6 +1232,7 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu
reinterpret_cast<fp32 *>(output->amax.dptr), reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), row_length, num_rows, reinterpret_cast<fp32 *>(output->scale_inv.dptr), row_length, num_rows,
n_tiles); n_tiles);
NVTE_CHECK_CUDA(cudaGetLastError());
}); // NOLINT(*) }); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
} }
......
...@@ -258,6 +258,7 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, std::vector<Ten ...@@ -258,6 +258,7 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, std::vector<Ten
multi_cast_transpose_kernel<nvec_in, nvec_out, true, fp32, InputType, OutputType> multi_cast_transpose_kernel<nvec_in, nvec_out, true, fp32, InputType, OutputType>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_aligned);); // NOLINT(*) <<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_aligned);); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
kernel_args_aligned.num_tensors = 0; kernel_args_aligned.num_tensors = 0;
} }
if (kernel_args_unaligned.num_tensors == kMaxTensorsPerKernel) { if (kernel_args_unaligned.num_tensors == kMaxTensorsPerKernel) {
...@@ -271,6 +272,7 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, std::vector<Ten ...@@ -271,6 +272,7 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, std::vector<Ten
multi_cast_transpose_kernel<nvec_in, nvec_out, false, fp32, InputType, OutputType> multi_cast_transpose_kernel<nvec_in, nvec_out, false, fp32, InputType, OutputType>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_unaligned);); // NOLINT(*) <<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_unaligned);); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
kernel_args_unaligned.num_tensors = 0; kernel_args_unaligned.num_tensors = 0;
} }
...@@ -311,6 +313,7 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, std::vector<Ten ...@@ -311,6 +313,7 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, std::vector<Ten
multi_cast_transpose_kernel<nvec_in, nvec_out, true, fp32, InputType, OutputType> multi_cast_transpose_kernel<nvec_in, nvec_out, true, fp32, InputType, OutputType>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_aligned);); // NOLINT(*) <<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_aligned);); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
} }
if (kernel_args_unaligned.num_tensors > 0) { if (kernel_args_unaligned.num_tensors > 0) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
...@@ -323,6 +326,7 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, std::vector<Ten ...@@ -323,6 +326,7 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, std::vector<Ten
multi_cast_transpose_kernel<nvec_in, nvec_out, false, fp32, InputType, OutputType> multi_cast_transpose_kernel<nvec_in, nvec_out, false, fp32, InputType, OutputType>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_unaligned);); // NOLINT(*) <<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_unaligned);); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
......
...@@ -279,6 +279,7 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr ...@@ -279,6 +279,7 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr
static_cast<const fp32 *>(noop.data.dptr), static_cast<const fp32 *>(noop.data.dptr),
static_cast<Type *>(output.data.dptr), static_cast<Type *>(output.data.dptr),
row_length, num_rows); row_length, num_rows);
NVTE_CHECK_CUDA(cudaGetLastError());
}); // NOLINT(*) }); // NOLINT(*)
} }
......
...@@ -416,6 +416,7 @@ void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_lengt ...@@ -416,6 +416,7 @@ void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_lengt
reinterpret_cast<BiasType *>(dbias->data.dptr), reinterpret_cast<BiasType *>(dbias->data.dptr),
reinterpret_cast<const fp32 *>(workspace.data.dptr), reduce_dbias_row_length, reinterpret_cast<const fp32 *>(workspace.data.dptr), reduce_dbias_row_length,
reduce_dbias_num_rows); reduce_dbias_num_rows);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor *dbias, void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor *dbias,
...@@ -472,17 +473,21 @@ void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor ...@@ -472,17 +473,21 @@ void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor
param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr); param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
if (full_tile) { if (full_tile) {
cudaFuncSetAttribute(transpose_dbias_kernel<nvec_in, nvec_out, Param>, NVTE_CHECK_CUDA(cudaFuncSetAttribute(transpose_dbias_kernel<nvec_in, nvec_out, Param>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100); cudaFuncAttributePreferredSharedMemoryCarveout,
100));
transpose_dbias_kernel<nvec_in, nvec_out, Param> transpose_dbias_kernel<nvec_in, nvec_out, Param>
<<<n_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>( <<<n_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>(
param, row_length, num_rows, n_tiles); param, row_length, num_rows, n_tiles);
NVTE_CHECK_CUDA(cudaGetLastError());
} else { } else {
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param>, cudaFuncSetAttribute(transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100); cudaFuncAttributePreferredSharedMemoryCarveout, 100));
transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param> transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param>
<<<n_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>( <<<n_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>(
param, row_length, num_rows, n_tiles); param, row_length, num_rows, n_tiles);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
reduce_dbias<BiasType>(*workspace, dbias, row_length, num_rows, nvec_out, reduce_dbias<BiasType>(*workspace, dbias, row_length, num_rows, nvec_out,
......
...@@ -950,15 +950,16 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu ...@@ -950,15 +950,16 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) +
(out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT;
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
cast_fp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType>, cast_fp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
cast_fp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType> cast_fp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType>
<<<grid_dim, block_dim, shmem_size, stream>>>( <<<grid_dim, block_dim, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act,
tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows,
cols);); // NOLINT(*) cols);
NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
} }
...@@ -1082,11 +1083,11 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -1082,11 +1083,11 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
switch (scaling_type) { switch (scaling_type) {
case ScalingType::ROWWISE: case ScalingType::ROWWISE:
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType,
OType, true, false, OType, true, false,
THREADS_PER_CHUNK_NON_COLWISE>, THREADS_PER_CHUNK_NON_COLWISE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType, mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
true, false, THREADS_PER_CHUNK_NON_COLWISE> true, false, THREADS_PER_CHUNK_NON_COLWISE>
...@@ -1096,13 +1097,14 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -1096,13 +1097,14 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise); scale_stride_colwise);
NVTE_CHECK_CUDA(cudaGetLastError());
break; break;
case ScalingType::COLWISE: case ScalingType::COLWISE:
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType,
OType, false, true, OType, false, true,
THREADS_PER_CHUNK_COLWISE>, THREADS_PER_CHUNK_COLWISE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType, mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
false, true, THREADS_PER_CHUNK_COLWISE> false, true, THREADS_PER_CHUNK_COLWISE>
...@@ -1112,13 +1114,14 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -1112,13 +1114,14 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise); scale_stride_colwise);
NVTE_CHECK_CUDA(cudaGetLastError());
break; break;
case ScalingType::BIDIMENSIONAL: case ScalingType::BIDIMENSIONAL:
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType,
OType, true, true, OType, true, true,
THREADS_PER_CHUNK_NON_COLWISE>, THREADS_PER_CHUNK_NON_COLWISE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType, mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
true, true, THREADS_PER_CHUNK_NON_COLWISE> true, true, THREADS_PER_CHUNK_NON_COLWISE>
...@@ -1128,6 +1131,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -1128,6 +1131,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise); scale_stride_colwise);
NVTE_CHECK_CUDA(cudaGetLastError());
break; break;
}); // NOLINT(*) }); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
......
...@@ -894,6 +894,7 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, ...@@ -894,6 +894,7 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows,
reduce_dbias_kernel<reduce_dbias_nvec, IType> reduce_dbias_kernel<reduce_dbias_nvec, IType>
<<<reduce_dbias_num_blocks, DBIAS_THREADS_PER_BLOCK, 0, stream>>>( <<<reduce_dbias_num_blocks, DBIAS_THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<IType *>(dbias->data.dptr), workspace_ptr, rows, cols); reinterpret_cast<IType *>(dbias->data.dptr), workspace_ptr, rows, cols);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)> template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
...@@ -925,6 +926,7 @@ static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream ...@@ -925,6 +926,7 @@ static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream
cast_fp8_1D_kernel<IS_ACT, ParamOP, OP, IType, OType><<<grid, block, 0, stream>>>( cast_fp8_1D_kernel<IS_ACT, ParamOP, OP, IType, OType><<<grid, block, 0, stream>>>(
input_ptr, output_ptr, amax_ptr, scale_inv_ptr, scale_ptr, N);); // NOLINT(*) input_ptr, output_ptr, amax_ptr, scale_inv_ptr, scale_ptr, N);); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
} }
template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, const ParamOP &)> template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
...@@ -988,6 +990,7 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T ...@@ -988,6 +990,7 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T
<<<grid, block, 0, stream>>>(tensor_map_input, tensor_map_act_input, tensor_map_output, <<<grid, block, 0, stream>>>(tensor_map_input, tensor_map_act_input, tensor_map_output,
workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows, workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows,
cols); cols);
NVTE_CHECK_CUDA(cudaGetLastError());
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream);
...@@ -1124,10 +1127,10 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -1124,10 +1127,10 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
switch (scaling_type) { switch (scaling_type) {
case ScalingType::ROWWISE: case ScalingType::ROWWISE:
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true, cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size));
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true, cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK> false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>
...@@ -1136,12 +1139,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -1136,12 +1139,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise); scale_stride_colwise);
NVTE_CHECK_CUDA(cudaGetLastError());
break; break;
case ScalingType::COLWISE: case ScalingType::COLWISE:
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, false, cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, false,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size));
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, false, cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, false,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK> true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>
...@@ -1150,12 +1154,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -1150,12 +1154,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise); scale_stride_colwise);
NVTE_CHECK_CUDA(cudaGetLastError());
break; break;
case ScalingType::BIDIMENSIONAL: case ScalingType::BIDIMENSIONAL:
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true, cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size));
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true, true, cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true, true,
CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK> CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>
...@@ -1164,6 +1169,7 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -1164,6 +1169,7 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise); scale_stride_colwise);
NVTE_CHECK_CUDA(cudaGetLastError());
break; break;
} }
......
...@@ -329,6 +329,7 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s ...@@ -329,6 +329,7 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
); // NOLINT(*) ); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} // namespace dequantization } // namespace dequantization
......
...@@ -248,6 +248,7 @@ void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> o ...@@ -248,6 +248,7 @@ void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> o
const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; const int n_blocks = kernel_args.block_range[kernel_args.num_tensors];
multi_padding_kernel<nvec, Type> multi_padding_kernel<nvec, Type>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*) <<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
kernel_args.num_tensors = 0; kernel_args.num_tensors = 0;
} }
...@@ -277,6 +278,7 @@ void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> o ...@@ -277,6 +278,7 @@ void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> o
const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; const int n_blocks = kernel_args.block_range[kernel_args.num_tensors];
multi_padding_kernel<nvec, Type> multi_padding_kernel<nvec, Type>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*) <<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
...@@ -322,6 +324,7 @@ void multi_unpadding(const std::vector<Tensor*> input_list, std::vector<Tensor*> ...@@ -322,6 +324,7 @@ void multi_unpadding(const std::vector<Tensor*> input_list, std::vector<Tensor*>
const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; const int n_blocks = kernel_args.block_range[kernel_args.num_tensors];
multi_unpadding_kernel<nvec, Type> multi_unpadding_kernel<nvec, Type>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*) <<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
kernel_args.num_tensors = 0; kernel_args.num_tensors = 0;
} }
...@@ -349,6 +352,7 @@ void multi_unpadding(const std::vector<Tensor*> input_list, std::vector<Tensor*> ...@@ -349,6 +352,7 @@ void multi_unpadding(const std::vector<Tensor*> input_list, std::vector<Tensor*>
const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; const int n_blocks = kernel_args.block_range[kernel_args.num_tensors];
multi_unpadding_kernel<nvec, Type> multi_unpadding_kernel<nvec, Type>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*) <<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
......
...@@ -364,6 +364,7 @@ void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, Out ...@@ -364,6 +364,7 @@ void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, Out
break; break;
} }
} }
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
...@@ -398,6 +399,7 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputTyp ...@@ -398,6 +399,7 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputTyp
break; break;
} }
} }
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
...@@ -491,6 +493,7 @@ void GatedActivationKernelLauncher(const InputType *input, OutputType *output, c ...@@ -491,6 +493,7 @@ void GatedActivationKernelLauncher(const InputType *input, OutputType *output, c
break; break;
} }
} }
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
...@@ -602,6 +605,7 @@ void DGatedActivationKernelLauncher(const InputType *grad, const InputType *inpu ...@@ -602,6 +605,7 @@ void DGatedActivationKernelLauncher(const InputType *grad, const InputType *inpu
break; break;
} }
} }
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment