Commit 27ddce40 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents d262ef4c 5b3092a0
......@@ -341,6 +341,7 @@ void thd_read_half_tensor(const Tensor &tensor, const Tensor &cu_seqlens, Tensor
thd_read_half_tensor_kernel<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
half.data.dptr, tensor.data.dptr, reinterpret_cast<int *>(cu_seqlens.data.dptr), batch,
hidden_size_in_bytes, half_idx, tensor_shape[seq_dim]);
NVTE_CHECK_CUDA(cudaGetLastError());
}
/***************************************************************************************************
......@@ -397,11 +398,13 @@ void thd_second_half_lse_correction(Tensor lse, const Tensor &lse_per_step,
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(lse_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
second_half_lse_seqlen);
NVTE_CHECK_CUDA(cudaGetLastError());
} else {
thd_lse_kernel<false, LseCorrectionFunctor><<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(lse_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
second_half_lse_seqlen);
NVTE_CHECK_CUDA(cudaGetLastError());
}
}
......@@ -446,11 +449,13 @@ void thd_read_second_half_lse(const Tensor &lse, const Tensor &cu_seqlens, Tenso
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(half_lse.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
second_half_lse_seqlen);
NVTE_CHECK_CUDA(cudaGetLastError());
} else {
thd_lse_kernel<false, ReadLseFunctor><<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(half_lse.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
second_half_lse_seqlen);
NVTE_CHECK_CUDA(cudaGetLastError());
}
}
......@@ -519,6 +524,7 @@ static void thd_out_correction_helper(Tensor out, const Tensor &out_per_step, co
reinterpret_cast<float *>(lse_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, dim_per_head,
lse_seqlen, lse_per_step_seqlen);
NVTE_CHECK_CUDA(cudaGetLastError());
} else {
thd_out_correction_kernel<dtype, only_second_half, tile, false>
<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
......@@ -528,6 +534,7 @@ static void thd_out_correction_helper(Tensor out, const Tensor &out_per_step, co
reinterpret_cast<float *>(lse_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, dim_per_head,
lse_seqlen, lse_per_step_seqlen);
NVTE_CHECK_CUDA(cudaGetLastError());
}
}
......@@ -602,6 +609,7 @@ static void thd_grad_correction_helper(Tensor grad, const Tensor &grad_per_step,
reinterpret_cast<dtype *>(grad.data.dptr),
reinterpret_cast<dtype *>(grad_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, hidden_size, total_tokens);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <typename dtype>
......@@ -667,6 +675,7 @@ void thd_get_partitioned_indices(const Tensor &cu_seqlens, Tensor output, int to
thd_partition_indices_kernel<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<int *>(output.data.dptr), reinterpret_cast<int *>(cu_seqlens.data.dptr),
batch, total_tokens, world_size, rank);
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace context_parallel
......
......@@ -91,6 +91,7 @@ void prepare_flash_attn_fwd(Tensor qkvi, Tensor qkv, cudaStream_t stream) {
prepare_kernel_fwd<dtype><<<grid, threads, 0, stream>>>(
reinterpret_cast<dtype *>(qkvi.data.dptr), reinterpret_cast<dtype *>(qkv.data.dptr),
shape[1], shape[2], shape[3], shape[4]););
NVTE_CHECK_CUDA(cudaGetLastError());
}
void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream_t stream) {
......@@ -129,6 +130,7 @@ void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream
reinterpret_cast<dtype *>(q.data.dptr), reinterpret_cast<dtype *>(k.data.dptr),
reinterpret_cast<dtype *>(v.data.dptr), reinterpret_cast<dtype *>(qkv.data.dptr),
q_shape[0], q_shape[1], q_shape[2], q_shape[3]););
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace flash_attention
......
......@@ -251,10 +251,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged
(head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 &&
cudnn_runtime_version >= 91100)) &&
// 9.11/9.12 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
(!((cudnn_runtime_version == 91100 || cudnn_runtime_version == 91200) && is_training &&
sm_arch_ == 90 && head_dim_qk >= 128 && head_dim_v >= 128 &&
!(head_dim_qk == 192 && head_dim_v == 128) && head_dim_qk != head_dim_v))) &&
// 9.11+ bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
// Conditional to temporarily use blanket cudnn_runtime_version >= 9.11 until fixed
(!((cudnn_runtime_version >= 91100) && is_training && sm_arch_ == 90 &&
head_dim_qk >= 128 && head_dim_v >= 128 && !(head_dim_qk == 192 && head_dim_v == 128) &&
head_dim_qk != head_dim_v))) &&
// bias type
((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
(cudnn_runtime_version >= 8906 &&
......
......@@ -416,6 +416,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
actual_b, b, static_cast<const int32_t *>(devPtrCuSeqlensQ),
static_cast<const int32_t *>(devPtrCuSeqlensKV), static_cast<int32_t *>(devActualSeqlenQ),
static_cast<int32_t *>(devActualSeqlenKV));
NVTE_CHECK_CUDA(cudaGetLastError());
variant_pack[seq_q] = devActualSeqlenQ;
variant_pack[seq_kv] = devActualSeqlenKV;
}
......@@ -454,6 +455,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ),
static_cast<int32_t *>(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK,
devOffsetsV, devOffsetsO, devOffsetsS);
NVTE_CHECK_CUDA(cudaGetLastError());
if (is_ragged_q) {
variant_pack[offset_q] = devOffsetsQ;
variant_pack[offset_o] = devOffsetsO;
......@@ -883,6 +885,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
actual_b, b, static_cast<const int32_t *>(devPtrCuSeqlensQ),
static_cast<const int32_t *>(devPtrCuSeqlensKV), static_cast<int32_t *>(devActualSeqlenQ),
static_cast<int32_t *>(devActualSeqlenKV));
NVTE_CHECK_CUDA(cudaGetLastError());
variant_pack[seq_q] = devActualSeqlenQ;
variant_pack[seq_kv] = devActualSeqlenKV;
}
......@@ -916,6 +919,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ),
static_cast<int32_t *>(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK,
devOffsetsV, devOffsetsO, devOffsetsS);
NVTE_CHECK_CUDA(cudaGetLastError());
if (is_ragged_q) {
variant_pack[offset_q] = devOffsetsQ;
variant_pack[offset_o] = devOffsetsO;
......
......@@ -1111,6 +1111,7 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
cu_seqlens_to_offsets<<<gridDims, blockDims, 0, stream>>>(
b, h, d, reinterpret_cast<int32_t*>(devPtrcuSeqlensQ), actual_seqlens_q, qkv_ragged_offset,
o_ragged_offset);
NVTE_CHECK_CUDA(cudaGetLastError());
void* devPtrQKVRaggedOffset = reinterpret_cast<void*>(qkv_ragged_offset);
void* devPtrORaggedOffset = reinterpret_cast<void*>(o_ragged_offset);
void* devPtrMNKOverride = reinterpret_cast<void*>(actual_seqlens_q);
......@@ -1577,6 +1578,7 @@ void fused_attn_fp8_bwd_impl(
cu_seqlens_to_offsets<<<gridDims, blockDims, 0, stream>>>(
b, h, d, reinterpret_cast<int32_t*>(devPtrcuSeqlensQ), actual_seqlens_q, qkv_ragged_offset,
o_ragged_offset);
NVTE_CHECK_CUDA(cudaGetLastError());
void* devPtrQKVRaggedOffset = reinterpret_cast<void*>(qkv_ragged_offset);
void* devPtrORaggedOffset = reinterpret_cast<void*>(o_ragged_offset);
void* devPtrMNKOverride = reinterpret_cast<void*>(actual_seqlens_q);
......@@ -1933,6 +1935,7 @@ void fused_attn_fp8_fwd_impl_v1(
b, b, static_cast<const int32_t*>(devPtrcuSeqlensQ), // TODO(pass max_b)
static_cast<const int32_t*>(devPtrcuSeqlensKV), static_cast<int32_t*>(devActualSeqlenQ),
static_cast<int32_t*>(devActualSeqlenKV));
NVTE_CHECK_CUDA(cudaGetLastError());
variant_pack[seq_q] = devActualSeqlenQ;
variant_pack[seq_kv] = devActualSeqlenKV;
}
......@@ -2329,6 +2332,7 @@ void fused_attn_fp8_bwd_impl_v1(
b, b, static_cast<const int32_t*>(devPtrcuSeqlensQ), // TODO(pass max_b)
static_cast<const int32_t*>(devPtrcuSeqlensKV), static_cast<int32_t*>(devActualSeqlenQ),
static_cast<int32_t*>(devActualSeqlenKV));
NVTE_CHECK_CUDA(cudaGetLastError());
variant_pack[seq_q] = devActualSeqlenQ;
variant_pack[seq_kv] = devActualSeqlenKV;
}
......
......@@ -157,6 +157,7 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso
reinterpret_cast<int *>(page_table.data.dptr),
reinterpret_cast<int *>(cu_new_lens.data.dptr),
reinterpret_cast<int *>(cu_cached_lens.data.dptr), h_kv, d_k, d_v, b, max_seq_len);
NVTE_CHECK_CUDA(cudaGetLastError());
}
dim3 grid_size(b, max_ctx_len);
copy_to_kv_cache_kernel<<<grid_size, block_size, 0, stream>>>(
......@@ -166,6 +167,7 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso
reinterpret_cast<int *>(cu_new_lens.data.dptr),
reinterpret_cast<int *>(cu_cached_lens.data.dptr), qkv_format, h_kv, d_k, d_v, b,
max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged);
NVTE_CHECK_CUDA(cudaGetLastError());
}
}
......@@ -215,6 +217,7 @@ void convert_thd_to_bshd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_se
reinterpret_cast<scalar_t *>(tensor.data.dptr),
reinterpret_cast<scalar_t *>(new_tensor.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), b, max_seq_len, h, d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
void convert_thd_to_bshd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int b,
......@@ -254,6 +257,7 @@ void convert_bshd_to_thd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_se
reinterpret_cast<scalar_t *>(tensor.data.dptr),
reinterpret_cast<scalar_t *>(new_tensor.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), b, max_seq_len, h, d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
void convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int t,
......
......@@ -600,13 +600,14 @@ uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cud
// workspace size requires 4 bytes
uint32_t *dout = static_cast<uint32_t *>(workspace);
uint32_t hout{};
cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream);
NVTE_CHECK_CUDA(cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream));
constexpr int threads = 128;
const int blocks = (len - 1) / threads + 1;
get_runtime_num_segments_kernel<<<blocks, threads, 0, stream>>>(static_cast<int32_t *>(cu_seqlen),
len, dout);
cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
NVTE_CHECK_CUDA(cudaGetLastError());
NVTE_CHECK_CUDA(cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream));
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
return hout;
}
......@@ -633,4 +634,5 @@ void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t
fused_attn::extract_seed_and_offset<<<1, 1, 0, stream>>>(
rng_state_ptr, captured, seed_ptr, seed_val, offset_ptr, offset_val, offset_intragraph);
NVTE_CHECK_CUDA(cudaGetLastError());
}
......@@ -21,12 +21,21 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs
const int h, const int d, const int d2, const int stride_h,
const int stride_d, const int o_stride_h,
const int o_stride_d) {
extern __shared__ float shared_mem_cos_sin[];
float *shared_mem_cos = shared_mem_cos_sin;
float *shared_mem_sin = shared_mem_cos_sin + d2;
int tid = threadIdx.x * blockDim.y + threadIdx.y;
for (int i = tid; i < d2; i += blockDim.x * blockDim.y) {
sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]);
}
__syncthreads();
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
float v_cos, v_sin;
sincosf(freqs[s_id * d2 + d_id], &v_sin, &v_cos);
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
float v_cos = shared_mem_cos[d_id];
float v_sin = shared_mem_sin[d_id];
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
float v_src = src[offset_src];
......@@ -49,12 +58,12 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs
// copy the rest
if (d > d2) {
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_head = offset_block + h_id * stride_h;
int offset_head_dst = offset_block_dst + h_id * o_stride_h;
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
#pragma unroll
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d];
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
dst[offset_dst] = src[offset_src];
}
}
}
......@@ -67,47 +76,54 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq
const int h, const int d, const int d2,
const int stride_h, const int stride_d,
const int o_stride_h, const int o_stride_d) {
extern __shared__ float shared_mem_cos_sin[];
float *shared_mem_cos = shared_mem_cos_sin;
float *shared_mem_sin = shared_mem_cos_sin + d2;
int tid = threadIdx.x * blockDim.y + threadIdx.y;
for (int i = tid; i < d2; i += blockDim.x * blockDim.y) {
sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]);
}
__syncthreads();
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
float v_cos = cosf(freqs[s_id * d2 + d_id]);
float v_sin;
if (!interleaved) {
v_sin = (d_id + d2 / 2 < d2) ? sinf(freqs[s_id * d2 + d_id + d2 / 2])
: -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]);
} else {
v_sin =
(d_id % 2 == 0) ? sinf(freqs[s_id * d2 + d_id + 1]) : -sinf(freqs[s_id * d2 + d_id - 1]);
}
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
float v_src = src[offset_src];
float v_src_rotate;
float v_cos = shared_mem_cos[d_id];
float v_src_rotate, v_sin;
if (!interleaved) {
v_src_rotate = (d_id + d2 / 2 < d2)
? static_cast<float>(src[offset_src + (d2 / 2) * stride_d])
: static_cast<float>(src[offset_src + (d2 / 2 - d2) * stride_d]);
if (d_id + d2 / 2 < d2) {
v_src_rotate = static_cast<float>(src[offset_src + (d2 / 2) * stride_d]);
v_sin = shared_mem_sin[d_id + d2 / 2];
} else {
v_src_rotate = static_cast<float>(src[offset_src + (d2 / 2 - d2) * stride_d]);
v_sin = -shared_mem_sin[d_id + d2 / 2 - d2];
}
} else {
v_src_rotate = (d_id % 2 == 0)
// d_id + 1
? static_cast<float>(src[offset_src + stride_d])
// d_id - 1
: static_cast<float>(src[offset_src - stride_d]);
if (d_id % 2 == 0) {
v_src_rotate = static_cast<float>(src[offset_src + stride_d]);
v_sin = shared_mem_sin[d_id + 1];
} else {
v_src_rotate = static_cast<float>(src[offset_src - stride_d]);
v_sin = -shared_mem_sin[d_id - 1];
}
}
dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
}
}
// handle the tail
// copy the rest
if (d > d2) {
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_head = offset_block + h_id * stride_h;
int offset_head_dst = offset_block_dst + h_id * o_stride_h;
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
#pragma unroll
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d];
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
dst[offset_dst] = src[offset_src];
}
}
}
......@@ -198,6 +214,251 @@ __global__ void fused_rope_backward_kernel(
offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
__device__ void fused_qkv_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *out,
const bool interleaved, const int s_id,
const int offset_block, const int offset_block_dst,
const int h, const int d, const int d2,
const int row_offset, const int in_row_length,
const int out_row_length) {
extern __shared__ float shared_mem_cos_sin_qk[];
// Split the shared memory into cos and sin parts for q or k
float *shared_mem_cos = nullptr;
float *shared_mem_sin = nullptr;
if (row_offset == 0) { // q
shared_mem_cos = shared_mem_cos_sin_qk;
shared_mem_sin = shared_mem_cos_sin_qk + d2;
} else { // k
shared_mem_cos = shared_mem_cos_sin_qk + 2 * d2;
shared_mem_sin = shared_mem_cos_sin_qk + 3 * d2;
}
if (freqs != nullptr) {
int tid = threadIdx.x * blockDim.y + threadIdx.y;
for (int i = tid; i < d2; i += blockDim.x * blockDim.y) {
sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]);
}
}
__syncthreads();
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
#pragma unroll
for (int i = 0; i < out_row_length; i += d) {
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
int offset_src = offset_block + h_id * in_row_length + (row_offset + i) + d_id;
int offset_dst = offset_block_dst + h_id * out_row_length + i + d_id;
if (freqs != nullptr) {
float v_cos, v_sin;
v_cos = shared_mem_cos[d_id];
v_sin = shared_mem_sin[d_id];
float v_src = src[offset_src];
float v_src_rotate;
if (!interleaved) {
v_src_rotate = (d_id + d2 / 2 < d2)
? -static_cast<float>(src[offset_src + (d2 / 2)])
: static_cast<float>(src[offset_src + (d2 / 2 - d2)]);
} else {
v_src_rotate = (d_id % 2 == 0) ? -static_cast<float>(src[offset_src + 1])
: static_cast<float>(src[offset_src - 1]);
}
out[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
} else {
out[offset_dst] = src[offset_src];
}
}
}
}
// copy the rest
if (d > d2) {
#pragma unroll
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
#pragma unroll
for (int i = 0; i < out_row_length; i += d) {
int offset_src = offset_block + h_id * in_row_length + (row_offset + i) + d_id;
int offset_dst = offset_block_dst + h_id * out_row_length + i + d_id;
out[offset_dst] = src[offset_src];
}
}
}
}
}
template <typename scalar_t>
__device__ void fused_qkv_rope_block_backward(const scalar_t *grad_out, const float *freqs,
scalar_t *out, const bool interleaved, const int s_id,
const int offset_block, const int offset_block_dst,
const int h, const int d, const int d2,
const int row_offset, const int in_row_length,
const int out_row_length) {
extern __shared__ float shared_mem_cos_sin_qk[];
float *shared_mem_cos = nullptr;
float *shared_mem_sin = nullptr;
// Split the shared memory into cos and sin parts for q or k
if (row_offset == 0) { // q
shared_mem_cos = shared_mem_cos_sin_qk;
shared_mem_sin = shared_mem_cos_sin_qk + d2;
} else { // k
shared_mem_cos = shared_mem_cos_sin_qk + 2 * d2;
shared_mem_sin = shared_mem_cos_sin_qk + 3 * d2;
}
if (freqs != nullptr) {
int tid = threadIdx.x * blockDim.y + threadIdx.y;
for (int i = tid; i < d2; i += blockDim.x * blockDim.y) {
sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]);
}
}
__syncthreads();
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
#pragma unroll
for (int i = 0; i < out_row_length; i += d) {
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
int offset_dst = offset_block + h_id * in_row_length + (row_offset + i) + d_id;
int offset_src = offset_block_dst + h_id * out_row_length + i + d_id;
float v_src = grad_out[offset_src];
if (freqs != nullptr) {
float v_cos, v_sin;
v_cos = shared_mem_cos[d_id];
float v_src_rotate;
if (!interleaved) {
if (d_id + d2 / 2 < d2) {
v_src_rotate = static_cast<float>(grad_out[offset_src + (d2 / 2)]);
v_sin = shared_mem_sin[d_id + d2 / 2];
} else {
v_src_rotate = static_cast<float>(grad_out[offset_src + (d2 / 2 - d2)]);
v_sin = -shared_mem_sin[d_id + d2 / 2 - d2];
}
} else {
if (d_id % 2 == 0) {
v_src_rotate = static_cast<float>(grad_out[offset_src + 1]);
v_sin = shared_mem_sin[d_id + 1];
} else {
v_src_rotate = static_cast<float>(grad_out[offset_src - 1]);
v_sin = -shared_mem_sin[d_id - 1];
}
}
out[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
} else {
out[offset_dst] = grad_out[offset_src];
}
}
}
}
// copy the rest
if (d > d2) {
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
#pragma unroll
for (int i = 0; i < out_row_length; i += d) {
#pragma unroll
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
int offset_dst = offset_block + h_id * in_row_length + (row_offset + i) + d_id;
int offset_src = offset_block_dst + h_id * out_row_length + i + d_id;
out[offset_dst] = grad_out[offset_src];
}
}
}
}
}
template <typename scalar_t>
__global__ void fused_qkv_rope_forward_kernel(
const scalar_t *qkv_input, const float *q_freqs, const float *k_freqs,
const int *start_positions, scalar_t *q_out, scalar_t *k_out, scalar_t *v_out,
const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2, const int q_split_arg,
const int k_split_arg, const int v_split_arg) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int cur_seqlens = s;
int total_d = q_split_arg + k_split_arg + v_split_arg;
int offset_block, offset_block_dst_q, offset_block_dst_k, offset_block_dst_v;
if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) {
offset_block = s_id * b * h * total_d + b_id * h * total_d;
offset_block_dst_q = s_id * b * h * q_split_arg + b_id * h * q_split_arg;
offset_block_dst_k = s_id * b * h * k_split_arg + b_id * h * k_split_arg;
offset_block_dst_v = s_id * b * h * v_split_arg + b_id * h * v_split_arg;
} else {
offset_block = b_id * s * h * total_d + s_id * h * total_d;
offset_block_dst_q = b_id * s * h * q_split_arg + s_id * h * q_split_arg;
offset_block_dst_k = b_id * s * h * k_split_arg + s_id * h * k_split_arg;
offset_block_dst_v = b_id * s * h * v_split_arg + s_id * h * v_split_arg;
}
int q_limit = q_split_arg;
int k_limit = q_limit + k_split_arg;
int s_id_for_freqs;
if (cp_size > 1) {
assert(cur_seqlens % 2 == 0);
if (s_id < cur_seqlens / 2) {
s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2;
} else {
s_id_for_freqs =
cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2;
}
} else {
int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id];
s_id_for_freqs = s_id + begin_offset;
}
fused_qkv_rope_block_forward(qkv_input, q_freqs, q_out, interleaved, s_id_for_freqs, offset_block,
offset_block_dst_q, h, d, d2, 0, total_d, q_split_arg);
fused_qkv_rope_block_forward(qkv_input, k_freqs, k_out, interleaved, s_id_for_freqs, offset_block,
offset_block_dst_k, h, d, d2, q_limit, total_d, k_split_arg);
fused_qkv_rope_block_forward(qkv_input, nullptr, v_out, interleaved, s_id_for_freqs, offset_block,
offset_block_dst_v, h, d, d2, k_limit, total_d, v_split_arg);
}
template <typename scalar_t>
__global__ void fused_qkv_rope_backward_kernel(
const scalar_t *grad_out_q, const scalar_t *grad_out_k, const scalar_t *grad_out_v,
const float *q_freqs, const float *k_freqs, scalar_t *qkv_grad,
const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2, const int q_split_arg,
const int k_split_arg, const int v_split_arg) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int cur_seqlens = s;
int offset_block, offset_block_dst_q, offset_block_dst_k, offset_block_dst_v;
int total_d = q_split_arg + k_split_arg + v_split_arg;
if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) {
offset_block = s_id * b * h * total_d + b_id * h * total_d;
offset_block_dst_q = s_id * b * h * q_split_arg + b_id * h * q_split_arg;
offset_block_dst_k = s_id * b * h * k_split_arg + b_id * h * k_split_arg;
offset_block_dst_v = s_id * b * h * v_split_arg + b_id * h * v_split_arg;
} else {
offset_block = b_id * s * h * total_d + s_id * h * total_d;
offset_block_dst_q = b_id * s * h * q_split_arg + s_id * h * q_split_arg;
offset_block_dst_k = b_id * s * h * k_split_arg + s_id * h * k_split_arg;
offset_block_dst_v = b_id * s * h * v_split_arg + s_id * h * v_split_arg;
}
int q_limit = q_split_arg;
int k_limit = q_limit + k_split_arg;
int s_id_for_freqs;
if (cp_size > 1) {
assert(cur_seqlens % 2 == 0);
if (s_id < cur_seqlens / 2) {
s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2;
} else {
s_id_for_freqs =
cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2;
}
} else {
s_id_for_freqs = s_id;
}
fused_qkv_rope_block_backward(grad_out_q, q_freqs, qkv_grad, interleaved, s_id_for_freqs,
offset_block, offset_block_dst_q, h, d, d2, 0, total_d,
q_split_arg);
fused_qkv_rope_block_backward(grad_out_k, k_freqs, qkv_grad, interleaved, s_id_for_freqs,
offset_block, offset_block_dst_k, h, d, d2, q_limit, total_d,
k_split_arg);
fused_qkv_rope_block_backward(grad_out_v, nullptr, qkv_grad, interleaved, s_id_for_freqs,
offset_block, offset_block_dst_v, h, d, d2, k_limit, total_d,
v_split_arg);
}
template <typename scalar_t>
void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, const float *freqs,
const int *start_positions, scalar_t *output,
......@@ -209,6 +470,7 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
const int shared_mem_size = 2 * d2 * sizeof(float); // cos, sin
int o_stride_s_or_t, o_stride_b;
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format");
......@@ -224,7 +486,7 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c
const int o_stride_h = d;
const int o_stride_d = 1;
fused_rope_forward_kernel<<<blocks, threads, 0, stream>>>(
fused_rope_forward_kernel<<<blocks, threads, shared_mem_size, stream>>>(
input, cu_seqlens, freqs, start_positions, output, interleaved, cp_size, cp_rank, s, h, d, d2,
stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h,
o_stride_d);
......@@ -242,6 +504,7 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
const int shared_mem_size = 2 * d2 * sizeof(float); // cos, sin
int o_stride_s_or_t, o_stride_b;
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format");
......@@ -257,13 +520,58 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se
const int o_stride_h = d;
const int o_stride_d = 1;
fused_rope_backward_kernel<<<blocks, threads, 0, stream>>>(
fused_rope_backward_kernel<<<blocks, threads, shared_mem_size, stream>>>(
output_grads, cu_seqlens, freqs, input_grads, interleaved, cp_size, cp_rank, s, h, d, d2,
stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h,
o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <typename scalar_t>
void fused_qkv_rope_forward_launcher(const scalar_t *qkv_input, const float *q_freqs,
const float *k_freqs, const int *start_positions,
scalar_t *q_out, scalar_t *k_out, scalar_t *v_out,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2,
const int qkv_split_arg_list_0, const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream) {
const int THREADS_PER_WARP = 32;
int warps_per_block = (h <= 8) ? h : 8;
dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
const int shared_mem_size = 4 * d2 * sizeof(float); // cos, sin * q ,k
fused_qkv_rope_forward_kernel<<<blocks, threads, shared_mem_size, stream>>>(
qkv_input, q_freqs, k_freqs, start_positions, q_out, k_out, v_out, qkv_format, interleaved,
cp_size, cp_rank, s, b, h, d, d2, qkv_split_arg_list_0, qkv_split_arg_list_1,
qkv_split_arg_list_2);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <typename scalar_t>
void fused_qkv_rope_backward_launcher(const scalar_t *q_grad_out, const scalar_t *k_grad_out,
const scalar_t *v_grad_out, const float *q_freqs,
const float *k_freqs, scalar_t *qkv_grad_input,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s,
const int b, const int h, const int d, const int d2,
const int qkv_split_arg_list_0,
const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream) {
const int THREADS_PER_WARP = 32;
const int warps_per_block = (h <= 8) ? h : 8;
dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
const int shared_mem_size = 4 * d2 * sizeof(float); // cos, sin * q ,k
fused_qkv_rope_backward_kernel<<<blocks, threads, shared_mem_size, stream>>>(
q_grad_out, k_grad_out, v_grad_out, q_freqs, k_freqs, qkv_grad_input, qkv_format, interleaved,
cp_size, cp_rank, s, b, h, d, d2, qkv_split_arg_list_0, qkv_split_arg_list_1,
qkv_split_arg_list_2);
NVTE_CHECK_CUDA(cudaGetLastError());
}
void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs,
const Tensor &start_positions, Tensor *output,
const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size,
......@@ -297,6 +605,46 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, c
stride_b, stride_h, stride_d, stream););
}
void fused_qkv_rope_forward(const Tensor &qkv_input, const Tensor &q_freqs, const Tensor &k_freqs,
const Tensor &start_positions, Tensor *q_out, Tensor *k_out,
Tensor *v_out, const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2, const int qkv_split_arg_list_0,
const int qkv_split_arg_list_1, const int qkv_split_arg_list_2,
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
qkv_input.data.dtype, scalar_t,
fused_qkv_rope_forward_launcher(reinterpret_cast<const scalar_t *>(qkv_input.data.dptr),
reinterpret_cast<const float *>(q_freqs.data.dptr),
reinterpret_cast<const float *>(k_freqs.data.dptr),
reinterpret_cast<const int *>(start_positions.data.dptr),
reinterpret_cast<scalar_t *>(q_out->data.dptr),
reinterpret_cast<scalar_t *>(k_out->data.dptr),
reinterpret_cast<scalar_t *>(v_out->data.dptr), qkv_format,
interleaved, cp_size, cp_rank, s, b, h, d, d2,
qkv_split_arg_list_0, qkv_split_arg_list_1,
qkv_split_arg_list_2, stream););
}
void fused_qkv_rope_backward(const Tensor &q_grad_out, const Tensor &k_grad_out,
const Tensor &v_grad_out, const Tensor &q_freqs, const Tensor &k_freqs,
Tensor *qkv_grad_input, const NVTE_QKV_Format qkv_format,
const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2,
const int qkv_split_arg_list_0, const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
q_grad_out.data.dtype, scalar_t,
fused_qkv_rope_backward_launcher(reinterpret_cast<const scalar_t *>(q_grad_out.data.dptr),
reinterpret_cast<const scalar_t *>(k_grad_out.data.dptr),
reinterpret_cast<const scalar_t *>(v_grad_out.data.dptr),
reinterpret_cast<const float *>(q_freqs.data.dptr),
reinterpret_cast<const float *>(k_freqs.data.dptr),
reinterpret_cast<scalar_t *>(qkv_grad_input->data.dptr),
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2,
qkv_split_arg_list_0, qkv_split_arg_list_1,
qkv_split_arg_list_2, stream););
}
} // end namespace transformer_engine
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens,
......@@ -328,3 +676,38 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
stride_b, stride_h, stride_d, stream);
}
void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs,
const NVTETensor k_freqs, const NVTETensor start_positions,
NVTETensor q_out, NVTETensor k_out, NVTETensor v_out,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2,
const int qkv_split_arg_list_0, const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_qkv_rope_forward);
using namespace transformer_engine;
fused_qkv_rope_forward(*convertNVTETensorCheck(qkv_input), *convertNVTETensorCheck(q_freqs),
*convertNVTETensorCheck(k_freqs), *convertNVTETensorCheck(start_positions),
convertNVTETensorCheck(q_out), convertNVTETensorCheck(k_out),
convertNVTETensorCheck(v_out), qkv_format, interleaved, cp_size, cp_rank,
s, b, h, d, d2, qkv_split_arg_list_0, qkv_split_arg_list_1,
qkv_split_arg_list_2, stream);
}
void nvte_fused_qkv_rope_backward(const NVTETensor q_grad_out, const NVTETensor k_grad_out,
const NVTETensor v_grad_out, const NVTETensor q_freqs,
const NVTETensor k_freqs, NVTETensor qkv_grad_input,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2,
const int qkv_split_arg_list_0, const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_qkv_rope_backward);
using namespace transformer_engine;
fused_qkv_rope_backward(*convertNVTETensorCheck(q_grad_out), *convertNVTETensorCheck(k_grad_out),
*convertNVTETensorCheck(v_grad_out), *convertNVTETensorCheck(q_freqs),
*convertNVTETensorCheck(k_freqs), convertNVTETensorCheck(qkv_grad_input),
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2,
qkv_split_arg_list_0, qkv_split_arg_list_1, qkv_split_arg_list_2, stream);
}
......@@ -178,9 +178,9 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
config.stream = stream;
// Update the max cluster size based on the device
cudaOccupancyMaxPotentialClusterSize(
NVTE_CHECK_CUDA(cudaOccupancyMaxPotentialClusterSize(
&cluster_size,
reinterpret_cast<void*>(fused_moe_aux_loss_forward_kernel<DataType, IndexType>), &config);
reinterpret_cast<void*>(fused_moe_aux_loss_forward_kernel<DataType, IndexType>), &config));
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeClusterDimension;
......@@ -190,15 +190,16 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
config.numAttrs = 1;
config.attrs = attribute;
cudaLaunchKernelEx(&config, fused_moe_aux_loss_forward_kernel<DataType, IndexType>, probs,
tokens_per_expert, total_num_tokens, num_experts, num_rows, num_cols, topk,
coeff, aux_loss, Const_buf);
NVTE_CHECK_CUDA(cudaLaunchKernelEx(
&config, fused_moe_aux_loss_forward_kernel<DataType, IndexType>, probs, tokens_per_expert,
total_num_tokens, num_experts, num_rows, num_cols, topk, coeff, aux_loss, Const_buf));
} else {
#endif
size_t smem_size = sizeof(CompType) * num_cols;
fused_moe_aux_loss_forward_kernel<DataType, IndexType>
<<<1, 1024, smem_size, stream>>>(probs, tokens_per_expert, total_num_tokens, num_experts,
num_rows, num_cols, topk, coeff, aux_loss, Const_buf);
NVTE_CHECK_CUDA(cudaGetLastError());
#ifndef __HIP_PLATFORM_AMD__
}
#endif
......@@ -232,7 +233,7 @@ __global__ void fused_moe_aux_loss_backward_kernel(const float* Const_buf,
// Loop: for all positions in each row
for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) {
float C_coeff = Const_buf[0];
IndexType tokens_per_expert_i = tokens_per_expert[i];
double tokens_per_expert_i = static_cast<double>(tokens_per_expert[i]);
double grad_aux_loss_value = static_cast<double>(grad_aux_loss[0]);
// Loop: for all rows
for (int j = global_warp_id; j < num_rows; j += global_warp_num) {
......@@ -251,6 +252,7 @@ void fused_moe_aux_loss_backward_kernel_launcher(const float* Const_buf,
int grid_size = (num_rows + block_size - 1) / block_size;
fused_moe_aux_loss_backward_kernel<DataType, IndexType><<<grid_size, block_size, 0, stream>>>(
Const_buf, tokens_per_expert, num_rows, num_cols, grad_aux_loss, grad_probs);
NVTE_CHECK_CUDA(cudaGetLastError());
}
void fused_moe_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_per_expert,
......
......@@ -151,6 +151,7 @@ void fused_score_for_moe_aux_loss_forward_kernel_launcher(
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
logits, num_tokens, num_experts, topk, score_function, scores, routing_map,
intermediate_output);
NVTE_CHECK_CUDA(cudaGetLastError());
}
void fused_score_for_moe_aux_loss_forward(const Tensor &logits, int num_tokens, int num_experts,
......@@ -286,6 +287,7 @@ void fused_score_for_moe_aux_loss_backward_kernel_launcher(
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
intermediate_output, grad_scores, num_tokens, num_experts, topk, score_function,
grad_logits);
NVTE_CHECK_CUDA(cudaGetLastError());
}
void fused_score_for_moe_aux_loss_backward(const Tensor &intermediate_output,
......
......@@ -257,6 +257,7 @@ void fused_topk_with_score_function_forward_kernel_launcher(
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk,
scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output);
NVTE_CHECK_CUDA(cudaGetLastError());
}
void fused_topk_with_score_function_forward(const Tensor logits, int num_tokens, int num_experts,
......@@ -447,6 +448,7 @@ void fused_topk_with_score_function_backward_kernel_launcher(
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
routing_map, intermediate_output, grad_probs, num_tokens, num_experts, topk,
use_pre_softmax, scaling_factor, score_function, grad_logits);
NVTE_CHECK_CUDA(cudaGetLastError());
}
void fused_topk_with_score_function_backward(const Tensor &routing_map,
......
......@@ -271,6 +271,14 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i
using type = int64_t; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
......
......@@ -353,6 +353,7 @@ void call_kernel_scaled_aligned_causal_masked_softmax_forward(
scaled_aligned_causal_masked_softmax_warp_forward<input_t, output_t, acc_t, log2_elements>
<<<grid_size, block_size, shmem_size, stream>>>(dst, src, scale, microbatches, query_seq_len,
key_seq_len);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
......@@ -363,6 +364,7 @@ void call_kernel_scaled_aligned_causal_masked_softmax_backward(
scaled_aligned_causal_masked_softmax_warp_backward<input_t, output_t, acc_t, log2_elements>
<<<grid_size, block_size, 0, stream>>>(gradInput, grad, output, scale, microbatches,
query_seq_len, key_seq_len);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <typename input_t, typename output_t, typename acc_t>
......
......@@ -513,6 +513,7 @@ void dispatch_scaled_softmax_forward(output_t *dst, const input_t *src, const in
default:
break;
}
NVTE_CHECK_CUDA(cudaGetLastError());
}
}
......@@ -625,6 +626,7 @@ void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src, c
default:
break;
}
NVTE_CHECK_CUDA(cudaGetLastError());
}
}
......@@ -736,6 +738,7 @@ void dispatch_scaled_masked_softmax_backward(output_t *grad_input, const input_t
default:
break;
}
NVTE_CHECK_CUDA(cudaGetLastError());
}
}
......
......@@ -445,6 +445,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(output_t *dst, const in
default:
break;
}
NVTE_CHECK_CUDA(cudaGetLastError());
}
}
......@@ -561,6 +562,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(output_t *grad_input,
default:
break;
}
NVTE_CHECK_CUDA(cudaGetLastError());
}
}
......
......@@ -25,28 +25,11 @@
#include "../util/logging.h"
#include "../util/multi_stream.h"
#include "common/util/cuda_runtime.h"
#include "cutlass_grouped_gemm.cuh"
#ifndef __HIP_PLATFORM_AMD__
namespace {
cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kFloat16:
return CUDA_R_16F;
case DType::kFloat32:
return CUDA_R_32F;
case DType::kBFloat16:
return CUDA_R_16BF;
case DType::kFloat8E4M3:
return CUDA_R_8F_E4M3;
case DType::kFloat8E5M2:
return CUDA_R_8F_E5M2;
default:
NVTE_ERROR("Invalid type");
}
}
uint32_t _getAlignment(uintptr_t address) {
// alignment are in bytes
uint32_t alignment = 256;
......@@ -532,22 +515,22 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
&epilogue, sizeof(epilogue)));
if (counter != nullptr) {
#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000)
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ",
#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ",
CUDA_VERSION);
#endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR(
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ",
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
CUBLAS_VERSION);
#endif
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \
CUBLAS_VERSION < 130000
NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA verson is ",
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ",
cuda::cudart_version());
NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000,
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ",
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
cublas_version());
if (m_split == 0) m_split = 1;
if (n_split == 0) n_split = 1;
......@@ -850,20 +833,23 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
#ifndef __HIP_PLATFORM_AMD__
// Check CUDA and cuBLAS versions
#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000)
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ",
#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ",
CUDA_VERSION);
#endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR("Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ",
CUBLAS_VERSION);
NVTE_ERROR(
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
CUBLAS_VERSION);
#endif
NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
"Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA verson is ",
cuda::cudart_version());
NVTE_CHECK(
transformer_engine::cuda::cudart_version() >= 12020 &&
transformer_engine::cuda::cudart_version() < 13000,
"Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ",
transformer_engine::cuda::cudart_version());
NVTE_CHECK(
cublas_version() >= 120205 && cublas_version() < 130000,
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ",
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
cublas_version());
#endif
......@@ -934,15 +920,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
#endif //__HIP_PLATFORM_AMD__
}
void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out,
const int num_gemms, bool transa, bool transb, bool grad,
NVTETensor *workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_stream_cublas_gemm);
void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms,
bool transa, bool transb, bool grad, NVTETensor *workspace,
bool accumulate, bool use_split_accumulator, int math_sm_count,
cudaStream_t stream) {
using namespace transformer_engine;
int num_streams = nvte_get_num_compute_streams();
......@@ -989,6 +971,25 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
}
}
void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out,
const int num_gemms, bool transa, bool transb, bool grad,
NVTETensor *workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_stream_cublas_gemm);
using namespace transformer_engine;
// Deprecation warning
NVTE_WARN(
"nvte_multi_stream_cublas_gemm is deprecated and will be removed in a future release. "
"Please migrate to nvte_multi_tensor_gemm (with CUTLASS Grouped GEMM support when "
"applicable).");
multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, workspace,
accumulate, use_split_accumulator, math_sm_count, stream);
}
#ifndef __HIP_PLATFORM_AMD__
namespace transformer_engine {
......@@ -1006,7 +1007,6 @@ void nvte_grouped_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
NVTETensor *workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count,
cudaStream_t stream) {
NVTE_API_CALL(nvte_grouped_gemm);
using namespace transformer_engine;
std::vector<const Tensor*> inputA;
......@@ -1307,4 +1307,98 @@ void nvte_cublas_batchgemm_v3(const NVTETensor A, const NVTETensor B, const NVTE
handle);
}
#endif
\ No newline at end of file
#endif
void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms,
bool transa, bool transb, bool grad, NVTETensor *workspace,
bool accumulate, bool use_split_accumulator, int math_sm_count,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_gemm);
#ifdef __HIP_PLATFORM_AMD__
const char *NVTE_USE_HIPBLASLT_GROUPEDGEMM = std::getenv("NVTE_USE_HIPBLASLT_GROUPEDGEMM");
if(NVTE_USE_HIPBLASLT_GROUPEDGEMM != nullptr && NVTE_USE_HIPBLASLT_GROUPEDGEMM[0] == '1'){
nvte_grouped_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad,
workspace, accumulate, use_split_accumulator, math_sm_count, stream);
} else {
multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad,
workspace, accumulate, use_split_accumulator, math_sm_count, stream);
}
#else
const int current_device = transformer_engine::cuda::current_device();
const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90);
const bool use_cutlass = transformer_engine::getenv<bool>("NVTE_USE_CUTLASS_GROUPED_GEMM", false);
const bool warn_fallback =
transformer_engine::getenv<bool>("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", false);
auto cublas_path = [&]() {
multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad,
workspace, accumulate, use_split_accumulator, math_sm_count, stream);
};
// Currently only support cutlass group gemm on Hopper Arch
if (!(is_hopper && use_cutlass)) {
cublas_path();
return;
}
auto is_empty_arr = [&](const NVTETensor *p) -> bool {
if (p == nullptr) return true;
for (int i = 0; i < num_gemms; ++i) {
if (transformer_engine::convertNVTETensor(p[i])->has_data()) return false;
}
return true;
};
auto all_groups_uniform_k128 = [&](const NVTETensor *p, bool trans) -> bool {
int64_t ref_k = -1;
for (size_t i = 0; i < num_gemms; i++) {
const auto tensor = transformer_engine::convertNVTETensorCheck(p[i]);
const int k = trans ? tensor->data.shape[0] : tensor->data.shape[1];
if ((k & 127) != 0) return false;
if (ref_k < 0)
ref_k = k;
else if (k != ref_k)
return false;
}
return true;
};
auto is_supported_dtype = [&]() -> bool {
auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]);
auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]);
auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]);
auto A_type = get_cuda_dtype(inputA->data.dtype);
auto B_type = get_cuda_dtype(inputB->data.dtype);
auto D_type = get_cuda_dtype(OutputD->data.dtype);
return (A_type == B_type) && (A_type == D_type) &&
((A_type == CUDA_R_16BF) || (A_type == CUDA_R_16F));
};
// CUTLASS Grouped GEMM fast path (SM90/TMA)
// Conditions:
// - No fused epilogue: both bias and pre_gelu_out are empty.
// - Supported dtypes only: FP16/BF16 (FP32 accumulate).
// - Uniform K across groups and K % 128 == 0.
// - use_split_accumulator is ignored for FP16/BF16.
// - grad is irrelevant when bias/pre_gelu_out are empty.
//
// Otherwise, fall back to cuBLAS.
if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype() &&
all_groups_uniform_k128(B, transb)) {
cutlass_grouped_gemm(A, B, D, num_gemms, transa, transb, grad, workspace, accumulate,
current_device, math_sm_count, stream);
} else {
if (warn_fallback) {
NVTE_WARN("Fallback to cuBLAS grouped GEMM.");
}
cublas_path();
}
#endif
}
/***************************************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
**************************************************************************************************/
#include "cutlass/bfloat16.h"
#include "cutlass/cutlass.h"
#include "cutlass_grouped_gemm.cuh"
namespace transformer_engine {
namespace grouped_gemm {
// Explicit template instantiation to match the template declarations in the .cuh
template void CutlassGroupedGemm<false, false, cutlass::half_t>(const NVTETensor*,
const NVTETensor*, NVTETensor*,
NVTETensor*, float, float, int,
cudaStream_t, int, int);
template void CutlassGroupedGemm<true, false, cutlass::half_t>(const NVTETensor*, const NVTETensor*,
NVTETensor*, NVTETensor*, float,
float, int, cudaStream_t, int, int);
template void CutlassGroupedGemm<false, true, cutlass::half_t>(const NVTETensor*, const NVTETensor*,
NVTETensor*, NVTETensor*, float,
float, int, cudaStream_t, int, int);
template void CutlassGroupedGemm<false, false, cutlass::bfloat16_t>(const NVTETensor*,
const NVTETensor*, NVTETensor*,
NVTETensor*, float, float, int,
cudaStream_t, int, int);
template void CutlassGroupedGemm<true, false, cutlass::bfloat16_t>(const NVTETensor*,
const NVTETensor*, NVTETensor*,
NVTETensor*, float, float, int,
cudaStream_t, int, int);
template void CutlassGroupedGemm<false, true, cutlass::bfloat16_t>(const NVTETensor*,
const NVTETensor*, NVTETensor*,
NVTETensor*, float, float, int,
cudaStream_t, int, int);
} // namespace grouped_gemm
} // namespace transformer_engine
void cutlass_grouped_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, int num_gemms,
bool transa, bool transb, bool grad, NVTETensor* workspace,
bool accumulate, int device, int math_sm_count, cudaStream_t stream) {
using namespace transformer_engine;
auto* inputA = convertNVTETensorCheck(A[0]);
auto* inputB = convertNVTETensorCheck(B[0]);
float one = 1.0;
float zero = 0.0;
float alpha = one;
float beta = (accumulate) ? one : zero;
auto dispatch = [&](auto tag) {
using T = decltype(tag);
if (!transa && !transb) {
grouped_gemm::CutlassGroupedGemm<false, false, T>(B, A, D, workspace, alpha, beta, num_gemms,
stream, device, math_sm_count);
} else if (!transb && transa) {
grouped_gemm::CutlassGroupedGemm<false, true, T>(B, A, D, workspace, alpha, beta, num_gemms,
stream, device, math_sm_count);
} else if (transb && !transa) {
grouped_gemm::CutlassGroupedGemm<true, false, T>(B, A, D, workspace, alpha, beta, num_gemms,
stream, device, math_sm_count);
} else {
NVTE_ERROR("Layout 'TT' is not supported by cutlass_grouped_gemm.");
}
};
if (inputA->data.dtype == DType::kBFloat16) {
dispatch(cutlass::bfloat16_t{});
} else if (inputA->data.dtype == DType::kFloat16) {
dispatch(cutlass::half_t{});
} else {
NVTE_ERROR("Unsupported dtype: only BF16(FP16) are supported.");
}
}
/***************************************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
**************************************************************************************************/
//
// Copyright (c) 2025 Shopee Inc. All Rights Reserved.
//
/**
* @file: cutlass_grouped_gemm.cuh
* @author: min.yang@shopee.com, yangfan.bai@shopee.com, finch.li@shopee.com
* @date: 2025-08-08 16:20:00
* @brief: cutlass group gemm kernel.
**/
#pragma once
#include <transformer_engine/transformer_engine.h>
#include <cub/cub.cuh>
#include <type_traits>
#include "../common.h"
#include "../util/logging.h"
#include "common/util/system.h"
#include "cute/tensor.hpp"
#include "cutlass/bfloat16.h"
#include "cutlass/complex.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/device_memory.h"
#include "cutlass/util/packed_stride.hpp"
namespace transformer_engine {
namespace grouped_gemm {
template <bool trans_a>
using GroupedGemmInputALayout =
std::conditional_t<trans_a, ::cutlass::layout::ColumnMajor, ::cutlass::layout::RowMajor>;
template <bool trans_b>
using GroupedGemmInputBLayout =
std::conditional_t<trans_b, ::cutlass::layout::ColumnMajor, ::cutlass::layout::RowMajor>;
using ProblemShapeType = cute::Shape<int, int, int>;
using ProblemShape = cutlass::gemm::GroupProblemShape<ProblemShapeType>; // <M,N,K> per group
template <typename ScheduleConfig>
struct GemmGivenSchedule {
using ElementA = typename ScheduleConfig::DataType; // Element type for A matrix operand
using ElementB = typename ScheduleConfig::DataType; // Element type for B matrix operand
using ElementC = typename ScheduleConfig::DataType; // Element type for C and D matrix operands
// A matrix configuration
using LayoutA = typename ScheduleConfig::LayoutA; // Layout type for A matrix operand
static constexpr int AlignmentA =
128 / cutlass::sizeof_bits<
ElementA>::value; // Alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using LayoutB = typename ScheduleConfig::LayoutB; // Layout type for B matrix operand
static constexpr int AlignmentB =
128 / cutlass::sizeof_bits<
ElementB>::value; // Alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using LayoutC = typename ScheduleConfig::LayoutC; // Layout type for C and D matrix operands
static constexpr int AlignmentC =
128 / cutlass::sizeof_bits<
ElementC>::value; // Alignment of C matrix in units of elements (up to 16 bytes)
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag =
cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using StageCountType =
cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using TileShape = typename ScheduleConfig::TileShape; // Threadblock-level tile size
using ClusterShape =
typename ScheduleConfig::ClusterShape; // Shape of the threadblocks in a cluster
using KernelSchedule = typename ScheduleConfig::KernelSchedule; // Kernel to launch
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; // Epilogue to launch
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator,
ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC, EpilogueSchedule,
cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*, AlignmentB,
ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};
template <typename DataType_, bool trans_a, bool trans_b>
struct ScheduleConfig {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_128, cute::_128, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_2, cute::_1>;
// TODO(Alan): Add tuning for different scenarios to select the optimal configuration,
// as the current configuration may not be the best.
// using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
// using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
// using TileShape = Shape<cute::_256, cute::_128, cute::_128>;
// using ClusterShape = Shape<cute::_1, cute::_2, cute::_1>;
using LayoutA = GroupedGemmInputALayout<trans_a>;
using LayoutB = GroupedGemmInputBLayout<trans_b>;
using LayoutC = cutlass::layout::RowMajor;
using DataType = DataType_;
};
template <typename DataType_, bool trans_a, bool trans_b>
using GemmGrouped = typename GemmGivenSchedule<ScheduleConfig<DataType_, trans_a, trans_b>>::Gemm;
template <typename GemmT, typename ElementA, typename ElementB, typename ElementC, typename StrideA,
typename StrideB, typename StrideC>
typename GemmT::Arguments MakeArguments(int num_experts, void* problem_sizes_host,
void* problem_sizes, const ElementA** ptr_A,
StrideA* stride_A, const ElementB** ptr_B,
StrideB* stride_B, ElementC** ptr_C, StrideC* stride_C,
float alpha, float beta, int device, int math_sm_count) {
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
cutlass::KernelHardwareInfo kernel_hw_info =
cutlass::KernelHardwareInfo::make_kernel_hardware_info<typename GemmT::GemmKernel>(
device, math_sm_count);
typename GemmT::Arguments arguments;
decltype(arguments.epilogue.thread) fusion_args;
fusion_args.alpha = alpha;
fusion_args.beta = beta;
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = nullptr;
fusion_args.beta_ptr_array = nullptr;
// Single alpha and beta for all groups
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
arguments =
typename GemmT::Arguments{cutlass::gemm::GemmUniversalMode::kGrouped,
{num_experts, reinterpret_cast<ProblemShapeType*>(problem_sizes),
reinterpret_cast<ProblemShapeType const*>(problem_sizes_host)},
{ptr_A, stride_A, ptr_B, stride_B},
{
fusion_args,
(beta > 0.0) ? (const ElementC**)ptr_C : nullptr, // NOLINT(*)
stride_C,
ptr_C,
stride_C,
},
kernel_hw_info};
return arguments;
}
template <typename T>
inline __device__ __host__ T ROUND_UP(T m, T n) {
return (m + n - 1) / n * n;
}
template <typename T>
void debug_type() {
std::cout << typeid(T).name() << std::endl;
}
int64_t inline getGemmCoordSize(int64_t num_gemms) {
return (int64_t)(ROUND_UP(num_gemms * sizeof(ProblemShapeType), 128UL));
}
int64_t inline getPtrSize(int64_t num_gemms) {
return (int64_t)(ROUND_UP(num_gemms * sizeof(half*), 128UL));
}
int64_t inline getLddSize(int64_t num_gemms) {
return (int64_t)(ROUND_UP(num_gemms * sizeof(int64_t), 128UL));
}
// cpu workspace size is 4MB
static constexpr size_t kCPUWorkSpaceSize = 4 * 1024 * 1024;
static char* getHostWorkspace() {
static std::once_flag flag;
static std::shared_ptr<char> workspace;
std::call_once(flag, [&]() {
workspace =
std::shared_ptr<char>(reinterpret_cast<char*>(std::malloc(kCPUWorkSpaceSize)), [](char* p) {
if (p) std::free(p);
});
if (!workspace) {
throw std::bad_alloc();
}
});
return workspace.get();
}
template <bool trans_a, bool trans_b, typename Element>
void CutlassGroupedGemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D,
NVTETensor* workspace, float alpha, float beta, int num_gemms,
cudaStream_t stream, int device, int math_sm_count) {
using Gemm = GemmGrouped<Element, trans_a, trans_b>;
using LayoutA = typename Gemm::LayoutA;
using LayoutB = typename Gemm::LayoutB;
using LayoutC = typename Gemm::LayoutC;
using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
using ElementC = typename Gemm::ElementC;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
typename Gemm::Arguments arguments;
size_t kernel_workspace_size = Gemm::get_workspace_size(arguments);
auto gemm_coord_size = getGemmCoordSize(num_gemms);
auto ptr_size = getPtrSize(num_gemms);
auto ldd_size = getLddSize(num_gemms);
auto param_workspace_size = 3 * ptr_size + 3 * ldd_size + gemm_coord_size;
NVTE_CHECK(
param_workspace_size < kCPUWorkSpaceSize,
"Insufficient kCPUWorkSpaceSize size: required=", static_cast<int64_t>(param_workspace_size),
", available=", static_cast<int64_t>(kCPUWorkSpaceSize), " for CUTLASS grouped GEMM.");
auto total_workspace_size = param_workspace_size + kernel_workspace_size;
transformer_engine::Tensor* wspace = transformer_engine::convertNVTETensor(workspace[0]);
NVTE_CHECK(total_workspace_size < wspace->numel(), "Insufficient workspace[0] size: required=",
static_cast<int64_t>(total_workspace_size),
", available=", static_cast<int64_t>(wspace->numel()), " for CUTLASS grouped GEMM.");
char* workspace_ptr = reinterpret_cast<char*>(wspace->data.dptr);
char* kernel_workspace_ptr = nullptr;
char* host_workspace = getHostWorkspace();
ProblemShapeType* problem_sizes_host = reinterpret_cast<ProblemShapeType*>(host_workspace);
ElementA** ptr_A_host = reinterpret_cast<ElementA**>(host_workspace + gemm_coord_size);
ElementB** ptr_B_host = reinterpret_cast<ElementB**>(host_workspace + gemm_coord_size + ptr_size);
ElementC** ptr_C_host =
reinterpret_cast<ElementC**>(host_workspace + gemm_coord_size + 2 * ptr_size);
int64_t* lda_host =
reinterpret_cast<int64_t*>(host_workspace + gemm_coord_size + 3 * ptr_size + 0 * ldd_size);
int64_t* ldb_host =
reinterpret_cast<int64_t*>(host_workspace + gemm_coord_size + 3 * ptr_size + 1 * ldd_size);
int64_t* ldc_host =
reinterpret_cast<int64_t*>(host_workspace + gemm_coord_size + 3 * ptr_size + 2 * ldd_size);
for (size_t i = 0; i < num_gemms; i++) {
const transformer_engine::Tensor* inputA = transformer_engine::convertNVTETensorCheck(A[i]);
const transformer_engine::Tensor* inputB = transformer_engine::convertNVTETensorCheck(B[i]);
transformer_engine::Tensor* outputD = transformer_engine::convertNVTETensor(D[i]);
const int m = trans_a ? inputA->data.shape[1] : inputA->data.shape[0];
const int k = trans_a ? inputA->data.shape[0] : inputA->data.shape[1];
const int n = trans_b ? inputB->data.shape[0] : inputB->data.shape[1];
auto problem = ProblemShapeType(m, n, k);
problem_sizes_host[i] = problem;
ptr_A_host[i] = reinterpret_cast<ElementA*>(inputA->data.dptr);
ptr_B_host[i] = reinterpret_cast<ElementB*>(inputB->data.dptr);
ptr_C_host[i] = reinterpret_cast<ElementC*>(outputD->data.dptr);
lda_host[i] = LayoutA::packed({m, k}).stride(0);
ldb_host[i] = LayoutB::packed({k, n}).stride(0);
ldc_host[i] = LayoutC::packed({m, n}).stride(0);
}
cudaMemcpyAsync(workspace_ptr, host_workspace, param_workspace_size, cudaMemcpyHostToDevice,
stream);
char* param_workspace_ptr = workspace_ptr;
ProblemShapeType* problem_sizes_device = reinterpret_cast<ProblemShapeType*>(param_workspace_ptr);
const ElementA** ptr_A = reinterpret_cast<const ElementA**>(
reinterpret_cast<char*>(param_workspace_ptr) + gemm_coord_size);
const ElementB** ptr_B = reinterpret_cast<const ElementB**>(
reinterpret_cast<char*>(param_workspace_ptr) + gemm_coord_size + 1 * ptr_size);
ElementC** ptr_C = reinterpret_cast<ElementC**>(reinterpret_cast<char*>(param_workspace_ptr) +
gemm_coord_size + 2 * ptr_size);
StrideA* lda = reinterpret_cast<StrideA*>(reinterpret_cast<char*>(param_workspace_ptr) +
gemm_coord_size + 3 * ptr_size + 0 * ldd_size);
StrideB* ldb = reinterpret_cast<StrideB*>(reinterpret_cast<char*>(param_workspace_ptr) +
gemm_coord_size + 3 * ptr_size + 1 * ldd_size);
StrideC* ldc = reinterpret_cast<StrideC*>(reinterpret_cast<char*>(param_workspace_ptr) +
gemm_coord_size + 3 * ptr_size + 2 * ldd_size);
kernel_workspace_ptr = workspace_ptr + param_workspace_size;
arguments = MakeArguments<Gemm, ElementA, ElementB, ElementC, StrideA, StrideB, StrideC>(
num_gemms, problem_sizes_host, problem_sizes_device, ptr_A, lda, ptr_B, ldb, ptr_C, ldc,
alpha, beta, device, math_sm_count);
Gemm gemm;
// Check can implement the kernel.
if (gemm.can_implement(arguments) != cutlass::Status::kSuccess) {
NVTE_CHECK(false, "Failed to implement CUTLASS Grouped GEMM");
}
// Initialize the kernel.
if (gemm.initialize(arguments, kernel_workspace_ptr) != cutlass::Status::kSuccess) {
NVTE_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM");
}
// Execute the kernel in the current stream.
if (gemm.run(stream) != cutlass::Status::kSuccess) {
NVTE_CHECK(false, "Failed to run CUTLASS Grouped GEMM");
}
}
} // namespace grouped_gemm
} // namespace transformer_engine
void cutlass_grouped_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, int num_gemms,
bool transa, bool transb, bool grad, NVTETensor* workspace,
bool accumulate, int device, int math_sm_count, cudaStream_t stream);
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file comm_gemm.h
* \brief Functions for distributed (multi-GPU) matrix multiplication.
*
* This API is a TE-native binding to cuBLASMp library.
* Refer here: https://docs.nvidia.com/cuda/cublasmp/usage/tp.html for specific
* patterns, which allow communication-computation overlap.
*
* All GEMM functions here have the same computation semantic, as expressed
* on global matrices, similar to nvte_cublas_gemm call:
* - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors
* - `D = AB + bias` if `pre_gelu_out` is empty and `bias` is not empty
* - `D = GELU(AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors
*
* Functions differ in matrix distribution patterns
*/
#ifndef TRANSFORMER_ENGINE_COMMON_COMM_GEMM_H_
#define TRANSFORMER_ENGINE_COMMON_COMM_GEMM_H_
#include <nccl.h>
#include <stdint.h>
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#else
#include <stdbool.h>
#endif
typedef struct NVTECommGemmCtx NVTECommGemmCtx;
enum NVTECommGemmAlgoType {
kNVTECommGemmAlgoDefault = 0,
kNVTECommGemmAlgoSplitP2P = 1,
kNVTECommGemmAlgoSplitMulticast = 2,
kNVTECommGemmAlgoAtomicP2P = 3,
kNVTECommGemmAlgoAtomicMulticast = 4
};
/*! \brief Create a comm-gemm context.
*
* \param[in] comm NCCL communicator.
* \param[in] nranks Number of ranks.
* \param[in] rank Local rank.
*/
NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank);
/*! \brief Destroy a comm-gemm context.
*
* \param[in] ctx Context to destroy.
*/
void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx);
/*! \brief Perform AllGather communication followed by GEMM
*
* Gathers distributed data from all ranks, then computes matrix multiplication.
*
* \param[in] ctx Comm-GEMM context.
* \param[in] m Global m dimension.
* \param[in] n Global n dimension.
* \param[in] k Global k dimension.
* \param[in] a Local part of A matrix.
* \param[in] b Local part of B matrix.
* \param[in,out] d Local part of D matrix.
* \param[in] bias Bias tensor.
* \param[in,out] pre_act_out Local part of output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of gradient computation.
* \param[in] accumulate Whether to accumulate the result into the D matrix.
* \param[in] comm_sm_count Number of GPU SMs to use for communication (default=0: use heuristics)
* \param[in] main_stream CUDA stream used for computation.
* \param[in] algo Algorithm to use.
*/
void nvte_all_gather_gemm(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a,
const NVTETensor b, const NVTETensor d, const NVTETensor bias,
const NVTETensor pre_act_out, bool transa, bool transb, bool grad,
bool accumulate, int comm_sm_count, cudaStream_t main_stream,
NVTECommGemmAlgoType algo);
/*! \brief Perform GEMM followed by ReduceScatter communication
*
* Computes matrix multiplication, then distributes results across ranks with reduction.
*
* \param[in] ctx Comm-GEMM context.
* \param[in] m Global m dimension.
* \param[in] n Global n dimension.
* \param[in] k Global k dimension.
* \param[in] a Local part of A matrix.
* \param[in] b Local part of B matrix.
* \param[in,out] d Local part of D matrix.
* \param[in] bias Bias tensor.
* \param[in,out] pre_act_out Local part of output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of gradient computation.
* \param[in] accumulate Whether to accumulate the result into the D matrix.
* \param[in] comm_sm_count Number of GPU SMs to use for communication (default=0: use heuristics)
* \param[in] main_stream CUDA stream used for computation.
* \param[in] algo Algorithm to use.
*/
void nvte_gemm_reduce_scatter(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k,
const NVTETensor a, const NVTETensor b, const NVTETensor d,
const NVTETensor bias, const NVTETensor pre_act_out, bool transa,
bool transb, bool grad, bool accumulate, int comm_sm_count,
cudaStream_t main_stream, NVTECommGemmAlgoType algo);
/*! \brief Perform GEMM followed by AllReduce communication
*
* Computes matrix multiplication, then reduces results across all ranks.
*
* \param[in] ctx Comm-GEMM context.
* \param[in] m Global m dimension.
* \param[in] n Global n dimension.
* \param[in] k Global k dimension.
* \param[in] a Local part of A matrix.
* \param[in] b Local part of B matrix.
* \param[in,out] d Local part of D matrix.
* \param[in] bias Bias tensor.
* \param[in,out] pre_act_out Local part of output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of gradient computation.
* \param[in] accumulate Whether to accumulate the result into the D matrix.
* \param[in] comm_sm_count Number of GPU SMs to use for communication (default=0: use heuristics)
* \param[in] main_stream CUDA stream used for computation.
* \param[in] algo Algorithm to use.
*/
void nvte_gemm_all_reduce(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a,
const NVTETensor b, const NVTETensor d, const NVTETensor bias,
const NVTETensor pre_act_out, bool transa, bool transb, bool grad,
bool accumulate, int comm_sm_count, cudaStream_t main_stream,
NVTECommGemmAlgoType algo);
/*! \brief Get local number of rows or columns.
*
* Utility function to get local dimension.
* Block size, nranks and local rank is derived from the context ctx.
*
* \param[in] ctx Comm-GEMM context.
* \param[in] global_size Global dimension.
*/
int64_t nvte_comm_gemm_numroc(NVTECommGemmCtx* ctx, int64_t global_size);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_COMM_GEMM_H_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file dropout.h
* \brief Functions for dropout.
*/
#ifndef TRANSFORMER_ENGINE_DROPOUT_FP8_H_
#define TRANSFORMER_ENGINE_DROPOUT_FP8_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief Dropout forward kernel.
*
* \param[in] input Input tensor.
* \param[out] output Output tensor.
* \param[out] mask Mask tensor. Each bit corresponds to an
* output tensor entry. Ones indicate kept
* entries and zeros indicate dropped entries.
* \param[in] rng_state RNG engine inputs.
* \param[in] dropout_probability Dropout probability.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_dropout_fwd(const NVTETensor input, NVTETensor output, NVTETensor mask,
NVTETensor rng_state, float dropout_probability, cudaStream_t stream);
/*! \brief Dropout backward kernel.
*
* \param[in] grad_output Gradient of output tensor.
* \param[out] mask Mask tensor. Each bit corresponds to an
* output tensor entry. Ones indicate kept
* entries and zeros indicate dropped entries.
* \param[out] grad_input Gradient of input tensor.
* \param[in] dropout_probability Dropout probability.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_dropout_bwd(const NVTETensor grad_output, const NVTETensor mask, NVTETensor grad_input,
float dropout_probability, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment