Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
27ddce40
Commit
27ddce40
authored
Oct 11, 2025
by
wenjh
Browse files
Merge branch 'nv_main'
parents
d262ef4c
5b3092a0
Changes
208
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1248 additions
and
92 deletions
+1248
-92
transformer_engine/common/fused_attn/context_parallel.cu
transformer_engine/common/fused_attn/context_parallel.cu
+9
-0
transformer_engine/common/fused_attn/flash_attn.cu
transformer_engine/common/fused_attn/flash_attn.cu
+2
-0
transformer_engine/common/fused_attn/fused_attn.cpp
transformer_engine/common/fused_attn/fused_attn.cpp
+5
-4
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
...gine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
+4
-0
transformer_engine/common/fused_attn/fused_attn_fp8.cu
transformer_engine/common/fused_attn/fused_attn_fp8.cu
+4
-0
transformer_engine/common/fused_attn/kv_cache.cu
transformer_engine/common/fused_attn/kv_cache.cu
+4
-0
transformer_engine/common/fused_attn/utils.cu
transformer_engine/common/fused_attn/utils.cu
+5
-3
transformer_engine/common/fused_rope/fused_rope.cu
transformer_engine/common/fused_rope/fused_rope.cu
+420
-37
transformer_engine/common/fused_router/fused_moe_aux_loss.cu
transformer_engine/common/fused_router/fused_moe_aux_loss.cu
+8
-6
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu
...ngine/common/fused_router/fused_score_for_moe_aux_loss.cu
+2
-0
transformer_engine/common/fused_router/fused_topk_with_score_function.cu
...ine/common/fused_router/fused_topk_with_score_function.cu
+2
-0
transformer_engine/common/fused_router/utils.h
transformer_engine/common/fused_router/utils.h
+8
-0
transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu
...mon/fused_softmax/scaled_aligned_causal_masked_softmax.cu
+2
-0
transformer_engine/common/fused_softmax/scaled_masked_softmax.cu
...rmer_engine/common/fused_softmax/scaled_masked_softmax.cu
+3
-0
transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu
...ommon/fused_softmax/scaled_upper_triang_masked_softmax.cu
+2
-0
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+136
-42
transformer_engine/common/gemm/cutlass_grouped_gemm.cu
transformer_engine/common/gemm/cutlass_grouped_gemm.cu
+77
-0
transformer_engine/common/gemm/cutlass_grouped_gemm.cuh
transformer_engine/common/gemm/cutlass_grouped_gemm.cuh
+348
-0
transformer_engine/common/include/transformer_engine/comm_gemm.h
...rmer_engine/common/include/transformer_engine/comm_gemm.h
+156
-0
transformer_engine/common/include/transformer_engine/dropout.h
...former_engine/common/include/transformer_engine/dropout.h
+51
-0
No files found.
transformer_engine/common/fused_attn/context_parallel.cu
View file @
27ddce40
...
...
@@ -341,6 +341,7 @@ void thd_read_half_tensor(const Tensor &tensor, const Tensor &cu_seqlens, Tensor
thd_read_half_tensor_kernel
<<<
grid
,
block
,
sizeof
(
int
)
*
(
batch
+
1
),
stream
>>>
(
half
.
data
.
dptr
,
tensor
.
data
.
dptr
,
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
hidden_size_in_bytes
,
half_idx
,
tensor_shape
[
seq_dim
]);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
/***************************************************************************************************
...
...
@@ -397,11 +398,13 @@ void thd_second_half_lse_correction(Tensor lse, const Tensor &lse_per_step,
reinterpret_cast
<
float
*>
(
lse
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
lse_per_step
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
num_heads
,
lse_seqlen
,
second_half_lse_seqlen
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
else
{
thd_lse_kernel
<
false
,
LseCorrectionFunctor
><<<
grid
,
block
,
sizeof
(
int
)
*
(
batch
+
1
),
stream
>>>
(
reinterpret_cast
<
float
*>
(
lse
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
lse_per_step
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
num_heads
,
lse_seqlen
,
second_half_lse_seqlen
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
...
...
@@ -446,11 +449,13 @@ void thd_read_second_half_lse(const Tensor &lse, const Tensor &cu_seqlens, Tenso
reinterpret_cast
<
float
*>
(
lse
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
half_lse
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
num_heads
,
lse_seqlen
,
second_half_lse_seqlen
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
else
{
thd_lse_kernel
<
false
,
ReadLseFunctor
><<<
grid
,
block
,
sizeof
(
int
)
*
(
batch
+
1
),
stream
>>>
(
reinterpret_cast
<
float
*>
(
lse
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
half_lse
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
num_heads
,
lse_seqlen
,
second_half_lse_seqlen
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
...
...
@@ -519,6 +524,7 @@ static void thd_out_correction_helper(Tensor out, const Tensor &out_per_step, co
reinterpret_cast
<
float
*>
(
lse_per_step
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
num_heads
,
dim_per_head
,
lse_seqlen
,
lse_per_step_seqlen
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
else
{
thd_out_correction_kernel
<
dtype
,
only_second_half
,
tile
,
false
>
<<<
grid
,
block
,
sizeof
(
int
)
*
(
batch
+
1
),
stream
>>>
(
...
...
@@ -528,6 +534,7 @@ static void thd_out_correction_helper(Tensor out, const Tensor &out_per_step, co
reinterpret_cast
<
float
*>
(
lse_per_step
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
num_heads
,
dim_per_head
,
lse_seqlen
,
lse_per_step_seqlen
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
...
...
@@ -602,6 +609,7 @@ static void thd_grad_correction_helper(Tensor grad, const Tensor &grad_per_step,
reinterpret_cast
<
dtype
*>
(
grad
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
grad_per_step
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
hidden_size
,
total_tokens
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
template
<
typename
dtype
>
...
...
@@ -667,6 +675,7 @@ void thd_get_partitioned_indices(const Tensor &cu_seqlens, Tensor output, int to
thd_partition_indices_kernel
<<<
grid
,
block
,
sizeof
(
int
)
*
(
batch
+
1
),
stream
>>>
(
reinterpret_cast
<
int
*>
(
output
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
batch
,
total_tokens
,
world_size
,
rank
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
// namespace context_parallel
...
...
transformer_engine/common/fused_attn/flash_attn.cu
View file @
27ddce40
...
...
@@ -91,6 +91,7 @@ void prepare_flash_attn_fwd(Tensor qkvi, Tensor qkv, cudaStream_t stream) {
prepare_kernel_fwd
<
dtype
><<<
grid
,
threads
,
0
,
stream
>>>
(
reinterpret_cast
<
dtype
*>
(
qkvi
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
qkv
.
data
.
dptr
),
shape
[
1
],
shape
[
2
],
shape
[
3
],
shape
[
4
]););
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
void
prepare_flash_attn_bwd
(
Tensor
q
,
Tensor
k
,
Tensor
v
,
Tensor
qkv
,
cudaStream_t
stream
)
{
...
...
@@ -129,6 +130,7 @@ void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream
reinterpret_cast
<
dtype
*>
(
q
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
k
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
v
.
data
.
dptr
),
reinterpret_cast
<
dtype
*>
(
qkv
.
data
.
dptr
),
q_shape
[
0
],
q_shape
[
1
],
q_shape
[
2
],
q_shape
[
3
]););
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
// namespace flash_attention
...
...
transformer_engine/common/fused_attn/fused_attn.cpp
View file @
27ddce40
...
...
@@ -251,10 +251,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged
(
head_dim_qk
==
192
&&
head_dim_v
==
128
&&
is_training
&&
sm_arch_
>=
100
&&
cudnn_runtime_version
>=
91100
))
&&
// 9.11/9.12 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
(
!
((
cudnn_runtime_version
==
91100
||
cudnn_runtime_version
==
91200
)
&&
is_training
&&
sm_arch_
==
90
&&
head_dim_qk
>=
128
&&
head_dim_v
>=
128
&&
!
(
head_dim_qk
==
192
&&
head_dim_v
==
128
)
&&
head_dim_qk
!=
head_dim_v
)))
&&
// 9.11+ bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
// Conditional to temporarily use blanket cudnn_runtime_version >= 9.11 until fixed
(
!
((
cudnn_runtime_version
>=
91100
)
&&
is_training
&&
sm_arch_
==
90
&&
head_dim_qk
>=
128
&&
head_dim_v
>=
128
&&
!
(
head_dim_qk
==
192
&&
head_dim_v
==
128
)
&&
head_dim_qk
!=
head_dim_v
)))
&&
// bias type
((
cudnn_runtime_version
<
8906
&&
bias_type
==
NVTE_Bias_Type
::
NVTE_NO_BIAS
)
||
(
cudnn_runtime_version
>=
8906
&&
...
...
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
View file @
27ddce40
...
...
@@ -416,6 +416,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
actual_b
,
b
,
static_cast
<
const
int32_t
*>
(
devPtrCuSeqlensQ
),
static_cast
<
const
int32_t
*>
(
devPtrCuSeqlensKV
),
static_cast
<
int32_t
*>
(
devActualSeqlenQ
),
static_cast
<
int32_t
*>
(
devActualSeqlenKV
));
NVTE_CHECK_CUDA
(
cudaGetLastError
());
variant_pack
[
seq_q
]
=
devActualSeqlenQ
;
variant_pack
[
seq_kv
]
=
devActualSeqlenKV
;
}
...
...
@@ -454,6 +455,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
layout_group
,
actual_b
,
b
,
h
,
hg
,
d_qk
,
d_v
,
static_cast
<
int32_t
*>
(
devPtrSeqOffsetsQ
),
static_cast
<
int32_t
*>
(
devPtrSeqOffsetsKV
),
ragged_offset_type
,
devOffsetsQ
,
devOffsetsK
,
devOffsetsV
,
devOffsetsO
,
devOffsetsS
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
if
(
is_ragged_q
)
{
variant_pack
[
offset_q
]
=
devOffsetsQ
;
variant_pack
[
offset_o
]
=
devOffsetsO
;
...
...
@@ -883,6 +885,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
actual_b
,
b
,
static_cast
<
const
int32_t
*>
(
devPtrCuSeqlensQ
),
static_cast
<
const
int32_t
*>
(
devPtrCuSeqlensKV
),
static_cast
<
int32_t
*>
(
devActualSeqlenQ
),
static_cast
<
int32_t
*>
(
devActualSeqlenKV
));
NVTE_CHECK_CUDA
(
cudaGetLastError
());
variant_pack
[
seq_q
]
=
devActualSeqlenQ
;
variant_pack
[
seq_kv
]
=
devActualSeqlenKV
;
}
...
...
@@ -916,6 +919,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
layout_group
,
actual_b
,
b
,
h
,
hg
,
d_qk
,
d_v
,
static_cast
<
int32_t
*>
(
devPtrSeqOffsetsQ
),
static_cast
<
int32_t
*>
(
devPtrSeqOffsetsKV
),
ragged_offset_type
,
devOffsetsQ
,
devOffsetsK
,
devOffsetsV
,
devOffsetsO
,
devOffsetsS
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
if
(
is_ragged_q
)
{
variant_pack
[
offset_q
]
=
devOffsetsQ
;
variant_pack
[
offset_o
]
=
devOffsetsO
;
...
...
transformer_engine/common/fused_attn/fused_attn_fp8.cu
View file @
27ddce40
...
...
@@ -1111,6 +1111,7 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
cu_seqlens_to_offsets
<<<
gridDims
,
blockDims
,
0
,
stream
>>>
(
b
,
h
,
d
,
reinterpret_cast
<
int32_t
*>
(
devPtrcuSeqlensQ
),
actual_seqlens_q
,
qkv_ragged_offset
,
o_ragged_offset
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
void
*
devPtrQKVRaggedOffset
=
reinterpret_cast
<
void
*>
(
qkv_ragged_offset
);
void
*
devPtrORaggedOffset
=
reinterpret_cast
<
void
*>
(
o_ragged_offset
);
void
*
devPtrMNKOverride
=
reinterpret_cast
<
void
*>
(
actual_seqlens_q
);
...
...
@@ -1577,6 +1578,7 @@ void fused_attn_fp8_bwd_impl(
cu_seqlens_to_offsets
<<<
gridDims
,
blockDims
,
0
,
stream
>>>
(
b
,
h
,
d
,
reinterpret_cast
<
int32_t
*>
(
devPtrcuSeqlensQ
),
actual_seqlens_q
,
qkv_ragged_offset
,
o_ragged_offset
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
void
*
devPtrQKVRaggedOffset
=
reinterpret_cast
<
void
*>
(
qkv_ragged_offset
);
void
*
devPtrORaggedOffset
=
reinterpret_cast
<
void
*>
(
o_ragged_offset
);
void
*
devPtrMNKOverride
=
reinterpret_cast
<
void
*>
(
actual_seqlens_q
);
...
...
@@ -1933,6 +1935,7 @@ void fused_attn_fp8_fwd_impl_v1(
b
,
b
,
static_cast
<
const
int32_t
*>
(
devPtrcuSeqlensQ
),
// TODO(pass max_b)
static_cast
<
const
int32_t
*>
(
devPtrcuSeqlensKV
),
static_cast
<
int32_t
*>
(
devActualSeqlenQ
),
static_cast
<
int32_t
*>
(
devActualSeqlenKV
));
NVTE_CHECK_CUDA
(
cudaGetLastError
());
variant_pack
[
seq_q
]
=
devActualSeqlenQ
;
variant_pack
[
seq_kv
]
=
devActualSeqlenKV
;
}
...
...
@@ -2329,6 +2332,7 @@ void fused_attn_fp8_bwd_impl_v1(
b
,
b
,
static_cast
<
const
int32_t
*>
(
devPtrcuSeqlensQ
),
// TODO(pass max_b)
static_cast
<
const
int32_t
*>
(
devPtrcuSeqlensKV
),
static_cast
<
int32_t
*>
(
devActualSeqlenQ
),
static_cast
<
int32_t
*>
(
devActualSeqlenKV
));
NVTE_CHECK_CUDA
(
cudaGetLastError
());
variant_pack
[
seq_q
]
=
devActualSeqlenQ
;
variant_pack
[
seq_kv
]
=
devActualSeqlenKV
;
}
...
...
transformer_engine/common/fused_attn/kv_cache.cu
View file @
27ddce40
...
...
@@ -157,6 +157,7 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso
reinterpret_cast
<
int
*>
(
page_table
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_new_lens
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_cached_lens
.
data
.
dptr
),
h_kv
,
d_k
,
d_v
,
b
,
max_seq_len
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
dim3
grid_size
(
b
,
max_ctx_len
);
copy_to_kv_cache_kernel
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
...
...
@@ -166,6 +167,7 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso
reinterpret_cast
<
int
*>
(
cu_new_lens
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_cached_lens
.
data
.
dptr
),
qkv_format
,
h_kv
,
d_k
,
d_v
,
b
,
max_ctx_len
,
max_seq_len
,
max_pages_per_seq
,
is_non_paged
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
...
...
@@ -215,6 +217,7 @@ void convert_thd_to_bshd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_se
reinterpret_cast
<
scalar_t
*>
(
tensor
.
data
.
dptr
),
reinterpret_cast
<
scalar_t
*>
(
new_tensor
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
b
,
max_seq_len
,
h
,
d
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
void
convert_thd_to_bshd
(
Tensor
tensor
,
Tensor
cu_seqlens
,
Tensor
new_tensor
,
int
b
,
...
...
@@ -254,6 +257,7 @@ void convert_bshd_to_thd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_se
reinterpret_cast
<
scalar_t
*>
(
tensor
.
data
.
dptr
),
reinterpret_cast
<
scalar_t
*>
(
new_tensor
.
data
.
dptr
),
reinterpret_cast
<
int
*>
(
cu_seqlens
.
data
.
dptr
),
b
,
max_seq_len
,
h
,
d
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
void
convert_bshd_to_thd
(
Tensor
tensor
,
Tensor
cu_seqlens
,
Tensor
new_tensor
,
int
t
,
...
...
transformer_engine/common/fused_attn/utils.cu
View file @
27ddce40
...
...
@@ -600,13 +600,14 @@ uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cud
// workspace size requires 4 bytes
uint32_t
*
dout
=
static_cast
<
uint32_t
*>
(
workspace
);
uint32_t
hout
{};
cudaMemsetAsync
(
dout
,
0
,
sizeof
(
uint32_t
),
stream
);
NVTE_CHECK_CUDA
(
cudaMemsetAsync
(
dout
,
0
,
sizeof
(
uint32_t
),
stream
)
)
;
constexpr
int
threads
=
128
;
const
int
blocks
=
(
len
-
1
)
/
threads
+
1
;
get_runtime_num_segments_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
static_cast
<
int32_t
*>
(
cu_seqlen
),
len
,
dout
);
cudaMemcpyAsync
(
&
hout
,
dout
,
sizeof
(
uint32_t
),
cudaMemcpyDeviceToHost
,
stream
);
cudaStreamSynchronize
(
stream
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
&
hout
,
dout
,
sizeof
(
uint32_t
),
cudaMemcpyDeviceToHost
,
stream
));
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream
));
return
hout
;
}
...
...
@@ -633,4 +634,5 @@ void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t
fused_attn
::
extract_seed_and_offset
<<<
1
,
1
,
0
,
stream
>>>
(
rng_state_ptr
,
captured
,
seed_ptr
,
seed_val
,
offset_ptr
,
offset_val
,
offset_intragraph
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
transformer_engine/common/fused_rope/fused_rope.cu
View file @
27ddce40
...
...
@@ -21,12 +21,21 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_h
,
const
int
stride_d
,
const
int
o_stride_h
,
const
int
o_stride_d
)
{
extern
__shared__
float
shared_mem_cos_sin
[];
float
*
shared_mem_cos
=
shared_mem_cos_sin
;
float
*
shared_mem_sin
=
shared_mem_cos_sin
+
d2
;
int
tid
=
threadIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
for
(
int
i
=
tid
;
i
<
d2
;
i
+=
blockDim
.
x
*
blockDim
.
y
)
{
sincosf
(
freqs
[
s_id
*
d2
+
i
],
&
shared_mem_sin
[
i
],
&
shared_mem_cos
[
i
]);
}
__syncthreads
();
#pragma unroll
for
(
int
d_id
=
threadIdx
.
x
;
d_id
<
d2
;
d_id
+=
blockDim
.
x
)
{
float
v_cos
,
v_sin
;
sincosf
(
freqs
[
s_id
*
d2
+
d_id
],
&
v_sin
,
&
v_cos
);
for
(
int
h_id
=
threadIdx
.
y
;
h_id
<
h
;
h_id
+=
blockDim
.
y
)
{
#pragma unroll
for
(
int
h_id
=
threadIdx
.
y
;
h_id
<
h
;
h_id
+=
blockDim
.
y
)
{
for
(
int
d_id
=
threadIdx
.
x
;
d_id
<
d2
;
d_id
+=
blockDim
.
x
)
{
float
v_cos
=
shared_mem_cos
[
d_id
];
float
v_sin
=
shared_mem_sin
[
d_id
];
int
offset_src
=
offset_block
+
h_id
*
stride_h
+
d_id
*
stride_d
;
int
offset_dst
=
offset_block_dst
+
h_id
*
o_stride_h
+
d_id
*
o_stride_d
;
float
v_src
=
src
[
offset_src
];
...
...
@@ -49,12 +58,12 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs
// copy the rest
if
(
d
>
d2
)
{
#pragma unroll
for
(
int
h_id
=
threadIdx
.
y
;
h_id
<
h
;
h_id
+=
blockDim
.
y
)
{
int
offset_head
=
offset_block
+
h_id
*
stride_h
;
int
offset_head_dst
=
offset_block_dst
+
h_id
*
o_stride_h
;
for
(
int
d_id
=
d2
+
threadIdx
.
x
;
d_id
<
d
;
d_id
+=
blockDim
.
x
)
{
#pragma unroll
for
(
int
d_id
=
d2
+
threadIdx
.
x
;
d_id
<
d
;
d_id
+=
blockDim
.
x
)
{
dst
[
offset_head_dst
+
d_id
*
o_stride_d
]
=
src
[
offset_head
+
d_id
*
stride_d
];
for
(
int
h_id
=
threadIdx
.
y
;
h_id
<
h
;
h_id
+=
blockDim
.
y
)
{
int
offset_src
=
offset_block
+
h_id
*
stride_h
+
d_id
*
stride_d
;
int
offset_dst
=
offset_block_dst
+
h_id
*
o_stride_h
+
d_id
*
o_stride_d
;
dst
[
offset_dst
]
=
src
[
offset_src
];
}
}
}
...
...
@@ -67,47 +76,54 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_h
,
const
int
stride_d
,
const
int
o_stride_h
,
const
int
o_stride_d
)
{
extern
__shared__
float
shared_mem_cos_sin
[];
float
*
shared_mem_cos
=
shared_mem_cos_sin
;
float
*
shared_mem_sin
=
shared_mem_cos_sin
+
d2
;
int
tid
=
threadIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
for
(
int
i
=
tid
;
i
<
d2
;
i
+=
blockDim
.
x
*
blockDim
.
y
)
{
sincosf
(
freqs
[
s_id
*
d2
+
i
],
&
shared_mem_sin
[
i
],
&
shared_mem_cos
[
i
]);
}
__syncthreads
();
#pragma unroll
for
(
int
d_id
=
threadIdx
.
x
;
d_id
<
d2
;
d_id
+=
blockDim
.
x
)
{
float
v_cos
=
cosf
(
freqs
[
s_id
*
d2
+
d_id
]);
float
v_sin
;
if
(
!
interleaved
)
{
v_sin
=
(
d_id
+
d2
/
2
<
d2
)
?
sinf
(
freqs
[
s_id
*
d2
+
d_id
+
d2
/
2
])
:
-
sinf
(
freqs
[
s_id
*
d2
+
d_id
+
d2
/
2
-
d2
]);
}
else
{
v_sin
=
(
d_id
%
2
==
0
)
?
sinf
(
freqs
[
s_id
*
d2
+
d_id
+
1
])
:
-
sinf
(
freqs
[
s_id
*
d2
+
d_id
-
1
]);
}
for
(
int
h_id
=
threadIdx
.
y
;
h_id
<
h
;
h_id
+=
blockDim
.
y
)
{
#pragma unroll
for
(
int
h
_id
=
threadIdx
.
y
;
h
_id
<
h
;
h
_id
+=
blockDim
.
y
)
{
for
(
int
d
_id
=
threadIdx
.
x
;
d
_id
<
d2
;
d
_id
+=
blockDim
.
x
)
{
int
offset_src
=
offset_block
+
h_id
*
stride_h
+
d_id
*
stride_d
;
int
offset_dst
=
offset_block_dst
+
h_id
*
o_stride_h
+
d_id
*
o_stride_d
;
float
v_src
=
src
[
offset_src
];
float
v_src_rotate
;
float
v_cos
=
shared_mem_cos
[
d_id
];
float
v_src_rotate
,
v_sin
;
if
(
!
interleaved
)
{
v_src_rotate
=
(
d_id
+
d2
/
2
<
d2
)
?
static_cast
<
float
>
(
src
[
offset_src
+
(
d2
/
2
)
*
stride_d
])
:
static_cast
<
float
>
(
src
[
offset_src
+
(
d2
/
2
-
d2
)
*
stride_d
]);
if
(
d_id
+
d2
/
2
<
d2
)
{
v_src_rotate
=
static_cast
<
float
>
(
src
[
offset_src
+
(
d2
/
2
)
*
stride_d
]);
v_sin
=
shared_mem_sin
[
d_id
+
d2
/
2
];
}
else
{
v_src_rotate
=
static_cast
<
float
>
(
src
[
offset_src
+
(
d2
/
2
-
d2
)
*
stride_d
]);
v_sin
=
-
shared_mem_sin
[
d_id
+
d2
/
2
-
d2
];
}
}
else
{
v_src_rotate
=
(
d_id
%
2
==
0
)
// d_id + 1
?
static_cast
<
float
>
(
src
[
offset_src
+
stride_d
])
// d_id - 1
:
static_cast
<
float
>
(
src
[
offset_src
-
stride_d
]);
if
(
d_id
%
2
==
0
)
{
v_src_rotate
=
static_cast
<
float
>
(
src
[
offset_src
+
stride_d
]);
v_sin
=
shared_mem_sin
[
d_id
+
1
];
}
else
{
v_src_rotate
=
static_cast
<
float
>
(
src
[
offset_src
-
stride_d
]);
v_sin
=
-
shared_mem_sin
[
d_id
-
1
];
}
}
dst
[
offset_dst
]
=
v_src
*
v_cos
+
v_src_rotate
*
v_sin
;
}
}
//
handle
the
tail
//
copy
the
rest
if
(
d
>
d2
)
{
#pragma unroll
for
(
int
h_id
=
threadIdx
.
y
;
h_id
<
h
;
h_id
+=
blockDim
.
y
)
{
int
offset_head
=
offset_block
+
h_id
*
stride_h
;
int
offset_head_dst
=
offset_block_dst
+
h_id
*
o_stride_h
;
for
(
int
d_id
=
d2
+
threadIdx
.
x
;
d_id
<
d
;
d_id
+=
blockDim
.
x
)
{
#pragma unroll
for
(
int
d_id
=
d2
+
threadIdx
.
x
;
d_id
<
d
;
d_id
+=
blockDim
.
x
)
{
dst
[
offset_head_dst
+
d_id
*
o_stride_d
]
=
src
[
offset_head
+
d_id
*
stride_d
];
for
(
int
h_id
=
threadIdx
.
y
;
h_id
<
h
;
h_id
+=
blockDim
.
y
)
{
int
offset_src
=
offset_block
+
h_id
*
stride_h
+
d_id
*
stride_d
;
int
offset_dst
=
offset_block_dst
+
h_id
*
o_stride_h
+
d_id
*
o_stride_d
;
dst
[
offset_dst
]
=
src
[
offset_src
];
}
}
}
...
...
@@ -198,6 +214,251 @@ __global__ void fused_rope_backward_kernel(
offset_block_dst
,
h
,
d
,
d2
,
stride_h
,
stride_d
,
o_stride_h
,
o_stride_d
);
}
template
<
typename
scalar_t
>
__device__
void
fused_qkv_rope_block_forward
(
const
scalar_t
*
src
,
const
float
*
freqs
,
scalar_t
*
out
,
const
bool
interleaved
,
const
int
s_id
,
const
int
offset_block
,
const
int
offset_block_dst
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
row_offset
,
const
int
in_row_length
,
const
int
out_row_length
)
{
extern
__shared__
float
shared_mem_cos_sin_qk
[];
// Split the shared memory into cos and sin parts for q or k
float
*
shared_mem_cos
=
nullptr
;
float
*
shared_mem_sin
=
nullptr
;
if
(
row_offset
==
0
)
{
// q
shared_mem_cos
=
shared_mem_cos_sin_qk
;
shared_mem_sin
=
shared_mem_cos_sin_qk
+
d2
;
}
else
{
// k
shared_mem_cos
=
shared_mem_cos_sin_qk
+
2
*
d2
;
shared_mem_sin
=
shared_mem_cos_sin_qk
+
3
*
d2
;
}
if
(
freqs
!=
nullptr
)
{
int
tid
=
threadIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
for
(
int
i
=
tid
;
i
<
d2
;
i
+=
blockDim
.
x
*
blockDim
.
y
)
{
sincosf
(
freqs
[
s_id
*
d2
+
i
],
&
shared_mem_sin
[
i
],
&
shared_mem_cos
[
i
]);
}
}
__syncthreads
();
#pragma unroll
for
(
int
h_id
=
threadIdx
.
y
;
h_id
<
h
;
h_id
+=
blockDim
.
y
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
out_row_length
;
i
+=
d
)
{
#pragma unroll
for
(
int
d_id
=
threadIdx
.
x
;
d_id
<
d2
;
d_id
+=
blockDim
.
x
)
{
int
offset_src
=
offset_block
+
h_id
*
in_row_length
+
(
row_offset
+
i
)
+
d_id
;
int
offset_dst
=
offset_block_dst
+
h_id
*
out_row_length
+
i
+
d_id
;
if
(
freqs
!=
nullptr
)
{
float
v_cos
,
v_sin
;
v_cos
=
shared_mem_cos
[
d_id
];
v_sin
=
shared_mem_sin
[
d_id
];
float
v_src
=
src
[
offset_src
];
float
v_src_rotate
;
if
(
!
interleaved
)
{
v_src_rotate
=
(
d_id
+
d2
/
2
<
d2
)
?
-
static_cast
<
float
>
(
src
[
offset_src
+
(
d2
/
2
)])
:
static_cast
<
float
>
(
src
[
offset_src
+
(
d2
/
2
-
d2
)]);
}
else
{
v_src_rotate
=
(
d_id
%
2
==
0
)
?
-
static_cast
<
float
>
(
src
[
offset_src
+
1
])
:
static_cast
<
float
>
(
src
[
offset_src
-
1
]);
}
out
[
offset_dst
]
=
v_src
*
v_cos
+
v_src_rotate
*
v_sin
;
}
else
{
out
[
offset_dst
]
=
src
[
offset_src
];
}
}
}
}
// copy the rest
if
(
d
>
d2
)
{
#pragma unroll
for
(
int
d_id
=
d2
+
threadIdx
.
x
;
d_id
<
d
;
d_id
+=
blockDim
.
x
)
{
#pragma unroll
for
(
int
h_id
=
threadIdx
.
y
;
h_id
<
h
;
h_id
+=
blockDim
.
y
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
out_row_length
;
i
+=
d
)
{
int
offset_src
=
offset_block
+
h_id
*
in_row_length
+
(
row_offset
+
i
)
+
d_id
;
int
offset_dst
=
offset_block_dst
+
h_id
*
out_row_length
+
i
+
d_id
;
out
[
offset_dst
]
=
src
[
offset_src
];
}
}
}
}
}
template
<
typename
scalar_t
>
__device__
void
fused_qkv_rope_block_backward
(
const
scalar_t
*
grad_out
,
const
float
*
freqs
,
scalar_t
*
out
,
const
bool
interleaved
,
const
int
s_id
,
const
int
offset_block
,
const
int
offset_block_dst
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
row_offset
,
const
int
in_row_length
,
const
int
out_row_length
)
{
extern
__shared__
float
shared_mem_cos_sin_qk
[];
float
*
shared_mem_cos
=
nullptr
;
float
*
shared_mem_sin
=
nullptr
;
// Split the shared memory into cos and sin parts for q or k
if
(
row_offset
==
0
)
{
// q
shared_mem_cos
=
shared_mem_cos_sin_qk
;
shared_mem_sin
=
shared_mem_cos_sin_qk
+
d2
;
}
else
{
// k
shared_mem_cos
=
shared_mem_cos_sin_qk
+
2
*
d2
;
shared_mem_sin
=
shared_mem_cos_sin_qk
+
3
*
d2
;
}
if
(
freqs
!=
nullptr
)
{
int
tid
=
threadIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
for
(
int
i
=
tid
;
i
<
d2
;
i
+=
blockDim
.
x
*
blockDim
.
y
)
{
sincosf
(
freqs
[
s_id
*
d2
+
i
],
&
shared_mem_sin
[
i
],
&
shared_mem_cos
[
i
]);
}
}
__syncthreads
();
#pragma unroll
for
(
int
h_id
=
threadIdx
.
y
;
h_id
<
h
;
h_id
+=
blockDim
.
y
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
out_row_length
;
i
+=
d
)
{
#pragma unroll
for
(
int
d_id
=
threadIdx
.
x
;
d_id
<
d2
;
d_id
+=
blockDim
.
x
)
{
int
offset_dst
=
offset_block
+
h_id
*
in_row_length
+
(
row_offset
+
i
)
+
d_id
;
int
offset_src
=
offset_block_dst
+
h_id
*
out_row_length
+
i
+
d_id
;
float
v_src
=
grad_out
[
offset_src
];
if
(
freqs
!=
nullptr
)
{
float
v_cos
,
v_sin
;
v_cos
=
shared_mem_cos
[
d_id
];
float
v_src_rotate
;
if
(
!
interleaved
)
{
if
(
d_id
+
d2
/
2
<
d2
)
{
v_src_rotate
=
static_cast
<
float
>
(
grad_out
[
offset_src
+
(
d2
/
2
)]);
v_sin
=
shared_mem_sin
[
d_id
+
d2
/
2
];
}
else
{
v_src_rotate
=
static_cast
<
float
>
(
grad_out
[
offset_src
+
(
d2
/
2
-
d2
)]);
v_sin
=
-
shared_mem_sin
[
d_id
+
d2
/
2
-
d2
];
}
}
else
{
if
(
d_id
%
2
==
0
)
{
v_src_rotate
=
static_cast
<
float
>
(
grad_out
[
offset_src
+
1
]);
v_sin
=
shared_mem_sin
[
d_id
+
1
];
}
else
{
v_src_rotate
=
static_cast
<
float
>
(
grad_out
[
offset_src
-
1
]);
v_sin
=
-
shared_mem_sin
[
d_id
-
1
];
}
}
out
[
offset_dst
]
=
v_src
*
v_cos
+
v_src_rotate
*
v_sin
;
}
else
{
out
[
offset_dst
]
=
grad_out
[
offset_src
];
}
}
}
}
// copy the rest
if
(
d
>
d2
)
{
#pragma unroll
for
(
int
h_id
=
threadIdx
.
y
;
h_id
<
h
;
h_id
+=
blockDim
.
y
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
out_row_length
;
i
+=
d
)
{
#pragma unroll
for
(
int
d_id
=
d2
+
threadIdx
.
x
;
d_id
<
d
;
d_id
+=
blockDim
.
x
)
{
int
offset_dst
=
offset_block
+
h_id
*
in_row_length
+
(
row_offset
+
i
)
+
d_id
;
int
offset_src
=
offset_block_dst
+
h_id
*
out_row_length
+
i
+
d_id
;
out
[
offset_dst
]
=
grad_out
[
offset_src
];
}
}
}
}
}
template
<
typename
scalar_t
>
__global__
void
fused_qkv_rope_forward_kernel
(
const
scalar_t
*
qkv_input
,
const
float
*
q_freqs
,
const
float
*
k_freqs
,
const
int
*
start_positions
,
scalar_t
*
q_out
,
scalar_t
*
k_out
,
scalar_t
*
v_out
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
q_split_arg
,
const
int
k_split_arg
,
const
int
v_split_arg
)
{
int
s_id
=
blockIdx
.
x
,
b_id
=
blockIdx
.
y
;
int
cur_seqlens
=
s
;
int
total_d
=
q_split_arg
+
k_split_arg
+
v_split_arg
;
int
offset_block
,
offset_block_dst_q
,
offset_block_dst_k
,
offset_block_dst_v
;
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
)
{
offset_block
=
s_id
*
b
*
h
*
total_d
+
b_id
*
h
*
total_d
;
offset_block_dst_q
=
s_id
*
b
*
h
*
q_split_arg
+
b_id
*
h
*
q_split_arg
;
offset_block_dst_k
=
s_id
*
b
*
h
*
k_split_arg
+
b_id
*
h
*
k_split_arg
;
offset_block_dst_v
=
s_id
*
b
*
h
*
v_split_arg
+
b_id
*
h
*
v_split_arg
;
}
else
{
offset_block
=
b_id
*
s
*
h
*
total_d
+
s_id
*
h
*
total_d
;
offset_block_dst_q
=
b_id
*
s
*
h
*
q_split_arg
+
s_id
*
h
*
q_split_arg
;
offset_block_dst_k
=
b_id
*
s
*
h
*
k_split_arg
+
s_id
*
h
*
k_split_arg
;
offset_block_dst_v
=
b_id
*
s
*
h
*
v_split_arg
+
s_id
*
h
*
v_split_arg
;
}
int
q_limit
=
q_split_arg
;
int
k_limit
=
q_limit
+
k_split_arg
;
int
s_id_for_freqs
;
if
(
cp_size
>
1
)
{
assert
(
cur_seqlens
%
2
==
0
);
if
(
s_id
<
cur_seqlens
/
2
)
{
s_id_for_freqs
=
s_id
+
cp_rank
*
cur_seqlens
/
2
;
}
else
{
s_id_for_freqs
=
cur_seqlens
*
cp_size
-
(
cp_rank
+
1
)
*
cur_seqlens
/
2
+
s_id
-
cur_seqlens
/
2
;
}
}
else
{
int
begin_offset
=
(
start_positions
==
nullptr
)
?
0
:
start_positions
[
b_id
];
s_id_for_freqs
=
s_id
+
begin_offset
;
}
fused_qkv_rope_block_forward
(
qkv_input
,
q_freqs
,
q_out
,
interleaved
,
s_id_for_freqs
,
offset_block
,
offset_block_dst_q
,
h
,
d
,
d2
,
0
,
total_d
,
q_split_arg
);
fused_qkv_rope_block_forward
(
qkv_input
,
k_freqs
,
k_out
,
interleaved
,
s_id_for_freqs
,
offset_block
,
offset_block_dst_k
,
h
,
d
,
d2
,
q_limit
,
total_d
,
k_split_arg
);
fused_qkv_rope_block_forward
(
qkv_input
,
nullptr
,
v_out
,
interleaved
,
s_id_for_freqs
,
offset_block
,
offset_block_dst_v
,
h
,
d
,
d2
,
k_limit
,
total_d
,
v_split_arg
);
}
template
<
typename
scalar_t
>
__global__
void
fused_qkv_rope_backward_kernel
(
const
scalar_t
*
grad_out_q
,
const
scalar_t
*
grad_out_k
,
const
scalar_t
*
grad_out_v
,
const
float
*
q_freqs
,
const
float
*
k_freqs
,
scalar_t
*
qkv_grad
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
q_split_arg
,
const
int
k_split_arg
,
const
int
v_split_arg
)
{
int
s_id
=
blockIdx
.
x
,
b_id
=
blockIdx
.
y
;
int
cur_seqlens
=
s
;
int
offset_block
,
offset_block_dst_q
,
offset_block_dst_k
,
offset_block_dst_v
;
int
total_d
=
q_split_arg
+
k_split_arg
+
v_split_arg
;
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
)
{
offset_block
=
s_id
*
b
*
h
*
total_d
+
b_id
*
h
*
total_d
;
offset_block_dst_q
=
s_id
*
b
*
h
*
q_split_arg
+
b_id
*
h
*
q_split_arg
;
offset_block_dst_k
=
s_id
*
b
*
h
*
k_split_arg
+
b_id
*
h
*
k_split_arg
;
offset_block_dst_v
=
s_id
*
b
*
h
*
v_split_arg
+
b_id
*
h
*
v_split_arg
;
}
else
{
offset_block
=
b_id
*
s
*
h
*
total_d
+
s_id
*
h
*
total_d
;
offset_block_dst_q
=
b_id
*
s
*
h
*
q_split_arg
+
s_id
*
h
*
q_split_arg
;
offset_block_dst_k
=
b_id
*
s
*
h
*
k_split_arg
+
s_id
*
h
*
k_split_arg
;
offset_block_dst_v
=
b_id
*
s
*
h
*
v_split_arg
+
s_id
*
h
*
v_split_arg
;
}
int
q_limit
=
q_split_arg
;
int
k_limit
=
q_limit
+
k_split_arg
;
int
s_id_for_freqs
;
if
(
cp_size
>
1
)
{
assert
(
cur_seqlens
%
2
==
0
);
if
(
s_id
<
cur_seqlens
/
2
)
{
s_id_for_freqs
=
s_id
+
cp_rank
*
cur_seqlens
/
2
;
}
else
{
s_id_for_freqs
=
cur_seqlens
*
cp_size
-
(
cp_rank
+
1
)
*
cur_seqlens
/
2
+
s_id
-
cur_seqlens
/
2
;
}
}
else
{
s_id_for_freqs
=
s_id
;
}
fused_qkv_rope_block_backward
(
grad_out_q
,
q_freqs
,
qkv_grad
,
interleaved
,
s_id_for_freqs
,
offset_block
,
offset_block_dst_q
,
h
,
d
,
d2
,
0
,
total_d
,
q_split_arg
);
fused_qkv_rope_block_backward
(
grad_out_k
,
k_freqs
,
qkv_grad
,
interleaved
,
s_id_for_freqs
,
offset_block
,
offset_block_dst_k
,
h
,
d
,
d2
,
q_limit
,
total_d
,
k_split_arg
);
fused_qkv_rope_block_backward
(
grad_out_v
,
nullptr
,
qkv_grad
,
interleaved
,
s_id_for_freqs
,
offset_block
,
offset_block_dst_v
,
h
,
d
,
d2
,
k_limit
,
total_d
,
v_split_arg
);
}
template
<
typename
scalar_t
>
void
fused_rope_forward_launcher
(
const
scalar_t
*
input
,
const
int
*
cu_seqlens
,
const
float
*
freqs
,
const
int
*
start_positions
,
scalar_t
*
output
,
...
...
@@ -209,6 +470,7 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c
int
warps_per_block
=
h
<
16
?
4
:
8
;
dim3
blocks
(
s
,
b
);
dim3
threads
(
THREADS_PER_WARP
,
warps_per_block
);
const
int
shared_mem_size
=
2
*
d2
*
sizeof
(
float
);
// cos, sin
int
o_stride_s_or_t
,
o_stride_b
;
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
)
{
NVTE_CHECK
(
cu_seqlens
!=
nullptr
,
"cu_seqlens is required for THD format"
);
...
...
@@ -224,7 +486,7 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c
const
int
o_stride_h
=
d
;
const
int
o_stride_d
=
1
;
fused_rope_forward_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
fused_rope_forward_kernel
<<<
blocks
,
threads
,
shared_mem_size
,
stream
>>>
(
input
,
cu_seqlens
,
freqs
,
start_positions
,
output
,
interleaved
,
cp_size
,
cp_rank
,
s
,
h
,
d
,
d2
,
stride_s_or_t
,
stride_b
,
stride_h
,
stride_d
,
o_stride_s_or_t
,
o_stride_b
,
o_stride_h
,
o_stride_d
);
...
...
@@ -242,6 +504,7 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se
int
warps_per_block
=
h
<
16
?
4
:
8
;
dim3
blocks
(
s
,
b
);
dim3
threads
(
THREADS_PER_WARP
,
warps_per_block
);
const
int
shared_mem_size
=
2
*
d2
*
sizeof
(
float
);
// cos, sin
int
o_stride_s_or_t
,
o_stride_b
;
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
)
{
NVTE_CHECK
(
cu_seqlens
!=
nullptr
,
"cu_seqlens is required for THD format"
);
...
...
@@ -257,13 +520,58 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se
const
int
o_stride_h
=
d
;
const
int
o_stride_d
=
1
;
fused_rope_backward_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
fused_rope_backward_kernel
<<<
blocks
,
threads
,
shared_mem_size
,
stream
>>>
(
output_grads
,
cu_seqlens
,
freqs
,
input_grads
,
interleaved
,
cp_size
,
cp_rank
,
s
,
h
,
d
,
d2
,
stride_s_or_t
,
stride_b
,
stride_h
,
stride_d
,
o_stride_s_or_t
,
o_stride_b
,
o_stride_h
,
o_stride_d
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
template
<
typename
scalar_t
>
void
fused_qkv_rope_forward_launcher
(
const
scalar_t
*
qkv_input
,
const
float
*
q_freqs
,
const
float
*
k_freqs
,
const
int
*
start_positions
,
scalar_t
*
q_out
,
scalar_t
*
k_out
,
scalar_t
*
v_out
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
qkv_split_arg_list_0
,
const
int
qkv_split_arg_list_1
,
const
int
qkv_split_arg_list_2
,
cudaStream_t
stream
)
{
const
int
THREADS_PER_WARP
=
32
;
int
warps_per_block
=
(
h
<=
8
)
?
h
:
8
;
dim3
blocks
(
s
,
b
);
dim3
threads
(
THREADS_PER_WARP
,
warps_per_block
);
const
int
shared_mem_size
=
4
*
d2
*
sizeof
(
float
);
// cos, sin * q ,k
fused_qkv_rope_forward_kernel
<<<
blocks
,
threads
,
shared_mem_size
,
stream
>>>
(
qkv_input
,
q_freqs
,
k_freqs
,
start_positions
,
q_out
,
k_out
,
v_out
,
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
qkv_split_arg_list_0
,
qkv_split_arg_list_1
,
qkv_split_arg_list_2
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
template
<
typename
scalar_t
>
void
fused_qkv_rope_backward_launcher
(
const
scalar_t
*
q_grad_out
,
const
scalar_t
*
k_grad_out
,
const
scalar_t
*
v_grad_out
,
const
float
*
q_freqs
,
const
float
*
k_freqs
,
scalar_t
*
qkv_grad_input
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
qkv_split_arg_list_0
,
const
int
qkv_split_arg_list_1
,
const
int
qkv_split_arg_list_2
,
cudaStream_t
stream
)
{
const
int
THREADS_PER_WARP
=
32
;
const
int
warps_per_block
=
(
h
<=
8
)
?
h
:
8
;
dim3
blocks
(
s
,
b
);
dim3
threads
(
THREADS_PER_WARP
,
warps_per_block
);
const
int
shared_mem_size
=
4
*
d2
*
sizeof
(
float
);
// cos, sin * q ,k
fused_qkv_rope_backward_kernel
<<<
blocks
,
threads
,
shared_mem_size
,
stream
>>>
(
q_grad_out
,
k_grad_out
,
v_grad_out
,
q_freqs
,
k_freqs
,
qkv_grad_input
,
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
qkv_split_arg_list_0
,
qkv_split_arg_list_1
,
qkv_split_arg_list_2
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
void
fused_rope_forward
(
const
Tensor
&
input
,
const
Tensor
&
cu_seqlens
,
const
Tensor
&
freqs
,
const
Tensor
&
start_positions
,
Tensor
*
output
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
cp_size
,
...
...
@@ -297,6 +605,46 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, c
stride_b
,
stride_h
,
stride_d
,
stream
););
}
void
fused_qkv_rope_forward
(
const
Tensor
&
qkv_input
,
const
Tensor
&
q_freqs
,
const
Tensor
&
k_freqs
,
const
Tensor
&
start_positions
,
Tensor
*
q_out
,
Tensor
*
k_out
,
Tensor
*
v_out
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
qkv_split_arg_list_0
,
const
int
qkv_split_arg_list_1
,
const
int
qkv_split_arg_list_2
,
cudaStream_t
stream
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
qkv_input
.
data
.
dtype
,
scalar_t
,
fused_qkv_rope_forward_launcher
(
reinterpret_cast
<
const
scalar_t
*>
(
qkv_input
.
data
.
dptr
),
reinterpret_cast
<
const
float
*>
(
q_freqs
.
data
.
dptr
),
reinterpret_cast
<
const
float
*>
(
k_freqs
.
data
.
dptr
),
reinterpret_cast
<
const
int
*>
(
start_positions
.
data
.
dptr
),
reinterpret_cast
<
scalar_t
*>
(
q_out
->
data
.
dptr
),
reinterpret_cast
<
scalar_t
*>
(
k_out
->
data
.
dptr
),
reinterpret_cast
<
scalar_t
*>
(
v_out
->
data
.
dptr
),
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
qkv_split_arg_list_0
,
qkv_split_arg_list_1
,
qkv_split_arg_list_2
,
stream
););
}
void
fused_qkv_rope_backward
(
const
Tensor
&
q_grad_out
,
const
Tensor
&
k_grad_out
,
const
Tensor
&
v_grad_out
,
const
Tensor
&
q_freqs
,
const
Tensor
&
k_freqs
,
Tensor
*
qkv_grad_input
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
qkv_split_arg_list_0
,
const
int
qkv_split_arg_list_1
,
const
int
qkv_split_arg_list_2
,
cudaStream_t
stream
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
q_grad_out
.
data
.
dtype
,
scalar_t
,
fused_qkv_rope_backward_launcher
(
reinterpret_cast
<
const
scalar_t
*>
(
q_grad_out
.
data
.
dptr
),
reinterpret_cast
<
const
scalar_t
*>
(
k_grad_out
.
data
.
dptr
),
reinterpret_cast
<
const
scalar_t
*>
(
v_grad_out
.
data
.
dptr
),
reinterpret_cast
<
const
float
*>
(
q_freqs
.
data
.
dptr
),
reinterpret_cast
<
const
float
*>
(
k_freqs
.
data
.
dptr
),
reinterpret_cast
<
scalar_t
*>
(
qkv_grad_input
->
data
.
dptr
),
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
qkv_split_arg_list_0
,
qkv_split_arg_list_1
,
qkv_split_arg_list_2
,
stream
););
}
}
// end namespace transformer_engine
void
nvte_fused_rope_forward
(
const
NVTETensor
input
,
const
NVTETensor
cu_seqlens
,
...
...
@@ -328,3 +676,38 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
stride_s_or_t
,
stride_b
,
stride_h
,
stride_d
,
stream
);
}
void
nvte_fused_qkv_rope_forward
(
const
NVTETensor
qkv_input
,
const
NVTETensor
q_freqs
,
const
NVTETensor
k_freqs
,
const
NVTETensor
start_positions
,
NVTETensor
q_out
,
NVTETensor
k_out
,
NVTETensor
v_out
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
qkv_split_arg_list_0
,
const
int
qkv_split_arg_list_1
,
const
int
qkv_split_arg_list_2
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_fused_qkv_rope_forward
);
using
namespace
transformer_engine
;
fused_qkv_rope_forward
(
*
convertNVTETensorCheck
(
qkv_input
),
*
convertNVTETensorCheck
(
q_freqs
),
*
convertNVTETensorCheck
(
k_freqs
),
*
convertNVTETensorCheck
(
start_positions
),
convertNVTETensorCheck
(
q_out
),
convertNVTETensorCheck
(
k_out
),
convertNVTETensorCheck
(
v_out
),
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
qkv_split_arg_list_0
,
qkv_split_arg_list_1
,
qkv_split_arg_list_2
,
stream
);
}
void
nvte_fused_qkv_rope_backward
(
const
NVTETensor
q_grad_out
,
const
NVTETensor
k_grad_out
,
const
NVTETensor
v_grad_out
,
const
NVTETensor
q_freqs
,
const
NVTETensor
k_freqs
,
NVTETensor
qkv_grad_input
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
qkv_split_arg_list_0
,
const
int
qkv_split_arg_list_1
,
const
int
qkv_split_arg_list_2
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_fused_qkv_rope_backward
);
using
namespace
transformer_engine
;
fused_qkv_rope_backward
(
*
convertNVTETensorCheck
(
q_grad_out
),
*
convertNVTETensorCheck
(
k_grad_out
),
*
convertNVTETensorCheck
(
v_grad_out
),
*
convertNVTETensorCheck
(
q_freqs
),
*
convertNVTETensorCheck
(
k_freqs
),
convertNVTETensorCheck
(
qkv_grad_input
),
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
qkv_split_arg_list_0
,
qkv_split_arg_list_1
,
qkv_split_arg_list_2
,
stream
);
}
transformer_engine/common/fused_router/fused_moe_aux_loss.cu
View file @
27ddce40
...
...
@@ -178,9 +178,9 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
config
.
stream
=
stream
;
// Update the max cluster size based on the device
cudaOccupancyMaxPotentialClusterSize
(
NVTE_CHECK_CUDA
(
cudaOccupancyMaxPotentialClusterSize
(
&
cluster_size
,
reinterpret_cast
<
void
*>
(
fused_moe_aux_loss_forward_kernel
<
DataType
,
IndexType
>
),
&
config
);
reinterpret_cast
<
void
*>
(
fused_moe_aux_loss_forward_kernel
<
DataType
,
IndexType
>
),
&
config
)
)
;
cudaLaunchAttribute
attribute
[
1
];
attribute
[
0
].
id
=
cudaLaunchAttributeClusterDimension
;
...
...
@@ -190,15 +190,16 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
config
.
numAttrs
=
1
;
config
.
attrs
=
attribute
;
cudaLaunchKernelEx
(
&
config
,
fused_moe_aux_loss_forward_kernel
<
DataType
,
IndexType
>
,
probs
,
tokens_per_expert
,
total_num_tokens
,
num_experts
,
num_
ro
w
s
,
num_cols
,
topk
,
coeff
,
aux_loss
,
Const_buf
);
NVTE_CHECK_CUDA
(
cudaLaunchKernelEx
(
&
config
,
fused_moe_aux_loss_forward_kernel
<
DataType
,
IndexType
>
,
p
ro
b
s
,
tokens_per_expert
,
total_num_tokens
,
num_experts
,
num_rows
,
num_cols
,
topk
,
coeff
,
aux_loss
,
Const_buf
)
)
;
}
else
{
#endif
size_t
smem_size
=
sizeof
(
CompType
)
*
num_cols
;
fused_moe_aux_loss_forward_kernel
<
DataType
,
IndexType
>
<<<
1
,
1024
,
smem_size
,
stream
>>>
(
probs
,
tokens_per_expert
,
total_num_tokens
,
num_experts
,
num_rows
,
num_cols
,
topk
,
coeff
,
aux_loss
,
Const_buf
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
#ifndef __HIP_PLATFORM_AMD__
}
#endif
...
...
@@ -232,7 +233,7 @@ __global__ void fused_moe_aux_loss_backward_kernel(const float* Const_buf,
// Loop: for all positions in each row
for
(
int
i
=
lane_id
;
i
<
num_cols
;
i
+=
kThreadsPerWarp
)
{
float
C_coeff
=
Const_buf
[
0
];
IndexTyp
e
tokens_per_expert_i
=
tokens_per_expert
[
i
];
doubl
e
tokens_per_expert_i
=
static_cast
<
double
>
(
tokens_per_expert
[
i
]
)
;
double
grad_aux_loss_value
=
static_cast
<
double
>
(
grad_aux_loss
[
0
]);
// Loop: for all rows
for
(
int
j
=
global_warp_id
;
j
<
num_rows
;
j
+=
global_warp_num
)
{
...
...
@@ -251,6 +252,7 @@ void fused_moe_aux_loss_backward_kernel_launcher(const float* Const_buf,
int
grid_size
=
(
num_rows
+
block_size
-
1
)
/
block_size
;
fused_moe_aux_loss_backward_kernel
<
DataType
,
IndexType
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
Const_buf
,
tokens_per_expert
,
num_rows
,
num_cols
,
grad_aux_loss
,
grad_probs
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
void
fused_moe_aux_loss_backward
(
const
Tensor
&
Const_buf
,
const
Tensor
&
tokens_per_expert
,
...
...
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu
View file @
27ddce40
...
...
@@ -151,6 +151,7 @@ void fused_score_for_moe_aux_loss_forward_kernel_launcher(
<<<
grid_size
,
kThreadsPerBlock
,
shared_memory_size
,
stream
>>>
(
logits
,
num_tokens
,
num_experts
,
topk
,
score_function
,
scores
,
routing_map
,
intermediate_output
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
void
fused_score_for_moe_aux_loss_forward
(
const
Tensor
&
logits
,
int
num_tokens
,
int
num_experts
,
...
...
@@ -286,6 +287,7 @@ void fused_score_for_moe_aux_loss_backward_kernel_launcher(
<<<
grid_size
,
kThreadsPerBlock
,
shared_memory_size
,
stream
>>>
(
intermediate_output
,
grad_scores
,
num_tokens
,
num_experts
,
topk
,
score_function
,
grad_logits
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
void
fused_score_for_moe_aux_loss_backward
(
const
Tensor
&
intermediate_output
,
...
...
transformer_engine/common/fused_router/fused_topk_with_score_function.cu
View file @
27ddce40
...
...
@@ -257,6 +257,7 @@ void fused_topk_with_score_function_forward_kernel_launcher(
<<<
grid_size
,
kThreadsPerBlock
,
shared_memory_size
,
stream
>>>
(
logits
,
num_tokens
,
num_experts
,
topk
,
use_pre_softmax
,
num_groups
,
group_topk
,
scaling_factor
,
score_function
,
expert_bias
,
probs
,
routing_map
,
intermediate_output
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
void
fused_topk_with_score_function_forward
(
const
Tensor
logits
,
int
num_tokens
,
int
num_experts
,
...
...
@@ -447,6 +448,7 @@ void fused_topk_with_score_function_backward_kernel_launcher(
<<<
grid_size
,
kThreadsPerBlock
,
shared_memory_size
,
stream
>>>
(
routing_map
,
intermediate_output
,
grad_probs
,
num_tokens
,
num_experts
,
topk
,
use_pre_softmax
,
scaling_factor
,
score_function
,
grad_logits
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
void
fused_topk_with_score_function_backward
(
const
Tensor
&
routing_map
,
...
...
transformer_engine/common/fused_router/utils.h
View file @
27ddce40
...
...
@@ -271,6 +271,14 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i
using type = int64_t; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
...
...
transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu
View file @
27ddce40
...
...
@@ -353,6 +353,7 @@ void call_kernel_scaled_aligned_causal_masked_softmax_forward(
scaled_aligned_causal_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
log2_elements
>
<<<
grid_size
,
block_size
,
shmem_size
,
stream
>>>
(
dst
,
src
,
scale
,
microbatches
,
query_seq_len
,
key_seq_len
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
...
...
@@ -363,6 +364,7 @@ void call_kernel_scaled_aligned_causal_masked_softmax_backward(
scaled_aligned_causal_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
log2_elements
>
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
gradInput
,
grad
,
output
,
scale
,
microbatches
,
query_seq_len
,
key_seq_len
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
...
...
transformer_engine/common/fused_softmax/scaled_masked_softmax.cu
View file @
27ddce40
...
...
@@ -513,6 +513,7 @@ void dispatch_scaled_softmax_forward(output_t *dst, const input_t *src, const in
default:
break
;
}
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
...
...
@@ -625,6 +626,7 @@ void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src, c
default:
break
;
}
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
...
...
@@ -736,6 +738,7 @@ void dispatch_scaled_masked_softmax_backward(output_t *grad_input, const input_t
default:
break
;
}
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
...
...
transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu
View file @
27ddce40
...
...
@@ -445,6 +445,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(output_t *dst, const in
default:
break
;
}
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
...
...
@@ -561,6 +562,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(output_t *grad_input,
default:
break
;
}
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
...
...
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
27ddce40
...
...
@@ -25,28 +25,11 @@
#include "../util/logging.h"
#include "../util/multi_stream.h"
#include "common/util/cuda_runtime.h"
#include "cutlass_grouped_gemm.cuh"
#ifndef __HIP_PLATFORM_AMD__
namespace
{
cudaDataType_t
get_cuda_dtype
(
const
transformer_engine
::
DType
t
)
{
using
namespace
transformer_engine
;
switch
(
t
)
{
case
DType
::
kFloat16
:
return
CUDA_R_16F
;
case
DType
::
kFloat32
:
return
CUDA_R_32F
;
case
DType
::
kBFloat16
:
return
CUDA_R_16BF
;
case
DType
::
kFloat8E4M3
:
return
CUDA_R_8F_E4M3
;
case
DType
::
kFloat8E5M2
:
return
CUDA_R_8F_E5M2
;
default:
NVTE_ERROR
(
"Invalid type"
);
}
}
uint32_t
_getAlignment
(
uintptr_t
address
)
{
// alignment are in bytes
uint32_t
alignment
=
256
;
...
...
@@ -532,22 +515,22 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
&
epilogue
,
sizeof
(
epilogue
)));
if
(
counter
!=
nullptr
)
{
#if !(CUDA_VERSION >= 12020 && CU
BLAS
_VERSION
>=
13000)
NVTE_ERROR
(
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is "
,
#if !(CUDA_VERSION >= 12020 && CU
DA
_VERSION
<
13000)
NVTE_ERROR
(
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA vers
i
on is "
,
CUDA_VERSION
);
#endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR
(
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is "
,
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS vers
i
on is "
,
CUBLAS_VERSION
);
#endif
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \
CUBLAS_VERSION < 130000
NVTE_CHECK
(
cuda
::
cudart_version
()
>=
12020
&&
cuda
::
cudart_version
()
<
13000
,
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA verson is "
,
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA vers
i
on is "
,
cuda
::
cudart_version
());
NVTE_CHECK
(
cublas_version
()
>=
120205
&&
cublas_version
()
<
130000
,
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS verson is "
,
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS vers
i
on is "
,
cublas_version
());
if
(
m_split
==
0
)
m_split
=
1
;
if
(
n_split
==
0
)
n_split
=
1
;
...
...
@@ -850,20 +833,23 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
#ifndef __HIP_PLATFORM_AMD__
// Check CUDA and cuBLAS versions
#if !(CUDA_VERSION >= 12020 && CU
BLAS
_VERSION
>=
13000)
NVTE_ERROR
(
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is "
,
#if !(CUDA_VERSION >= 12020 && CU
DA
_VERSION
<
13000)
NVTE_ERROR
(
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA vers
i
on is "
,
CUDA_VERSION
);
#endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR
(
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is "
,
CUBLAS_VERSION
);
NVTE_ERROR
(
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is "
,
CUBLAS_VERSION
);
#endif
NVTE_CHECK
(
cuda
::
cudart_version
()
>=
12020
&&
cuda
::
cudart_version
()
<
13000
,
"Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA verson is "
,
cuda
::
cudart_version
());
NVTE_CHECK
(
transformer_engine
::
cuda
::
cudart_version
()
>=
12020
&&
transformer_engine
::
cuda
::
cudart_version
()
<
13000
,
"Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is "
,
transformer_engine
::
cuda
::
cudart_version
());
NVTE_CHECK
(
cublas_version
()
>=
120205
&&
cublas_version
()
<
130000
,
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS verson is "
,
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS vers
i
on is "
,
cublas_version
());
#endif
...
...
@@ -934,15 +920,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
#endif //__HIP_PLATFORM_AMD__
}
void
nvte_multi_stream_cublas_gemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
const
int
num_gemms
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
*
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_stream_cublas_gemm
);
void
multi_stream_cublas_gemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
const
int
num_gemms
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
*
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
int
num_streams
=
nvte_get_num_compute_streams
();
...
...
@@ -989,6 +971,25 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
}
}
void
nvte_multi_stream_cublas_gemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
const
int
num_gemms
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
*
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_stream_cublas_gemm
);
using
namespace
transformer_engine
;
// Deprecation warning
NVTE_WARN
(
"nvte_multi_stream_cublas_gemm is deprecated and will be removed in a future release. "
"Please migrate to nvte_multi_tensor_gemm (with CUTLASS Grouped GEMM support when "
"applicable)."
);
multi_stream_cublas_gemm
(
A
,
B
,
D
,
bias
,
pre_gelu_out
,
num_gemms
,
transa
,
transb
,
grad
,
workspace
,
accumulate
,
use_split_accumulator
,
math_sm_count
,
stream
);
}
#ifndef __HIP_PLATFORM_AMD__
namespace
transformer_engine
{
...
...
@@ -1006,7 +1007,6 @@ void nvte_grouped_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
NVTETensor
*
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_grouped_gemm
);
using
namespace
transformer_engine
;
std
::
vector
<
const
Tensor
*>
inputA
;
...
...
@@ -1307,4 +1307,98 @@ void nvte_cublas_batchgemm_v3(const NVTETensor A, const NVTETensor B, const NVTE
handle
);
}
#endif
\ No newline at end of file
#endif
void
nvte_multi_tensor_gemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
const
int
num_gemms
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
*
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_tensor_gemm
);
#ifdef __HIP_PLATFORM_AMD__
const
char
*
NVTE_USE_HIPBLASLT_GROUPEDGEMM
=
std
::
getenv
(
"NVTE_USE_HIPBLASLT_GROUPEDGEMM"
);
if
(
NVTE_USE_HIPBLASLT_GROUPEDGEMM
!=
nullptr
&&
NVTE_USE_HIPBLASLT_GROUPEDGEMM
[
0
]
==
'1'
){
nvte_grouped_gemm
(
A
,
B
,
D
,
bias
,
pre_gelu_out
,
num_gemms
,
transa
,
transb
,
grad
,
workspace
,
accumulate
,
use_split_accumulator
,
math_sm_count
,
stream
);
}
else
{
multi_stream_cublas_gemm
(
A
,
B
,
D
,
bias
,
pre_gelu_out
,
num_gemms
,
transa
,
transb
,
grad
,
workspace
,
accumulate
,
use_split_accumulator
,
math_sm_count
,
stream
);
}
#else
const
int
current_device
=
transformer_engine
::
cuda
::
current_device
();
const
bool
is_hopper
=
(
transformer_engine
::
cuda
::
sm_arch
(
current_device
)
==
90
);
const
bool
use_cutlass
=
transformer_engine
::
getenv
<
bool
>
(
"NVTE_USE_CUTLASS_GROUPED_GEMM"
,
false
);
const
bool
warn_fallback
=
transformer_engine
::
getenv
<
bool
>
(
"NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK"
,
false
);
auto
cublas_path
=
[
&
]()
{
multi_stream_cublas_gemm
(
A
,
B
,
D
,
bias
,
pre_gelu_out
,
num_gemms
,
transa
,
transb
,
grad
,
workspace
,
accumulate
,
use_split_accumulator
,
math_sm_count
,
stream
);
};
// Currently only support cutlass group gemm on Hopper Arch
if
(
!
(
is_hopper
&&
use_cutlass
))
{
cublas_path
();
return
;
}
auto
is_empty_arr
=
[
&
](
const
NVTETensor
*
p
)
->
bool
{
if
(
p
==
nullptr
)
return
true
;
for
(
int
i
=
0
;
i
<
num_gemms
;
++
i
)
{
if
(
transformer_engine
::
convertNVTETensor
(
p
[
i
])
->
has_data
())
return
false
;
}
return
true
;
};
auto
all_groups_uniform_k128
=
[
&
](
const
NVTETensor
*
p
,
bool
trans
)
->
bool
{
int64_t
ref_k
=
-
1
;
for
(
size_t
i
=
0
;
i
<
num_gemms
;
i
++
)
{
const
auto
tensor
=
transformer_engine
::
convertNVTETensorCheck
(
p
[
i
]);
const
int
k
=
trans
?
tensor
->
data
.
shape
[
0
]
:
tensor
->
data
.
shape
[
1
];
if
((
k
&
127
)
!=
0
)
return
false
;
if
(
ref_k
<
0
)
ref_k
=
k
;
else
if
(
k
!=
ref_k
)
return
false
;
}
return
true
;
};
auto
is_supported_dtype
=
[
&
]()
->
bool
{
auto
*
inputA
=
transformer_engine
::
convertNVTETensorCheck
(
A
[
0
]);
auto
*
inputB
=
transformer_engine
::
convertNVTETensorCheck
(
B
[
0
]);
auto
*
OutputD
=
transformer_engine
::
convertNVTETensorCheck
(
D
[
0
]);
auto
A_type
=
get_cuda_dtype
(
inputA
->
data
.
dtype
);
auto
B_type
=
get_cuda_dtype
(
inputB
->
data
.
dtype
);
auto
D_type
=
get_cuda_dtype
(
OutputD
->
data
.
dtype
);
return
(
A_type
==
B_type
)
&&
(
A_type
==
D_type
)
&&
((
A_type
==
CUDA_R_16BF
)
||
(
A_type
==
CUDA_R_16F
));
};
// CUTLASS Grouped GEMM fast path (SM90/TMA)
// Conditions:
// - No fused epilogue: both bias and pre_gelu_out are empty.
// - Supported dtypes only: FP16/BF16 (FP32 accumulate).
// - Uniform K across groups and K % 128 == 0.
// - use_split_accumulator is ignored for FP16/BF16.
// - grad is irrelevant when bias/pre_gelu_out are empty.
//
// Otherwise, fall back to cuBLAS.
if
(
is_empty_arr
(
bias
)
&&
is_empty_arr
(
pre_gelu_out
)
&&
is_supported_dtype
()
&&
all_groups_uniform_k128
(
B
,
transb
))
{
cutlass_grouped_gemm
(
A
,
B
,
D
,
num_gemms
,
transa
,
transb
,
grad
,
workspace
,
accumulate
,
current_device
,
math_sm_count
,
stream
);
}
else
{
if
(
warn_fallback
)
{
NVTE_WARN
(
"Fallback to cuBLAS grouped GEMM."
);
}
cublas_path
();
}
#endif
}
transformer_engine/common/gemm/cutlass_grouped_gemm.cu
0 → 100644
View file @
27ddce40
/***************************************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
**************************************************************************************************/
#include "cutlass/bfloat16.h"
#include "cutlass/cutlass.h"
#include "cutlass_grouped_gemm.cuh"
namespace
transformer_engine
{
namespace
grouped_gemm
{
// Explicit template instantiation to match the template declarations in the .cuh
template
void
CutlassGroupedGemm
<
false
,
false
,
cutlass
::
half_t
>(
const
NVTETensor
*
,
const
NVTETensor
*
,
NVTETensor
*
,
NVTETensor
*
,
float
,
float
,
int
,
cudaStream_t
,
int
,
int
);
template
void
CutlassGroupedGemm
<
true
,
false
,
cutlass
::
half_t
>(
const
NVTETensor
*
,
const
NVTETensor
*
,
NVTETensor
*
,
NVTETensor
*
,
float
,
float
,
int
,
cudaStream_t
,
int
,
int
);
template
void
CutlassGroupedGemm
<
false
,
true
,
cutlass
::
half_t
>(
const
NVTETensor
*
,
const
NVTETensor
*
,
NVTETensor
*
,
NVTETensor
*
,
float
,
float
,
int
,
cudaStream_t
,
int
,
int
);
template
void
CutlassGroupedGemm
<
false
,
false
,
cutlass
::
bfloat16_t
>(
const
NVTETensor
*
,
const
NVTETensor
*
,
NVTETensor
*
,
NVTETensor
*
,
float
,
float
,
int
,
cudaStream_t
,
int
,
int
);
template
void
CutlassGroupedGemm
<
true
,
false
,
cutlass
::
bfloat16_t
>(
const
NVTETensor
*
,
const
NVTETensor
*
,
NVTETensor
*
,
NVTETensor
*
,
float
,
float
,
int
,
cudaStream_t
,
int
,
int
);
template
void
CutlassGroupedGemm
<
false
,
true
,
cutlass
::
bfloat16_t
>(
const
NVTETensor
*
,
const
NVTETensor
*
,
NVTETensor
*
,
NVTETensor
*
,
float
,
float
,
int
,
cudaStream_t
,
int
,
int
);
}
// namespace grouped_gemm
}
// namespace transformer_engine
void
cutlass_grouped_gemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
int
num_gemms
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
*
workspace
,
bool
accumulate
,
int
device
,
int
math_sm_count
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
auto
*
inputA
=
convertNVTETensorCheck
(
A
[
0
]);
auto
*
inputB
=
convertNVTETensorCheck
(
B
[
0
]);
float
one
=
1.0
;
float
zero
=
0.0
;
float
alpha
=
one
;
float
beta
=
(
accumulate
)
?
one
:
zero
;
auto
dispatch
=
[
&
](
auto
tag
)
{
using
T
=
decltype
(
tag
);
if
(
!
transa
&&
!
transb
)
{
grouped_gemm
::
CutlassGroupedGemm
<
false
,
false
,
T
>
(
B
,
A
,
D
,
workspace
,
alpha
,
beta
,
num_gemms
,
stream
,
device
,
math_sm_count
);
}
else
if
(
!
transb
&&
transa
)
{
grouped_gemm
::
CutlassGroupedGemm
<
false
,
true
,
T
>
(
B
,
A
,
D
,
workspace
,
alpha
,
beta
,
num_gemms
,
stream
,
device
,
math_sm_count
);
}
else
if
(
transb
&&
!
transa
)
{
grouped_gemm
::
CutlassGroupedGemm
<
true
,
false
,
T
>
(
B
,
A
,
D
,
workspace
,
alpha
,
beta
,
num_gemms
,
stream
,
device
,
math_sm_count
);
}
else
{
NVTE_ERROR
(
"Layout 'TT' is not supported by cutlass_grouped_gemm."
);
}
};
if
(
inputA
->
data
.
dtype
==
DType
::
kBFloat16
)
{
dispatch
(
cutlass
::
bfloat16_t
{});
}
else
if
(
inputA
->
data
.
dtype
==
DType
::
kFloat16
)
{
dispatch
(
cutlass
::
half_t
{});
}
else
{
NVTE_ERROR
(
"Unsupported dtype: only BF16(FP16) are supported."
);
}
}
transformer_engine/common/gemm/cutlass_grouped_gemm.cuh
0 → 100644
View file @
27ddce40
/***************************************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
**************************************************************************************************/
//
// Copyright (c) 2025 Shopee Inc. All Rights Reserved.
//
/**
* @file: cutlass_grouped_gemm.cuh
* @author: min.yang@shopee.com, yangfan.bai@shopee.com, finch.li@shopee.com
* @date: 2025-08-08 16:20:00
* @brief: cutlass group gemm kernel.
**/
#pragma once
#include <transformer_engine/transformer_engine.h>
#include <cub/cub.cuh>
#include <type_traits>
#include "../common.h"
#include "../util/logging.h"
#include "common/util/system.h"
#include "cute/tensor.hpp"
#include "cutlass/bfloat16.h"
#include "cutlass/complex.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/device_memory.h"
#include "cutlass/util/packed_stride.hpp"
namespace
transformer_engine
{
namespace
grouped_gemm
{
template
<
bool
trans_a
>
using
GroupedGemmInputALayout
=
std
::
conditional_t
<
trans_a
,
::
cutlass
::
layout
::
ColumnMajor
,
::
cutlass
::
layout
::
RowMajor
>
;
template
<
bool
trans_b
>
using
GroupedGemmInputBLayout
=
std
::
conditional_t
<
trans_b
,
::
cutlass
::
layout
::
ColumnMajor
,
::
cutlass
::
layout
::
RowMajor
>
;
using
ProblemShapeType
=
cute
::
Shape
<
int
,
int
,
int
>
;
using
ProblemShape
=
cutlass
::
gemm
::
GroupProblemShape
<
ProblemShapeType
>
;
// <M,N,K> per group
template
<
typename
ScheduleConfig
>
struct
GemmGivenSchedule
{
using
ElementA
=
typename
ScheduleConfig
::
DataType
;
// Element type for A matrix operand
using
ElementB
=
typename
ScheduleConfig
::
DataType
;
// Element type for B matrix operand
using
ElementC
=
typename
ScheduleConfig
::
DataType
;
// Element type for C and D matrix operands
// A matrix configuration
using
LayoutA
=
typename
ScheduleConfig
::
LayoutA
;
// Layout type for A matrix operand
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
// Alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using
LayoutB
=
typename
ScheduleConfig
::
LayoutB
;
// Layout type for B matrix operand
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
// Alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using
LayoutC
=
typename
ScheduleConfig
::
LayoutC
;
// Layout type for C and D matrix operands
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementC
>::
value
;
// Alignment of C matrix in units of elements (up to 16 bytes)
// Core kernel configurations
using
ElementAccumulator
=
float
;
// Element type for internal accumulation
using
ArchTag
=
cutlass
::
arch
::
Sm90
;
// Tag indicating the minimum SM that supports the intended feature
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
// Operator class tag
using
StageCountType
=
cutlass
::
gemm
::
collective
::
StageCountAuto
;
// Stage count maximized based on the tile size
using
TileShape
=
typename
ScheduleConfig
::
TileShape
;
// Threadblock-level tile size
using
ClusterShape
=
typename
ScheduleConfig
::
ClusterShape
;
// Shape of the threadblocks in a cluster
using
KernelSchedule
=
typename
ScheduleConfig
::
KernelSchedule
;
// Kernel to launch
using
EpilogueSchedule
=
typename
ScheduleConfig
::
EpilogueSchedule
;
// Epilogue to launch
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm90
,
cutlass
::
arch
::
OpClassTensorOp
,
TileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementAccumulator
,
ElementC
,
LayoutC
*
,
AlignmentC
,
ElementC
,
LayoutC
*
,
AlignmentC
,
EpilogueSchedule
,
cutlass
::
epilogue
::
fusion
::
LinearCombination
<
ElementC
,
ElementAccumulator
>>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
LayoutA
*
,
AlignmentA
,
ElementB
,
LayoutB
*
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
KernelSchedule
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
ProblemShape
,
CollectiveMainloop
,
CollectiveEpilogue
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
};
template
<
typename
DataType_
,
bool
trans_a
,
bool
trans_b
>
struct
ScheduleConfig
{
using
KernelSchedule
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecializedPingpong
;
using
TileShape
=
cute
::
Shape
<
cute
::
_128
,
cute
::
_128
,
cute
::
_128
>
;
using
ClusterShape
=
cute
::
Shape
<
cute
::
_1
,
cute
::
_2
,
cute
::
_1
>
;
// TODO(Alan): Add tuning for different scenarios to select the optimal configuration,
// as the current configuration may not be the best.
// using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
// using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
// using TileShape = Shape<cute::_256, cute::_128, cute::_128>;
// using ClusterShape = Shape<cute::_1, cute::_2, cute::_1>;
using
LayoutA
=
GroupedGemmInputALayout
<
trans_a
>
;
using
LayoutB
=
GroupedGemmInputBLayout
<
trans_b
>
;
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
using
DataType
=
DataType_
;
};
template
<
typename
DataType_
,
bool
trans_a
,
bool
trans_b
>
using
GemmGrouped
=
typename
GemmGivenSchedule
<
ScheduleConfig
<
DataType_
,
trans_a
,
trans_b
>>::
Gemm
;
template
<
typename
GemmT
,
typename
ElementA
,
typename
ElementB
,
typename
ElementC
,
typename
StrideA
,
typename
StrideB
,
typename
StrideC
>
typename
GemmT
::
Arguments
MakeArguments
(
int
num_experts
,
void
*
problem_sizes_host
,
void
*
problem_sizes
,
const
ElementA
**
ptr_A
,
StrideA
*
stride_A
,
const
ElementB
**
ptr_B
,
StrideB
*
stride_B
,
ElementC
**
ptr_C
,
StrideC
*
stride_C
,
float
alpha
,
float
beta
,
int
device
,
int
math_sm_count
)
{
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
cutlass
::
KernelHardwareInfo
kernel_hw_info
=
cutlass
::
KernelHardwareInfo
::
make_kernel_hardware_info
<
typename
GemmT
::
GemmKernel
>
(
device
,
math_sm_count
);
typename
GemmT
::
Arguments
arguments
;
decltype
(
arguments
.
epilogue
.
thread
)
fusion_args
;
fusion_args
.
alpha
=
alpha
;
fusion_args
.
beta
=
beta
;
fusion_args
.
alpha_ptr
=
nullptr
;
fusion_args
.
beta_ptr
=
nullptr
;
fusion_args
.
alpha_ptr_array
=
nullptr
;
fusion_args
.
beta_ptr_array
=
nullptr
;
// Single alpha and beta for all groups
fusion_args
.
dAlpha
=
{
cute
::
_0
{},
cute
::
_0
{},
0
};
fusion_args
.
dBeta
=
{
cute
::
_0
{},
cute
::
_0
{},
0
};
arguments
=
typename
GemmT
::
Arguments
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGrouped
,
{
num_experts
,
reinterpret_cast
<
ProblemShapeType
*>
(
problem_sizes
),
reinterpret_cast
<
ProblemShapeType
const
*>
(
problem_sizes_host
)},
{
ptr_A
,
stride_A
,
ptr_B
,
stride_B
},
{
fusion_args
,
(
beta
>
0.0
)
?
(
const
ElementC
**
)
ptr_C
:
nullptr
,
// NOLINT(*)
stride_C
,
ptr_C
,
stride_C
,
},
kernel_hw_info
};
return
arguments
;
}
template
<
typename
T
>
inline
__device__
__host__
T
ROUND_UP
(
T
m
,
T
n
)
{
return
(
m
+
n
-
1
)
/
n
*
n
;
}
template
<
typename
T
>
void
debug_type
()
{
std
::
cout
<<
typeid
(
T
).
name
()
<<
std
::
endl
;
}
int64_t
inline
getGemmCoordSize
(
int64_t
num_gemms
)
{
return
(
int64_t
)(
ROUND_UP
(
num_gemms
*
sizeof
(
ProblemShapeType
),
128UL
));
}
int64_t
inline
getPtrSize
(
int64_t
num_gemms
)
{
return
(
int64_t
)(
ROUND_UP
(
num_gemms
*
sizeof
(
half
*
),
128UL
));
}
int64_t
inline
getLddSize
(
int64_t
num_gemms
)
{
return
(
int64_t
)(
ROUND_UP
(
num_gemms
*
sizeof
(
int64_t
),
128UL
));
}
// cpu workspace size is 4MB
static
constexpr
size_t
kCPUWorkSpaceSize
=
4
*
1024
*
1024
;
static
char
*
getHostWorkspace
()
{
static
std
::
once_flag
flag
;
static
std
::
shared_ptr
<
char
>
workspace
;
std
::
call_once
(
flag
,
[
&
]()
{
workspace
=
std
::
shared_ptr
<
char
>
(
reinterpret_cast
<
char
*>
(
std
::
malloc
(
kCPUWorkSpaceSize
)),
[](
char
*
p
)
{
if
(
p
)
std
::
free
(
p
);
});
if
(
!
workspace
)
{
throw
std
::
bad_alloc
();
}
});
return
workspace
.
get
();
}
template
<
bool
trans_a
,
bool
trans_b
,
typename
Element
>
void
CutlassGroupedGemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
NVTETensor
*
workspace
,
float
alpha
,
float
beta
,
int
num_gemms
,
cudaStream_t
stream
,
int
device
,
int
math_sm_count
)
{
using
Gemm
=
GemmGrouped
<
Element
,
trans_a
,
trans_b
>
;
using
LayoutA
=
typename
Gemm
::
LayoutA
;
using
LayoutB
=
typename
Gemm
::
LayoutB
;
using
LayoutC
=
typename
Gemm
::
LayoutC
;
using
ElementA
=
typename
Gemm
::
ElementA
;
using
ElementB
=
typename
Gemm
::
ElementB
;
using
ElementC
=
typename
Gemm
::
ElementC
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
InternalStrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
InternalStrideB
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
InternalStrideC
;
typename
Gemm
::
Arguments
arguments
;
size_t
kernel_workspace_size
=
Gemm
::
get_workspace_size
(
arguments
);
auto
gemm_coord_size
=
getGemmCoordSize
(
num_gemms
);
auto
ptr_size
=
getPtrSize
(
num_gemms
);
auto
ldd_size
=
getLddSize
(
num_gemms
);
auto
param_workspace_size
=
3
*
ptr_size
+
3
*
ldd_size
+
gemm_coord_size
;
NVTE_CHECK
(
param_workspace_size
<
kCPUWorkSpaceSize
,
"Insufficient kCPUWorkSpaceSize size: required="
,
static_cast
<
int64_t
>
(
param_workspace_size
),
", available="
,
static_cast
<
int64_t
>
(
kCPUWorkSpaceSize
),
" for CUTLASS grouped GEMM."
);
auto
total_workspace_size
=
param_workspace_size
+
kernel_workspace_size
;
transformer_engine
::
Tensor
*
wspace
=
transformer_engine
::
convertNVTETensor
(
workspace
[
0
]);
NVTE_CHECK
(
total_workspace_size
<
wspace
->
numel
(),
"Insufficient workspace[0] size: required="
,
static_cast
<
int64_t
>
(
total_workspace_size
),
", available="
,
static_cast
<
int64_t
>
(
wspace
->
numel
()),
" for CUTLASS grouped GEMM."
);
char
*
workspace_ptr
=
reinterpret_cast
<
char
*>
(
wspace
->
data
.
dptr
);
char
*
kernel_workspace_ptr
=
nullptr
;
char
*
host_workspace
=
getHostWorkspace
();
ProblemShapeType
*
problem_sizes_host
=
reinterpret_cast
<
ProblemShapeType
*>
(
host_workspace
);
ElementA
**
ptr_A_host
=
reinterpret_cast
<
ElementA
**>
(
host_workspace
+
gemm_coord_size
);
ElementB
**
ptr_B_host
=
reinterpret_cast
<
ElementB
**>
(
host_workspace
+
gemm_coord_size
+
ptr_size
);
ElementC
**
ptr_C_host
=
reinterpret_cast
<
ElementC
**>
(
host_workspace
+
gemm_coord_size
+
2
*
ptr_size
);
int64_t
*
lda_host
=
reinterpret_cast
<
int64_t
*>
(
host_workspace
+
gemm_coord_size
+
3
*
ptr_size
+
0
*
ldd_size
);
int64_t
*
ldb_host
=
reinterpret_cast
<
int64_t
*>
(
host_workspace
+
gemm_coord_size
+
3
*
ptr_size
+
1
*
ldd_size
);
int64_t
*
ldc_host
=
reinterpret_cast
<
int64_t
*>
(
host_workspace
+
gemm_coord_size
+
3
*
ptr_size
+
2
*
ldd_size
);
for
(
size_t
i
=
0
;
i
<
num_gemms
;
i
++
)
{
const
transformer_engine
::
Tensor
*
inputA
=
transformer_engine
::
convertNVTETensorCheck
(
A
[
i
]);
const
transformer_engine
::
Tensor
*
inputB
=
transformer_engine
::
convertNVTETensorCheck
(
B
[
i
]);
transformer_engine
::
Tensor
*
outputD
=
transformer_engine
::
convertNVTETensor
(
D
[
i
]);
const
int
m
=
trans_a
?
inputA
->
data
.
shape
[
1
]
:
inputA
->
data
.
shape
[
0
];
const
int
k
=
trans_a
?
inputA
->
data
.
shape
[
0
]
:
inputA
->
data
.
shape
[
1
];
const
int
n
=
trans_b
?
inputB
->
data
.
shape
[
0
]
:
inputB
->
data
.
shape
[
1
];
auto
problem
=
ProblemShapeType
(
m
,
n
,
k
);
problem_sizes_host
[
i
]
=
problem
;
ptr_A_host
[
i
]
=
reinterpret_cast
<
ElementA
*>
(
inputA
->
data
.
dptr
);
ptr_B_host
[
i
]
=
reinterpret_cast
<
ElementB
*>
(
inputB
->
data
.
dptr
);
ptr_C_host
[
i
]
=
reinterpret_cast
<
ElementC
*>
(
outputD
->
data
.
dptr
);
lda_host
[
i
]
=
LayoutA
::
packed
({
m
,
k
}).
stride
(
0
);
ldb_host
[
i
]
=
LayoutB
::
packed
({
k
,
n
}).
stride
(
0
);
ldc_host
[
i
]
=
LayoutC
::
packed
({
m
,
n
}).
stride
(
0
);
}
cudaMemcpyAsync
(
workspace_ptr
,
host_workspace
,
param_workspace_size
,
cudaMemcpyHostToDevice
,
stream
);
char
*
param_workspace_ptr
=
workspace_ptr
;
ProblemShapeType
*
problem_sizes_device
=
reinterpret_cast
<
ProblemShapeType
*>
(
param_workspace_ptr
);
const
ElementA
**
ptr_A
=
reinterpret_cast
<
const
ElementA
**>
(
reinterpret_cast
<
char
*>
(
param_workspace_ptr
)
+
gemm_coord_size
);
const
ElementB
**
ptr_B
=
reinterpret_cast
<
const
ElementB
**>
(
reinterpret_cast
<
char
*>
(
param_workspace_ptr
)
+
gemm_coord_size
+
1
*
ptr_size
);
ElementC
**
ptr_C
=
reinterpret_cast
<
ElementC
**>
(
reinterpret_cast
<
char
*>
(
param_workspace_ptr
)
+
gemm_coord_size
+
2
*
ptr_size
);
StrideA
*
lda
=
reinterpret_cast
<
StrideA
*>
(
reinterpret_cast
<
char
*>
(
param_workspace_ptr
)
+
gemm_coord_size
+
3
*
ptr_size
+
0
*
ldd_size
);
StrideB
*
ldb
=
reinterpret_cast
<
StrideB
*>
(
reinterpret_cast
<
char
*>
(
param_workspace_ptr
)
+
gemm_coord_size
+
3
*
ptr_size
+
1
*
ldd_size
);
StrideC
*
ldc
=
reinterpret_cast
<
StrideC
*>
(
reinterpret_cast
<
char
*>
(
param_workspace_ptr
)
+
gemm_coord_size
+
3
*
ptr_size
+
2
*
ldd_size
);
kernel_workspace_ptr
=
workspace_ptr
+
param_workspace_size
;
arguments
=
MakeArguments
<
Gemm
,
ElementA
,
ElementB
,
ElementC
,
StrideA
,
StrideB
,
StrideC
>
(
num_gemms
,
problem_sizes_host
,
problem_sizes_device
,
ptr_A
,
lda
,
ptr_B
,
ldb
,
ptr_C
,
ldc
,
alpha
,
beta
,
device
,
math_sm_count
);
Gemm
gemm
;
// Check can implement the kernel.
if
(
gemm
.
can_implement
(
arguments
)
!=
cutlass
::
Status
::
kSuccess
)
{
NVTE_CHECK
(
false
,
"Failed to implement CUTLASS Grouped GEMM"
);
}
// Initialize the kernel.
if
(
gemm
.
initialize
(
arguments
,
kernel_workspace_ptr
)
!=
cutlass
::
Status
::
kSuccess
)
{
NVTE_CHECK
(
false
,
"Failed to initialize CUTLASS Grouped GEMM"
);
}
// Execute the kernel in the current stream.
if
(
gemm
.
run
(
stream
)
!=
cutlass
::
Status
::
kSuccess
)
{
NVTE_CHECK
(
false
,
"Failed to run CUTLASS Grouped GEMM"
);
}
}
}
// namespace grouped_gemm
}
// namespace transformer_engine
void
cutlass_grouped_gemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
int
num_gemms
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
*
workspace
,
bool
accumulate
,
int
device
,
int
math_sm_count
,
cudaStream_t
stream
);
transformer_engine/common/include/transformer_engine/comm_gemm.h
0 → 100644
View file @
27ddce40
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file comm_gemm.h
* \brief Functions for distributed (multi-GPU) matrix multiplication.
*
* This API is a TE-native binding to cuBLASMp library.
* Refer here: https://docs.nvidia.com/cuda/cublasmp/usage/tp.html for specific
* patterns, which allow communication-computation overlap.
*
* All GEMM functions here have the same computation semantic, as expressed
* on global matrices, similar to nvte_cublas_gemm call:
* - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors
* - `D = AB + bias` if `pre_gelu_out` is empty and `bias` is not empty
* - `D = GELU(AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors
*
* Functions differ in matrix distribution patterns
*/
#ifndef TRANSFORMER_ENGINE_COMMON_COMM_GEMM_H_
#define TRANSFORMER_ENGINE_COMMON_COMM_GEMM_H_
#include <nccl.h>
#include <stdint.h>
#include "transformer_engine.h"
#ifdef __cplusplus
extern
"C"
{
#else
#include <stdbool.h>
#endif
typedef
struct
NVTECommGemmCtx
NVTECommGemmCtx
;
enum
NVTECommGemmAlgoType
{
kNVTECommGemmAlgoDefault
=
0
,
kNVTECommGemmAlgoSplitP2P
=
1
,
kNVTECommGemmAlgoSplitMulticast
=
2
,
kNVTECommGemmAlgoAtomicP2P
=
3
,
kNVTECommGemmAlgoAtomicMulticast
=
4
};
/*! \brief Create a comm-gemm context.
*
* \param[in] comm NCCL communicator.
* \param[in] nranks Number of ranks.
* \param[in] rank Local rank.
*/
NVTECommGemmCtx
*
nvte_comm_gemm_ctx_create
(
ncclComm_t
comm
,
int
nranks
,
int
rank
);
/*! \brief Destroy a comm-gemm context.
*
* \param[in] ctx Context to destroy.
*/
void
nvte_comm_gemm_ctx_destroy
(
NVTECommGemmCtx
*
ctx
);
/*! \brief Perform AllGather communication followed by GEMM
*
* Gathers distributed data from all ranks, then computes matrix multiplication.
*
* \param[in] ctx Comm-GEMM context.
* \param[in] m Global m dimension.
* \param[in] n Global n dimension.
* \param[in] k Global k dimension.
* \param[in] a Local part of A matrix.
* \param[in] b Local part of B matrix.
* \param[in,out] d Local part of D matrix.
* \param[in] bias Bias tensor.
* \param[in,out] pre_act_out Local part of output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of gradient computation.
* \param[in] accumulate Whether to accumulate the result into the D matrix.
* \param[in] comm_sm_count Number of GPU SMs to use for communication (default=0: use heuristics)
* \param[in] main_stream CUDA stream used for computation.
* \param[in] algo Algorithm to use.
*/
void
nvte_all_gather_gemm
(
NVTECommGemmCtx
*
ctx
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
NVTETensor
a
,
const
NVTETensor
b
,
const
NVTETensor
d
,
const
NVTETensor
bias
,
const
NVTETensor
pre_act_out
,
bool
transa
,
bool
transb
,
bool
grad
,
bool
accumulate
,
int
comm_sm_count
,
cudaStream_t
main_stream
,
NVTECommGemmAlgoType
algo
);
/*! \brief Perform GEMM followed by ReduceScatter communication
*
* Computes matrix multiplication, then distributes results across ranks with reduction.
*
* \param[in] ctx Comm-GEMM context.
* \param[in] m Global m dimension.
* \param[in] n Global n dimension.
* \param[in] k Global k dimension.
* \param[in] a Local part of A matrix.
* \param[in] b Local part of B matrix.
* \param[in,out] d Local part of D matrix.
* \param[in] bias Bias tensor.
* \param[in,out] pre_act_out Local part of output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of gradient computation.
* \param[in] accumulate Whether to accumulate the result into the D matrix.
* \param[in] comm_sm_count Number of GPU SMs to use for communication (default=0: use heuristics)
* \param[in] main_stream CUDA stream used for computation.
* \param[in] algo Algorithm to use.
*/
void
nvte_gemm_reduce_scatter
(
NVTECommGemmCtx
*
ctx
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
NVTETensor
a
,
const
NVTETensor
b
,
const
NVTETensor
d
,
const
NVTETensor
bias
,
const
NVTETensor
pre_act_out
,
bool
transa
,
bool
transb
,
bool
grad
,
bool
accumulate
,
int
comm_sm_count
,
cudaStream_t
main_stream
,
NVTECommGemmAlgoType
algo
);
/*! \brief Perform GEMM followed by AllReduce communication
*
* Computes matrix multiplication, then reduces results across all ranks.
*
* \param[in] ctx Comm-GEMM context.
* \param[in] m Global m dimension.
* \param[in] n Global n dimension.
* \param[in] k Global k dimension.
* \param[in] a Local part of A matrix.
* \param[in] b Local part of B matrix.
* \param[in,out] d Local part of D matrix.
* \param[in] bias Bias tensor.
* \param[in,out] pre_act_out Local part of output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of gradient computation.
* \param[in] accumulate Whether to accumulate the result into the D matrix.
* \param[in] comm_sm_count Number of GPU SMs to use for communication (default=0: use heuristics)
* \param[in] main_stream CUDA stream used for computation.
* \param[in] algo Algorithm to use.
*/
void
nvte_gemm_all_reduce
(
NVTECommGemmCtx
*
ctx
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
NVTETensor
a
,
const
NVTETensor
b
,
const
NVTETensor
d
,
const
NVTETensor
bias
,
const
NVTETensor
pre_act_out
,
bool
transa
,
bool
transb
,
bool
grad
,
bool
accumulate
,
int
comm_sm_count
,
cudaStream_t
main_stream
,
NVTECommGemmAlgoType
algo
);
/*! \brief Get local number of rows or columns.
*
* Utility function to get local dimension.
* Block size, nranks and local rank is derived from the context ctx.
*
* \param[in] ctx Comm-GEMM context.
* \param[in] global_size Global dimension.
*/
int64_t
nvte_comm_gemm_numroc
(
NVTECommGemmCtx
*
ctx
,
int64_t
global_size
);
#ifdef __cplusplus
}
// extern "C"
#endif
#endif // TRANSFORMER_ENGINE_COMM_GEMM_H_
transformer_engine/common/include/transformer_engine/dropout.h
0 → 100644
View file @
27ddce40
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file dropout.h
* \brief Functions for dropout.
*/
#ifndef TRANSFORMER_ENGINE_DROPOUT_FP8_H_
#define TRANSFORMER_ENGINE_DROPOUT_FP8_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern
"C"
{
#endif
/*! \brief Dropout forward kernel.
*
* \param[in] input Input tensor.
* \param[out] output Output tensor.
* \param[out] mask Mask tensor. Each bit corresponds to an
* output tensor entry. Ones indicate kept
* entries and zeros indicate dropped entries.
* \param[in] rng_state RNG engine inputs.
* \param[in] dropout_probability Dropout probability.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_dropout_fwd
(
const
NVTETensor
input
,
NVTETensor
output
,
NVTETensor
mask
,
NVTETensor
rng_state
,
float
dropout_probability
,
cudaStream_t
stream
);
/*! \brief Dropout backward kernel.
*
* \param[in] grad_output Gradient of output tensor.
* \param[out] mask Mask tensor. Each bit corresponds to an
* output tensor entry. Ones indicate kept
* entries and zeros indicate dropped entries.
* \param[out] grad_input Gradient of input tensor.
* \param[in] dropout_probability Dropout probability.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_dropout_bwd
(
const
NVTETensor
grad_output
,
const
NVTETensor
mask
,
NVTETensor
grad_input
,
float
dropout_probability
,
cudaStream_t
stream
);
#ifdef __cplusplus
}
// extern "C"
#endif
#endif
Prev
1
2
3
4
5
6
7
8
9
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment