Unverified Commit 4285874d authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Common] Add checks to CUDA kernel launch and CUDA API calls (#2074)



* add checks to cuda kernel launch and cuda API calls
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* Remove exceptions from destructors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix weired dispatch in ln/rmsnorm
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 715c3bb8
...@@ -101,10 +101,10 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl ...@@ -101,10 +101,10 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
DType::kInt32); DType::kInt32);
} }
// CUDA event creation // CUDA event creation
cudaEventCreateWithFlags(&_start_compute, 0); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_compute, 0));
cudaEventCreateWithFlags(&_stop_compute, 0); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_compute, 0));
cudaEventCreateWithFlags(&_start_comm, 0); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_comm, 0));
cudaEventCreateWithFlags(&_stop_comm, 0); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_comm, 0));
/* /*
Defining the launcher order between the communication and GEMM kernels Defining the launcher order between the communication and GEMM kernels
...@@ -114,11 +114,11 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl ...@@ -114,11 +114,11 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
*/ */
int max_connection = transformer_engine::getenv<int>("CUDA_DEVICE_MAX_CONNECTIONS", 8); int max_connection = transformer_engine::getenv<int>("CUDA_DEVICE_MAX_CONNECTIONS", 8);
int runtime_version = 0; int runtime_version = 0;
cudaRuntimeGetVersion(&runtime_version); NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&runtime_version));
cudaDeviceProp deviceProp; cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, 0); NVTE_CHECK_CUDA(cudaGetDeviceProperties(&deviceProp, 0));
if (runtime_version >= 12030 && deviceProp.major == 9 && max_connection > 1) { if (runtime_version >= 12030 && deviceProp.major == 9 && max_connection > 1) {
cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming));
} else { } else {
_comm_launch_event = 0; _comm_launch_event = 0;
} }
...@@ -129,9 +129,13 @@ CommOverlapCore::~CommOverlapCore() { ...@@ -129,9 +129,13 @@ CommOverlapCore::~CommOverlapCore() {
cudaEventDestroy(_start_comm); cudaEventDestroy(_start_comm);
cudaEventDestroy(_stop_compute); cudaEventDestroy(_stop_compute);
cudaEventDestroy(_start_compute); cudaEventDestroy(_start_compute);
if (_comm_launch_event) cudaEventDestroy(_comm_launch_event); if (_comm_launch_event) {
cudaEventDestroy(_comm_launch_event);
}
if (_atomic_gemm) cudaFree(_counter.dptr()); if (_atomic_gemm) {
cudaFree(_counter.dptr());
}
for (size_t i = 0; i < _stream_compute.size(); i++) { for (size_t i = 0; i < _stream_compute.size(); i++) {
cudaStreamSynchronize(_stream_compute[i]); cudaStreamSynchronize(_stream_compute[i]);
...@@ -698,7 +702,9 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { ...@@ -698,7 +702,9 @@ CommOverlapP2PBase::~CommOverlapP2PBase() {
cudaEventDestroy(_stop_recv); cudaEventDestroy(_stop_recv);
cudaEventDestroy(_stop_send); cudaEventDestroy(_stop_send);
cudaStreamDestroy(_stream_recv); cudaStreamDestroy(_stream_recv);
for (size_t i = 0; i < _stream_send.size(); i++) cudaStreamDestroy(_stream_send[i]); for (size_t i = 0; i < _stream_send.size(); i++) {
cudaStreamDestroy(_stream_send[i]);
}
} }
TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source, TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source,
......
...@@ -2319,6 +2319,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds ...@@ -2319,6 +2319,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds
if (comm->push == 0) { if (comm->push == 0) {
kuserbuffers_pullsend<<<1, 1, 0, stream>>>(comm->myrank, peer, &(comm->send_id[peer]), kuserbuffers_pullsend<<<1, 1, 0, stream>>>(comm->myrank, peer, &(comm->send_id[peer]),
reinterpret_cast<int *>(flagptr)); reinterpret_cast<int *>(flagptr));
NVTE_CHECK_CUDA(cudaGetLastError());
} else { } else {
void *srcptr = reinterpret_cast<char *>(comm->mem_ptr[srchandler]) + srcoffset; void *srcptr = reinterpret_cast<char *>(comm->mem_ptr[srchandler]) + srcoffset;
void *dstptr = reinterpret_cast<char *>(comm->peer_ptr[dsthandler][peerlocal]) + dstoffset; void *dstptr = reinterpret_cast<char *>(comm->peer_ptr[dsthandler][peerlocal]) + dstoffset;
...@@ -2516,8 +2517,11 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds ...@@ -2516,8 +2517,11 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]), reinterpret_cast<int *>(flagptr), &(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]), reinterpret_cast<int *>(flagptr),
reinterpret_cast<int4 *>(srcptr), reinterpret_cast<int4 *>(dstptr), reinterpret_cast<int4 *>(srcptr), reinterpret_cast<int4 *>(dstptr),
signalonly ? 0 : bytes / 16, comm->ub_timeout); signalonly ? 0 : bytes / 16, comm->ub_timeout);
if (!signalonly) NVTE_CHECK_CUDA(cudaGetLastError());
if (!signalonly) {
kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler])); kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]));
NVTE_CHECK_CUDA(cudaGetLastError());
}
if (comm->use_ce) { if (comm->use_ce) {
NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
} }
...@@ -2532,6 +2536,7 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds ...@@ -2532,6 +2536,7 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
reinterpret_cast<int *>(0 ? // temporary disable reinterpret_cast<int *>(0 ? // temporary disable
GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 2) GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 2)
: nullptr)); : nullptr));
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
...@@ -2612,24 +2617,28 @@ void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) { ...@@ -2612,24 +2617,28 @@ void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) {
dim3 block(1); dim3 block(1);
dim3 grid(1); dim3 grid(1);
producer_kernel<<<grid, block, 0, stream>>>(atomic_ptr, chunk_i); producer_kernel<<<grid, block, 0, stream>>>(atomic_ptr, chunk_i);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream) { void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream) {
dim3 block(1); dim3 block(1);
dim3 grid(1); dim3 grid(1);
consumer_kernel<<<grid, block, 0, stream>>>(atomic_ptr, chunk_i); consumer_kernel<<<grid, block, 0, stream>>>(atomic_ptr, chunk_i);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream) { void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream) {
dim3 block(1); dim3 block(1);
dim3 grid(1); dim3 grid(1);
consumer_batch_kernel<<<grid, block, 0, stream>>>(atomic_ptr, first_chunk_i, num_chunks); consumer_batch_kernel<<<grid, block, 0, stream>>>(atomic_ptr, first_chunk_i, num_chunks);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream_t stream) { void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream_t stream) {
dim3 block(1); dim3 block(1);
dim3 grid(1); dim3 grid(1);
reset_counters_kernel<<<grid, block, 0, stream>>>(atomic_ptr, num_chunks, allgather); reset_counters_kernel<<<grid, block, 0, stream>>>(atomic_ptr, num_chunks, allgather);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
template <typename fp8type, int nvec> template <typename fp8type, int nvec>
...@@ -2683,6 +2692,7 @@ void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_in ...@@ -2683,6 +2692,7 @@ void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_in
reduce_fp8_in_bf16_out_cuda<fp8type, nvec> reduce_fp8_in_bf16_out_cuda<fp8type, nvec>
<<<grid, block, 0, stream>>>(inputs, output, scale, num_inputs, input_size, <<<grid, block, 0, stream>>>(inputs, output, scale, num_inputs, input_size,
num_aligned_elements_per_input, tot_input_size); num_aligned_elements_per_input, tot_input_size);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output, float *scale, template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output, float *scale,
...@@ -2738,4 +2748,5 @@ void reduce_bf16(void *inputs, void *output, int num_inputs, int input_size, cud ...@@ -2738,4 +2748,5 @@ void reduce_bf16(void *inputs, void *output, int num_inputs, int input_size, cud
dim3 grid(num_blocks); dim3 grid(num_blocks);
reduce_bf16_cuda<nvec><<<grid, block, 0, stream>>>( reduce_bf16_cuda<nvec><<<grid, block, 0, stream>>>(
inputs, output, num_inputs, input_size, num_aligned_elements_per_input, tot_input_size); inputs, output, num_inputs, input_size, num_aligned_elements_per_input, tot_input_size);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
...@@ -50,6 +50,7 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) { ...@@ -50,6 +50,7 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) {
update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>( update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>(
reinterpret_cast<const float *>(t->scale.dptr), reinterpret_cast<const float *>(t->scale.dptr),
reinterpret_cast<float *>(t->scale_inv.dptr)); reinterpret_cast<float *>(t->scale_inv.dptr));
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
...@@ -91,6 +92,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) ...@@ -91,6 +92,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
dim3 grid(numBlocks, 1, 1); \ dim3 grid(numBlocks, 1, 1); \
memset_kernel<vectorizedType> \ memset_kernel<vectorizedType> \
<<<grid, kThreadsPerBlock, 0, stream>>>(ptr, value, size_in_bytes); \ <<<grid, kThreadsPerBlock, 0, stream>>>(ptr, value, size_in_bytes); \
NVTE_CHECK_CUDA(cudaGetLastError()); \
return; \ return; \
} }
...@@ -101,7 +103,7 @@ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream ...@@ -101,7 +103,7 @@ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream
if (size_in_bytes > 4096) { if (size_in_bytes > 4096) {
// Use cudaMemsetAsync for larger sizes. // Use cudaMemsetAsync for larger sizes.
cudaMemsetAsync(ptr, value, size_in_bytes, stream); NVTE_CHECK_CUDA(cudaMemsetAsync(ptr, value, size_in_bytes, stream));
return; return;
} }
......
...@@ -341,6 +341,7 @@ void thd_read_half_tensor(const Tensor &tensor, const Tensor &cu_seqlens, Tensor ...@@ -341,6 +341,7 @@ void thd_read_half_tensor(const Tensor &tensor, const Tensor &cu_seqlens, Tensor
thd_read_half_tensor_kernel<<<grid, block, sizeof(int) * (batch + 1), stream>>>( thd_read_half_tensor_kernel<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
half.data.dptr, tensor.data.dptr, reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, half.data.dptr, tensor.data.dptr, reinterpret_cast<int *>(cu_seqlens.data.dptr), batch,
hidden_size_in_bytes, half_idx, tensor_shape[seq_dim]); hidden_size_in_bytes, half_idx, tensor_shape[seq_dim]);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
/*************************************************************************************************** /***************************************************************************************************
...@@ -397,11 +398,13 @@ void thd_second_half_lse_correction(Tensor lse, const Tensor &lse_per_step, ...@@ -397,11 +398,13 @@ void thd_second_half_lse_correction(Tensor lse, const Tensor &lse_per_step,
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(lse_per_step.data.dptr), reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(lse_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen, reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
second_half_lse_seqlen); second_half_lse_seqlen);
NVTE_CHECK_CUDA(cudaGetLastError());
} else { } else {
thd_lse_kernel<false, LseCorrectionFunctor><<<grid, block, sizeof(int) * (batch + 1), stream>>>( thd_lse_kernel<false, LseCorrectionFunctor><<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(lse_per_step.data.dptr), reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(lse_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen, reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
second_half_lse_seqlen); second_half_lse_seqlen);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
...@@ -446,11 +449,13 @@ void thd_read_second_half_lse(const Tensor &lse, const Tensor &cu_seqlens, Tenso ...@@ -446,11 +449,13 @@ void thd_read_second_half_lse(const Tensor &lse, const Tensor &cu_seqlens, Tenso
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(half_lse.data.dptr), reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(half_lse.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen, reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
second_half_lse_seqlen); second_half_lse_seqlen);
NVTE_CHECK_CUDA(cudaGetLastError());
} else { } else {
thd_lse_kernel<false, ReadLseFunctor><<<grid, block, sizeof(int) * (batch + 1), stream>>>( thd_lse_kernel<false, ReadLseFunctor><<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(half_lse.data.dptr), reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(half_lse.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen, reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
second_half_lse_seqlen); second_half_lse_seqlen);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
...@@ -519,6 +524,7 @@ static void thd_out_correction_helper(Tensor out, const Tensor &out_per_step, co ...@@ -519,6 +524,7 @@ static void thd_out_correction_helper(Tensor out, const Tensor &out_per_step, co
reinterpret_cast<float *>(lse_per_step.data.dptr), reinterpret_cast<float *>(lse_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, dim_per_head, reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, dim_per_head,
lse_seqlen, lse_per_step_seqlen); lse_seqlen, lse_per_step_seqlen);
NVTE_CHECK_CUDA(cudaGetLastError());
} else { } else {
thd_out_correction_kernel<dtype, only_second_half, tile, false> thd_out_correction_kernel<dtype, only_second_half, tile, false>
<<<grid, block, sizeof(int) * (batch + 1), stream>>>( <<<grid, block, sizeof(int) * (batch + 1), stream>>>(
...@@ -528,6 +534,7 @@ static void thd_out_correction_helper(Tensor out, const Tensor &out_per_step, co ...@@ -528,6 +534,7 @@ static void thd_out_correction_helper(Tensor out, const Tensor &out_per_step, co
reinterpret_cast<float *>(lse_per_step.data.dptr), reinterpret_cast<float *>(lse_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, dim_per_head, reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, dim_per_head,
lse_seqlen, lse_per_step_seqlen); lse_seqlen, lse_per_step_seqlen);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
...@@ -602,6 +609,7 @@ static void thd_grad_correction_helper(Tensor grad, const Tensor &grad_per_step, ...@@ -602,6 +609,7 @@ static void thd_grad_correction_helper(Tensor grad, const Tensor &grad_per_step,
reinterpret_cast<dtype *>(grad.data.dptr), reinterpret_cast<dtype *>(grad.data.dptr),
reinterpret_cast<dtype *>(grad_per_step.data.dptr), reinterpret_cast<dtype *>(grad_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, hidden_size, total_tokens); reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, hidden_size, total_tokens);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
template <typename dtype> template <typename dtype>
...@@ -667,6 +675,7 @@ void thd_get_partitioned_indices(const Tensor &cu_seqlens, Tensor output, int to ...@@ -667,6 +675,7 @@ void thd_get_partitioned_indices(const Tensor &cu_seqlens, Tensor output, int to
thd_partition_indices_kernel<<<grid, block, sizeof(int) * (batch + 1), stream>>>( thd_partition_indices_kernel<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<int *>(output.data.dptr), reinterpret_cast<int *>(cu_seqlens.data.dptr), reinterpret_cast<int *>(output.data.dptr), reinterpret_cast<int *>(cu_seqlens.data.dptr),
batch, total_tokens, world_size, rank); batch, total_tokens, world_size, rank);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} // namespace context_parallel } // namespace context_parallel
......
...@@ -91,6 +91,7 @@ void prepare_flash_attn_fwd(Tensor qkvi, Tensor qkv, cudaStream_t stream) { ...@@ -91,6 +91,7 @@ void prepare_flash_attn_fwd(Tensor qkvi, Tensor qkv, cudaStream_t stream) {
prepare_kernel_fwd<dtype><<<grid, threads, 0, stream>>>( prepare_kernel_fwd<dtype><<<grid, threads, 0, stream>>>(
reinterpret_cast<dtype *>(qkvi.data.dptr), reinterpret_cast<dtype *>(qkv.data.dptr), reinterpret_cast<dtype *>(qkvi.data.dptr), reinterpret_cast<dtype *>(qkv.data.dptr),
shape[1], shape[2], shape[3], shape[4]);); shape[1], shape[2], shape[3], shape[4]););
NVTE_CHECK_CUDA(cudaGetLastError());
} }
void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream_t stream) { void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream_t stream) {
...@@ -129,6 +130,7 @@ void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream ...@@ -129,6 +130,7 @@ void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream
reinterpret_cast<dtype *>(q.data.dptr), reinterpret_cast<dtype *>(k.data.dptr), reinterpret_cast<dtype *>(q.data.dptr), reinterpret_cast<dtype *>(k.data.dptr),
reinterpret_cast<dtype *>(v.data.dptr), reinterpret_cast<dtype *>(qkv.data.dptr), reinterpret_cast<dtype *>(v.data.dptr), reinterpret_cast<dtype *>(qkv.data.dptr),
q_shape[0], q_shape[1], q_shape[2], q_shape[3]);); q_shape[0], q_shape[1], q_shape[2], q_shape[3]););
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} // namespace flash_attention } // namespace flash_attention
......
...@@ -416,6 +416,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -416,6 +416,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
actual_b, b, static_cast<const int32_t *>(devPtrCuSeqlensQ), actual_b, b, static_cast<const int32_t *>(devPtrCuSeqlensQ),
static_cast<const int32_t *>(devPtrCuSeqlensKV), static_cast<int32_t *>(devActualSeqlenQ), static_cast<const int32_t *>(devPtrCuSeqlensKV), static_cast<int32_t *>(devActualSeqlenQ),
static_cast<int32_t *>(devActualSeqlenKV)); static_cast<int32_t *>(devActualSeqlenKV));
NVTE_CHECK_CUDA(cudaGetLastError());
variant_pack[seq_q] = devActualSeqlenQ; variant_pack[seq_q] = devActualSeqlenQ;
variant_pack[seq_kv] = devActualSeqlenKV; variant_pack[seq_kv] = devActualSeqlenKV;
} }
...@@ -454,6 +455,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -454,6 +455,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ), layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ),
static_cast<int32_t *>(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, static_cast<int32_t *>(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK,
devOffsetsV, devOffsetsO, devOffsetsS); devOffsetsV, devOffsetsO, devOffsetsS);
NVTE_CHECK_CUDA(cudaGetLastError());
if (is_ragged_q) { if (is_ragged_q) {
variant_pack[offset_q] = devOffsetsQ; variant_pack[offset_q] = devOffsetsQ;
variant_pack[offset_o] = devOffsetsO; variant_pack[offset_o] = devOffsetsO;
...@@ -883,6 +885,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -883,6 +885,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
actual_b, b, static_cast<const int32_t *>(devPtrCuSeqlensQ), actual_b, b, static_cast<const int32_t *>(devPtrCuSeqlensQ),
static_cast<const int32_t *>(devPtrCuSeqlensKV), static_cast<int32_t *>(devActualSeqlenQ), static_cast<const int32_t *>(devPtrCuSeqlensKV), static_cast<int32_t *>(devActualSeqlenQ),
static_cast<int32_t *>(devActualSeqlenKV)); static_cast<int32_t *>(devActualSeqlenKV));
NVTE_CHECK_CUDA(cudaGetLastError());
variant_pack[seq_q] = devActualSeqlenQ; variant_pack[seq_q] = devActualSeqlenQ;
variant_pack[seq_kv] = devActualSeqlenKV; variant_pack[seq_kv] = devActualSeqlenKV;
} }
...@@ -916,6 +919,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -916,6 +919,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ), layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ),
static_cast<int32_t *>(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, static_cast<int32_t *>(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK,
devOffsetsV, devOffsetsO, devOffsetsS); devOffsetsV, devOffsetsO, devOffsetsS);
NVTE_CHECK_CUDA(cudaGetLastError());
if (is_ragged_q) { if (is_ragged_q) {
variant_pack[offset_q] = devOffsetsQ; variant_pack[offset_q] = devOffsetsQ;
variant_pack[offset_o] = devOffsetsO; variant_pack[offset_o] = devOffsetsO;
......
...@@ -1111,6 +1111,7 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in ...@@ -1111,6 +1111,7 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
cu_seqlens_to_offsets<<<gridDims, blockDims, 0, stream>>>( cu_seqlens_to_offsets<<<gridDims, blockDims, 0, stream>>>(
b, h, d, reinterpret_cast<int32_t*>(devPtrcuSeqlensQ), actual_seqlens_q, qkv_ragged_offset, b, h, d, reinterpret_cast<int32_t*>(devPtrcuSeqlensQ), actual_seqlens_q, qkv_ragged_offset,
o_ragged_offset); o_ragged_offset);
NVTE_CHECK_CUDA(cudaGetLastError());
void* devPtrQKVRaggedOffset = reinterpret_cast<void*>(qkv_ragged_offset); void* devPtrQKVRaggedOffset = reinterpret_cast<void*>(qkv_ragged_offset);
void* devPtrORaggedOffset = reinterpret_cast<void*>(o_ragged_offset); void* devPtrORaggedOffset = reinterpret_cast<void*>(o_ragged_offset);
void* devPtrMNKOverride = reinterpret_cast<void*>(actual_seqlens_q); void* devPtrMNKOverride = reinterpret_cast<void*>(actual_seqlens_q);
...@@ -1577,6 +1578,7 @@ void fused_attn_fp8_bwd_impl( ...@@ -1577,6 +1578,7 @@ void fused_attn_fp8_bwd_impl(
cu_seqlens_to_offsets<<<gridDims, blockDims, 0, stream>>>( cu_seqlens_to_offsets<<<gridDims, blockDims, 0, stream>>>(
b, h, d, reinterpret_cast<int32_t*>(devPtrcuSeqlensQ), actual_seqlens_q, qkv_ragged_offset, b, h, d, reinterpret_cast<int32_t*>(devPtrcuSeqlensQ), actual_seqlens_q, qkv_ragged_offset,
o_ragged_offset); o_ragged_offset);
NVTE_CHECK_CUDA(cudaGetLastError());
void* devPtrQKVRaggedOffset = reinterpret_cast<void*>(qkv_ragged_offset); void* devPtrQKVRaggedOffset = reinterpret_cast<void*>(qkv_ragged_offset);
void* devPtrORaggedOffset = reinterpret_cast<void*>(o_ragged_offset); void* devPtrORaggedOffset = reinterpret_cast<void*>(o_ragged_offset);
void* devPtrMNKOverride = reinterpret_cast<void*>(actual_seqlens_q); void* devPtrMNKOverride = reinterpret_cast<void*>(actual_seqlens_q);
...@@ -1933,6 +1935,7 @@ void fused_attn_fp8_fwd_impl_v1( ...@@ -1933,6 +1935,7 @@ void fused_attn_fp8_fwd_impl_v1(
b, b, static_cast<const int32_t*>(devPtrcuSeqlensQ), // TODO(pass max_b) b, b, static_cast<const int32_t*>(devPtrcuSeqlensQ), // TODO(pass max_b)
static_cast<const int32_t*>(devPtrcuSeqlensKV), static_cast<int32_t*>(devActualSeqlenQ), static_cast<const int32_t*>(devPtrcuSeqlensKV), static_cast<int32_t*>(devActualSeqlenQ),
static_cast<int32_t*>(devActualSeqlenKV)); static_cast<int32_t*>(devActualSeqlenKV));
NVTE_CHECK_CUDA(cudaGetLastError());
variant_pack[seq_q] = devActualSeqlenQ; variant_pack[seq_q] = devActualSeqlenQ;
variant_pack[seq_kv] = devActualSeqlenKV; variant_pack[seq_kv] = devActualSeqlenKV;
} }
...@@ -2329,6 +2332,7 @@ void fused_attn_fp8_bwd_impl_v1( ...@@ -2329,6 +2332,7 @@ void fused_attn_fp8_bwd_impl_v1(
b, b, static_cast<const int32_t*>(devPtrcuSeqlensQ), // TODO(pass max_b) b, b, static_cast<const int32_t*>(devPtrcuSeqlensQ), // TODO(pass max_b)
static_cast<const int32_t*>(devPtrcuSeqlensKV), static_cast<int32_t*>(devActualSeqlenQ), static_cast<const int32_t*>(devPtrcuSeqlensKV), static_cast<int32_t*>(devActualSeqlenQ),
static_cast<int32_t*>(devActualSeqlenKV)); static_cast<int32_t*>(devActualSeqlenKV));
NVTE_CHECK_CUDA(cudaGetLastError());
variant_pack[seq_q] = devActualSeqlenQ; variant_pack[seq_q] = devActualSeqlenQ;
variant_pack[seq_kv] = devActualSeqlenKV; variant_pack[seq_kv] = devActualSeqlenKV;
} }
......
...@@ -157,6 +157,7 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso ...@@ -157,6 +157,7 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso
reinterpret_cast<int *>(page_table.data.dptr), reinterpret_cast<int *>(page_table.data.dptr),
reinterpret_cast<int *>(cu_new_lens.data.dptr), reinterpret_cast<int *>(cu_new_lens.data.dptr),
reinterpret_cast<int *>(cu_cached_lens.data.dptr), h_kv, d_k, d_v, b, max_seq_len); reinterpret_cast<int *>(cu_cached_lens.data.dptr), h_kv, d_k, d_v, b, max_seq_len);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
dim3 grid_size(b, max_ctx_len); dim3 grid_size(b, max_ctx_len);
copy_to_kv_cache_kernel<<<grid_size, block_size, 0, stream>>>( copy_to_kv_cache_kernel<<<grid_size, block_size, 0, stream>>>(
...@@ -166,6 +167,7 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso ...@@ -166,6 +167,7 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso
reinterpret_cast<int *>(cu_new_lens.data.dptr), reinterpret_cast<int *>(cu_new_lens.data.dptr),
reinterpret_cast<int *>(cu_cached_lens.data.dptr), qkv_format, h_kv, d_k, d_v, b, reinterpret_cast<int *>(cu_cached_lens.data.dptr), qkv_format, h_kv, d_k, d_v, b,
max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged); max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
...@@ -215,6 +217,7 @@ void convert_thd_to_bshd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_se ...@@ -215,6 +217,7 @@ void convert_thd_to_bshd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_se
reinterpret_cast<scalar_t *>(tensor.data.dptr), reinterpret_cast<scalar_t *>(tensor.data.dptr),
reinterpret_cast<scalar_t *>(new_tensor.data.dptr), reinterpret_cast<scalar_t *>(new_tensor.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), b, max_seq_len, h, d); reinterpret_cast<int *>(cu_seqlens.data.dptr), b, max_seq_len, h, d);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
void convert_thd_to_bshd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int b, void convert_thd_to_bshd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int b,
...@@ -254,6 +257,7 @@ void convert_bshd_to_thd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_se ...@@ -254,6 +257,7 @@ void convert_bshd_to_thd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_se
reinterpret_cast<scalar_t *>(tensor.data.dptr), reinterpret_cast<scalar_t *>(tensor.data.dptr),
reinterpret_cast<scalar_t *>(new_tensor.data.dptr), reinterpret_cast<scalar_t *>(new_tensor.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), b, max_seq_len, h, d); reinterpret_cast<int *>(cu_seqlens.data.dptr), b, max_seq_len, h, d);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
void convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int t, void convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int t,
......
...@@ -600,13 +600,14 @@ uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cud ...@@ -600,13 +600,14 @@ uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cud
// workspace size requires 4 bytes // workspace size requires 4 bytes
uint32_t *dout = static_cast<uint32_t *>(workspace); uint32_t *dout = static_cast<uint32_t *>(workspace);
uint32_t hout{}; uint32_t hout{};
cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream); NVTE_CHECK_CUDA(cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream));
constexpr int threads = 128; constexpr int threads = 128;
const int blocks = (len - 1) / threads + 1; const int blocks = (len - 1) / threads + 1;
get_runtime_num_segments_kernel<<<blocks, threads, 0, stream>>>(static_cast<int32_t *>(cu_seqlen), get_runtime_num_segments_kernel<<<blocks, threads, 0, stream>>>(static_cast<int32_t *>(cu_seqlen),
len, dout); len, dout);
cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream); NVTE_CHECK_CUDA(cudaGetLastError());
cudaStreamSynchronize(stream); NVTE_CHECK_CUDA(cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream));
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
return hout; return hout;
} }
...@@ -633,4 +634,5 @@ void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t ...@@ -633,4 +634,5 @@ void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t
fused_attn::extract_seed_and_offset<<<1, 1, 0, stream>>>( fused_attn::extract_seed_and_offset<<<1, 1, 0, stream>>>(
rng_state_ptr, captured, seed_ptr, seed_val, offset_ptr, offset_val, offset_intragraph); rng_state_ptr, captured, seed_ptr, seed_val, offset_ptr, offset_val, offset_intragraph);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
...@@ -177,9 +177,9 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, ...@@ -177,9 +177,9 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
config.stream = stream; config.stream = stream;
// Update the max cluster size based on the device // Update the max cluster size based on the device
cudaOccupancyMaxPotentialClusterSize( NVTE_CHECK_CUDA(cudaOccupancyMaxPotentialClusterSize(
&cluster_size, &cluster_size,
reinterpret_cast<void*>(fused_moe_aux_loss_forward_kernel<DataType, IndexType>), &config); reinterpret_cast<void*>(fused_moe_aux_loss_forward_kernel<DataType, IndexType>), &config));
cudaLaunchAttribute attribute[1]; cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeClusterDimension; attribute[0].id = cudaLaunchAttributeClusterDimension;
...@@ -189,14 +189,15 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, ...@@ -189,14 +189,15 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
config.numAttrs = 1; config.numAttrs = 1;
config.attrs = attribute; config.attrs = attribute;
cudaLaunchKernelEx(&config, fused_moe_aux_loss_forward_kernel<DataType, IndexType>, probs, NVTE_CHECK_CUDA(cudaLaunchKernelEx(
tokens_per_expert, total_num_tokens, num_experts, num_rows, num_cols, topk, &config, fused_moe_aux_loss_forward_kernel<DataType, IndexType>, probs, tokens_per_expert,
coeff, aux_loss, Const_buf); total_num_tokens, num_experts, num_rows, num_cols, topk, coeff, aux_loss, Const_buf));
} else { } else {
size_t smem_size = sizeof(CompType) * num_cols; size_t smem_size = sizeof(CompType) * num_cols;
fused_moe_aux_loss_forward_kernel<DataType, IndexType> fused_moe_aux_loss_forward_kernel<DataType, IndexType>
<<<1, 1024, smem_size, stream>>>(probs, tokens_per_expert, total_num_tokens, num_experts, <<<1, 1024, smem_size, stream>>>(probs, tokens_per_expert, total_num_tokens, num_experts,
num_rows, num_cols, topk, coeff, aux_loss, Const_buf); num_rows, num_cols, topk, coeff, aux_loss, Const_buf);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
...@@ -247,6 +248,7 @@ void fused_moe_aux_loss_backward_kernel_launcher(const float* Const_buf, ...@@ -247,6 +248,7 @@ void fused_moe_aux_loss_backward_kernel_launcher(const float* Const_buf,
int grid_size = (num_rows + block_size - 1) / block_size; int grid_size = (num_rows + block_size - 1) / block_size;
fused_moe_aux_loss_backward_kernel<DataType, IndexType><<<grid_size, block_size, 0, stream>>>( fused_moe_aux_loss_backward_kernel<DataType, IndexType><<<grid_size, block_size, 0, stream>>>(
Const_buf, tokens_per_expert, num_rows, num_cols, grad_aux_loss, grad_probs); Const_buf, tokens_per_expert, num_rows, num_cols, grad_aux_loss, grad_probs);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
void fused_moe_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_per_expert, void fused_moe_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_per_expert,
......
...@@ -151,6 +151,7 @@ void fused_score_for_moe_aux_loss_forward_kernel_launcher( ...@@ -151,6 +151,7 @@ void fused_score_for_moe_aux_loss_forward_kernel_launcher(
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>( <<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
logits, num_tokens, num_experts, topk, score_function, scores, routing_map, logits, num_tokens, num_experts, topk, score_function, scores, routing_map,
intermediate_output); intermediate_output);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
void fused_score_for_moe_aux_loss_forward(const Tensor &logits, int num_tokens, int num_experts, void fused_score_for_moe_aux_loss_forward(const Tensor &logits, int num_tokens, int num_experts,
...@@ -286,6 +287,7 @@ void fused_score_for_moe_aux_loss_backward_kernel_launcher( ...@@ -286,6 +287,7 @@ void fused_score_for_moe_aux_loss_backward_kernel_launcher(
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>( <<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
intermediate_output, grad_scores, num_tokens, num_experts, topk, score_function, intermediate_output, grad_scores, num_tokens, num_experts, topk, score_function,
grad_logits); grad_logits);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
void fused_score_for_moe_aux_loss_backward(const Tensor &intermediate_output, void fused_score_for_moe_aux_loss_backward(const Tensor &intermediate_output,
......
...@@ -257,6 +257,7 @@ void fused_topk_with_score_function_forward_kernel_launcher( ...@@ -257,6 +257,7 @@ void fused_topk_with_score_function_forward_kernel_launcher(
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>( <<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk,
scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output); scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
void fused_topk_with_score_function_forward(const Tensor logits, int num_tokens, int num_experts, void fused_topk_with_score_function_forward(const Tensor logits, int num_tokens, int num_experts,
...@@ -447,6 +448,7 @@ void fused_topk_with_score_function_backward_kernel_launcher( ...@@ -447,6 +448,7 @@ void fused_topk_with_score_function_backward_kernel_launcher(
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>( <<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
routing_map, intermediate_output, grad_probs, num_tokens, num_experts, topk, routing_map, intermediate_output, grad_probs, num_tokens, num_experts, topk,
use_pre_softmax, scaling_factor, score_function, grad_logits); use_pre_softmax, scaling_factor, score_function, grad_logits);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
void fused_topk_with_score_function_backward(const Tensor &routing_map, void fused_topk_with_score_function_backward(const Tensor &routing_map,
......
...@@ -353,6 +353,7 @@ void call_kernel_scaled_aligned_causal_masked_softmax_forward( ...@@ -353,6 +353,7 @@ void call_kernel_scaled_aligned_causal_masked_softmax_forward(
scaled_aligned_causal_masked_softmax_warp_forward<input_t, output_t, acc_t, log2_elements> scaled_aligned_causal_masked_softmax_warp_forward<input_t, output_t, acc_t, log2_elements>
<<<grid_size, block_size, shmem_size, stream>>>(dst, src, scale, microbatches, query_seq_len, <<<grid_size, block_size, shmem_size, stream>>>(dst, src, scale, microbatches, query_seq_len,
key_seq_len); key_seq_len);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
template <typename input_t, typename output_t, typename acc_t, int log2_elements> template <typename input_t, typename output_t, typename acc_t, int log2_elements>
...@@ -363,6 +364,7 @@ void call_kernel_scaled_aligned_causal_masked_softmax_backward( ...@@ -363,6 +364,7 @@ void call_kernel_scaled_aligned_causal_masked_softmax_backward(
scaled_aligned_causal_masked_softmax_warp_backward<input_t, output_t, acc_t, log2_elements> scaled_aligned_causal_masked_softmax_warp_backward<input_t, output_t, acc_t, log2_elements>
<<<grid_size, block_size, 0, stream>>>(gradInput, grad, output, scale, microbatches, <<<grid_size, block_size, 0, stream>>>(gradInput, grad, output, scale, microbatches,
query_seq_len, key_seq_len); query_seq_len, key_seq_len);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
template <typename input_t, typename output_t, typename acc_t> template <typename input_t, typename output_t, typename acc_t>
......
...@@ -513,6 +513,7 @@ void dispatch_scaled_softmax_forward(output_t *dst, const input_t *src, const in ...@@ -513,6 +513,7 @@ void dispatch_scaled_softmax_forward(output_t *dst, const input_t *src, const in
default: default:
break; break;
} }
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
...@@ -625,6 +626,7 @@ void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src, c ...@@ -625,6 +626,7 @@ void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src, c
default: default:
break; break;
} }
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
...@@ -736,6 +738,7 @@ void dispatch_scaled_masked_softmax_backward(output_t *grad_input, const input_t ...@@ -736,6 +738,7 @@ void dispatch_scaled_masked_softmax_backward(output_t *grad_input, const input_t
default: default:
break; break;
} }
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
......
...@@ -445,6 +445,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(output_t *dst, const in ...@@ -445,6 +445,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(output_t *dst, const in
default: default:
break; break;
} }
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
...@@ -561,6 +562,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(output_t *grad_input, ...@@ -561,6 +562,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(output_t *grad_input,
default: default:
break; break;
} }
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
......
...@@ -413,6 +413,7 @@ void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag, ...@@ -413,6 +413,7 @@ void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag,
reinterpret_cast<float *>(ret.data.dptr), reinterpret_cast<float *>(ret.data.dptr),
per_tensor ? reinterpret_cast<float *>(ret_per_tensor.data.dptr) : nullptr, per_tensor, per_tensor ? reinterpret_cast<float *>(ret_per_tensor.data.dptr) : nullptr, per_tensor,
max_chunks_per_tensor); max_chunks_per_tensor);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag, void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag,
...@@ -440,6 +441,7 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag, ...@@ -440,6 +441,7 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag,
reinterpret_cast<float *>(ret.data.dptr), reinterpret_cast<float *>(ret.data.dptr),
per_tensor ? reinterpret_cast<float *>(ret_per_tensor.data.dptr) : nullptr, per_tensor, per_tensor ? reinterpret_cast<float *>(ret_per_tensor.data.dptr) : nullptr, per_tensor,
max_chunks_per_tensor); max_chunks_per_tensor);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} // namespace multi_tensor_l2norm } // namespace multi_tensor_l2norm
......
...@@ -138,8 +138,8 @@ void TeNormalizationPlan<KernelParamsType>::_set_workspace() { ...@@ -138,8 +138,8 @@ void TeNormalizationPlan<KernelParamsType>::_set_workspace() {
if (_launch_params.barrier_bytes > 0) { if (_launch_params.barrier_bytes > 0) {
_launch_params.params.barrier = _launch_params.params.barrier =
reinterpret_cast<int*>(workspace_dptr + _launch_params.workspace_bytes); reinterpret_cast<int*>(workspace_dptr + _launch_params.workspace_bytes);
cudaMemsetAsync(_launch_params.params.barrier, 0, _launch_params.barrier_bytes, NVTE_CHECK_CUDA(cudaMemsetAsync(_launch_params.params.barrier, 0,
_launch_params.stream); _launch_params.barrier_bytes, _launch_params.stream));
} }
if constexpr (std::is_same_v<KernelParamsType, BackwardKernelParams>) { if constexpr (std::is_same_v<KernelParamsType, BackwardKernelParams>) {
_launch_params.params.dgamma_part = _launch_params.params.dgamma_part =
......
...@@ -14,7 +14,7 @@ using namespace transformer_engine::normalization; ...@@ -14,7 +14,7 @@ using namespace transformer_engine::normalization;
template <typename weight_t, typename input_t, typename output_t, typename compute_t, template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N, typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL> int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL>
void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params, void launch_ln_bwd_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*) const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE, using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>; CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
...@@ -22,8 +22,8 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params, ...@@ -22,8 +22,8 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
if (configure_params) { if (configure_params) {
int ctas_per_sm; int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
launch_params.params.ctas_per_row = CTAS_PER_ROW; launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col = launch_params.params.ctas_per_col =
launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row;
...@@ -49,13 +49,14 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params, ...@@ -49,13 +49,14 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
if (ctas_per_row == 1) { if (ctas_per_row == 1) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>( kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(
launch_params.params); launch_params.params);
NVTE_CHECK_CUDA(cudaGetLastError());
} else { } else {
dim3 grid(ctas_per_row * ctas_per_col); dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA); dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = reinterpret_cast<void *>(&launch_params.params); void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block, NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), Kernel_traits::SMEM_BYTES, reinterpret_cast<void **>(&params_),
stream); Kernel_traits::SMEM_BYTES, stream));
} }
using Kernel_traits_f = using Kernel_traits_f =
...@@ -66,12 +67,13 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params, ...@@ -66,12 +67,13 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
auto kernel_f = &ln_bwd_finalize_tuned_kernel<Kernel_traits_f>; auto kernel_f = &ln_bwd_finalize_tuned_kernel<Kernel_traits_f>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>( kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(
launch_params.params); launch_params.params);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
template <typename weight_t, typename input_t, typename output_t, typename compute_t, template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG_MAIN, typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL> int BYTES_PER_LDG_FINAL>
void launch_general_(LaunchParams<BackwardKernelParams> &launch_params, void launch_ln_bwd_general_(LaunchParams<BackwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*) const bool configure_params) { // NOLINT(*)
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
...@@ -87,8 +89,8 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params, ...@@ -87,8 +89,8 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
int ctas_per_row = launch_params.params.ctas_per_row; int ctas_per_row = launch_params.params.ctas_per_row;
if (configure_params) { if (configure_params) {
int ctas_per_sm; int ctas_per_sm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
Kernel_traits::THREADS_PER_CTA, 0); &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0));
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_row = ceil_div(cols, HIDDEN_SIZE); ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row);
...@@ -109,10 +111,11 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params, ...@@ -109,10 +111,11 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
dim3 block(Kernel_traits::THREADS_PER_CTA); dim3 block(Kernel_traits::THREADS_PER_CTA);
if (ctas_per_row == 1) { if (ctas_per_row == 1) {
kernel<<<grid, block, 0, stream>>>(launch_params.params); kernel<<<grid, block, 0, stream>>>(launch_params.params);
NVTE_CHECK_CUDA(cudaGetLastError());
} else { } else {
void *params_ = reinterpret_cast<void *>(&launch_params.params); void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block, NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream); reinterpret_cast<void **>(&params_), 0, stream));
} }
// Launch finalization kernel // Launch finalization kernel
...@@ -126,6 +129,7 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params, ...@@ -126,6 +129,7 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL);
dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1);
kernel_final<<<grid_final, block_final, 0, stream>>>(launch_params.params); kernel_final<<<grid_final, block_final, 0, stream>>>(launch_params.params);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
#define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \ #define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \
...@@ -134,8 +138,8 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params, ...@@ -134,8 +138,8 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
void \ void \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \ LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \
launch_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, __VA_ARGS__>( \ launch_ln_bwd_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, \
launch_params, configure_params); \ __VA_ARGS__>(launch_params, configure_params); \
} \ } \
REGISTER_NORM_BASE( \ REGISTER_NORM_BASE( \
NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
......
...@@ -13,15 +13,15 @@ using namespace transformer_engine::normalization; ...@@ -13,15 +13,15 @@ using namespace transformer_engine::normalization;
template <typename weight_t, typename input_t, typename output_t, typename compute_t, template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N, typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
int BYTES_PER_LDG> int BYTES_PER_LDG>
void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params, void launch_ln_fwd_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*) const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE, using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>; CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>;
auto kernel = &ln_fwd_tuned_kernel<Kernel_traits>; auto kernel = &ln_fwd_tuned_kernel<Kernel_traits>;
if (configure_params) { if (configure_params) {
int ctas_per_sm; int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
launch_params.params.ctas_per_row = CTAS_PER_ROW; launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col = launch_params.params.ctas_per_col =
launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row;
...@@ -45,18 +45,20 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params, ...@@ -45,18 +45,20 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
if (ctas_per_row == 1) { if (ctas_per_row == 1) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>( kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(
launch_params.params); launch_params.params);
NVTE_CHECK_CUDA(cudaGetLastError());
} else { } else {
dim3 grid(ctas_per_row * ctas_per_col); dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA); dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = reinterpret_cast<void *>(&launch_params.params); void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, // NOLINT(*) NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
Kernel_traits::SMEM_BYTES_FWD, stream); reinterpret_cast<void **>(&params_),
Kernel_traits::SMEM_BYTES_FWD, stream));
} }
} }
template <typename weight_t, typename input_t, typename output_t, typename compute_t, template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG> typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG>
void launch_general_(LaunchParams<ForwardKernelParams> &launch_params, void launch_ln_fwd_general_(LaunchParams<ForwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*) const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE, using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
1, WARPS_M, WARPS_N, BYTES_PER_LDG>; 1, WARPS_M, WARPS_N, BYTES_PER_LDG>;
...@@ -70,8 +72,8 @@ void launch_general_(LaunchParams<ForwardKernelParams> &launch_params, ...@@ -70,8 +72,8 @@ void launch_general_(LaunchParams<ForwardKernelParams> &launch_params,
int ctas_per_row = launch_params.params.ctas_per_row; int ctas_per_row = launch_params.params.ctas_per_row;
if (configure_params) { if (configure_params) {
int ctas_per_sm; int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0));
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_row = ceil_div(cols, HIDDEN_SIZE); ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row);
...@@ -91,10 +93,11 @@ void launch_general_(LaunchParams<ForwardKernelParams> &launch_params, ...@@ -91,10 +93,11 @@ void launch_general_(LaunchParams<ForwardKernelParams> &launch_params,
dim3 block(Kernel_traits::THREADS_PER_CTA); dim3 block(Kernel_traits::THREADS_PER_CTA);
if (ctas_per_row == 1) { if (ctas_per_row == 1) {
kernel<<<grid, block, 0, stream>>>(launch_params.params); kernel<<<grid, block, 0, stream>>>(launch_params.params);
NVTE_CHECK_CUDA(cudaGetLastError());
} else { } else {
void *params_ = reinterpret_cast<void *>(&launch_params.params); void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block, NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream); reinterpret_cast<void **>(&params_), 0, stream));
} }
} }
...@@ -104,8 +107,8 @@ void launch_general_(LaunchParams<ForwardKernelParams> &launch_params, ...@@ -104,8 +107,8 @@ void launch_general_(LaunchParams<ForwardKernelParams> &launch_params,
void \ void \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \ LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \
launch_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, __VA_ARGS__>( \ launch_ln_fwd_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, \
launch_params, configure_params); \ __VA_ARGS__>(launch_params, configure_params); \
} \ } \
REGISTER_NORM_BASE( \ REGISTER_NORM_BASE( \
NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
......
...@@ -13,7 +13,7 @@ using namespace transformer_engine::normalization; ...@@ -13,7 +13,7 @@ using namespace transformer_engine::normalization;
template <typename weight_t, typename input_t, typename output_t, typename compute_t, template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N, typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL, bool FUSED_ADD = false> int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL, bool FUSED_ADD = false>
void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params, void launch_rmsnorm_bwd_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*) const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE, using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>; CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
...@@ -48,6 +48,7 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params, ...@@ -48,6 +48,7 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
if (ctas_per_row == 1) { if (ctas_per_row == 1) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>( kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(
launch_params.params); launch_params.params);
NVTE_CHECK_CUDA(cudaGetLastError());
} else { } else {
dim3 grid(ctas_per_row * ctas_per_col); dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA); dim3 block(Kernel_traits::THREADS_PER_CTA);
...@@ -65,12 +66,13 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params, ...@@ -65,12 +66,13 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
auto kernel_f = &rmsnorm_bwd_finalize_tuned_kernel<Kernel_traits_f>; auto kernel_f = &rmsnorm_bwd_finalize_tuned_kernel<Kernel_traits_f>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>( kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(
launch_params.params); launch_params.params);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
template <typename weight_t, typename input_t, typename output_t, typename compute_t, template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG_MAIN, typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL, bool FUSED_ADD = false> int BYTES_PER_LDG_FINAL, bool FUSED_ADD = false>
void launch_general_(LaunchParams<BackwardKernelParams> &launch_params, void launch_rmsnorm_bwd_general_(LaunchParams<BackwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*) const bool configure_params) { // NOLINT(*)
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
...@@ -110,6 +112,7 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params, ...@@ -110,6 +112,7 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
dim3 block(Kernel_traits::THREADS_PER_CTA); dim3 block(Kernel_traits::THREADS_PER_CTA);
if (ctas_per_row == 1) { if (ctas_per_row == 1) {
kernel<<<grid, block, 0, stream>>>(launch_params.params); kernel<<<grid, block, 0, stream>>>(launch_params.params);
NVTE_CHECK_CUDA(cudaGetLastError());
} else { } else {
void *params_ = reinterpret_cast<void *>(&launch_params.params); void *params_ = reinterpret_cast<void *>(&launch_params.params);
NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block, NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
...@@ -127,6 +130,7 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params, ...@@ -127,6 +130,7 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL);
dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1);
kernel_final<<<grid_final, block_final, 0, stream>>>(launch_params.params); kernel_final<<<grid_final, block_final, 0, stream>>>(launch_params.params);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
#define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \ #define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \
...@@ -135,8 +139,8 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params, ...@@ -135,8 +139,8 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
void \ void \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \ LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \
launch_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, __VA_ARGS__>( \ launch_rmsnorm_bwd_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, \
launch_params, configure_params); \ __VA_ARGS__>(launch_params, configure_params); \
} \ } \
REGISTER_NORM_BASE( \ REGISTER_NORM_BASE( \
NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment