Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
2b05e121
Commit
2b05e121
authored
Jun 17, 2025
by
yuguo
Browse files
Merge commit '
a69692ac
' of...
Merge commit '
a69692ac
' of
https://github.com/NVIDIA/TransformerEngine
parents
0fd441c2
a69692ac
Changes
245
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
789 additions
and
339 deletions
+789
-339
transformer_engine/common/fused_attn/context_parallel.cu
transformer_engine/common/fused_attn/context_parallel.cu
+16
-16
transformer_engine/common/fused_attn/flash_attn.cu
transformer_engine/common/fused_attn/flash_attn.cu
+5
-5
transformer_engine/common/fused_attn/fused_attn.cpp
transformer_engine/common/fused_attn/fused_attn.cpp
+139
-124
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
...gine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
+40
-40
transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu
..._engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu
+6
-6
transformer_engine/common/fused_attn/fused_attn_fp8.cu
transformer_engine/common/fused_attn/fused_attn_fp8.cu
+26
-26
transformer_engine/common/fused_attn/kv_cache.cu
transformer_engine/common/fused_attn/kv_cache.cu
+12
-12
transformer_engine/common/fused_rope/fused_rope.cu
transformer_engine/common/fused_rope/fused_rope.cu
+8
-10
transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu
...mon/fused_softmax/scaled_aligned_causal_masked_softmax.cu
+4
-4
transformer_engine/common/fused_softmax/scaled_masked_softmax.cu
...rmer_engine/common/fused_softmax/scaled_masked_softmax.cu
+10
-11
transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu
...ommon/fused_softmax/scaled_upper_triang_masked_softmax.cu
+5
-5
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+36
-38
transformer_engine/common/include/transformer_engine/cast.h
transformer_engine/common/include/transformer_engine/cast.h
+11
-0
transformer_engine/common/include/transformer_engine/cast_transpose_noop.h
...e/common/include/transformer_engine/cast_transpose_noop.h
+7
-9
transformer_engine/common/include/transformer_engine/fused_attn.h
...mer_engine/common/include/transformer_engine/fused_attn.h
+148
-4
transformer_engine/common/include/transformer_engine/gemm.h
transformer_engine/common/include/transformer_engine/gemm.h
+0
-4
transformer_engine/common/include/transformer_engine/multi_stream.h
...r_engine/common/include/transformer_engine/multi_stream.h
+25
-0
transformer_engine/common/include/transformer_engine/multi_tensor.h
...r_engine/common/include/transformer_engine/multi_tensor.h
+204
-0
transformer_engine/common/include/transformer_engine/transformer_engine.h
...ne/common/include/transformer_engine/transformer_engine.h
+78
-14
transformer_engine/common/multi_tensor/adam.cu
transformer_engine/common/multi_tensor/adam.cu
+9
-11
No files found.
transformer_engine/common/fused_attn/context_parallel.cu
View file @
2b05e121
...
@@ -325,7 +325,7 @@ void thd_read_half_tensor(const Tensor &tensor, const Tensor &cu_seqlens, Tensor
...
@@ -325,7 +325,7 @@ void thd_read_half_tensor(const Tensor &tensor, const Tensor &cu_seqlens, Tensor
int
batch
=
cu_seqlens_shape
[
0
]
-
1
;
int
batch
=
cu_seqlens_shape
[
0
]
-
1
;
int
num_heads
=
tensor_shape
[
seq_dim
+
1
];
int
num_heads
=
tensor_shape
[
seq_dim
+
1
];
int
dim_per_head
=
tensor_shape
[
seq_dim
+
2
];
int
dim_per_head
=
tensor_shape
[
seq_dim
+
2
];
int
hidden_size_in_bytes
=
num_heads
*
dim_per_head
*
typeTo
Size
(
tensor
.
dtype
());
int
hidden_size_in_bytes
=
(
num_heads
*
dim_per_head
*
typeTo
NumBits
(
tensor
.
dtype
())
)
/
8
;
// For 128-bits load/store
// For 128-bits load/store
NVTE_CHECK
(
hidden_size_in_bytes
%
16
==
0
);
NVTE_CHECK
(
hidden_size_in_bytes
%
16
==
0
);
...
@@ -582,7 +582,7 @@ static void thd_grad_correction_helper(Tensor grad, const Tensor &grad_per_step,
...
@@ -582,7 +582,7 @@ static void thd_grad_correction_helper(Tensor grad, const Tensor &grad_per_step,
NVTE_CHECK
(
grad_per_step_shape
[
seq_dim
+
2
]
==
dim_per_head
);
NVTE_CHECK
(
grad_per_step_shape
[
seq_dim
+
2
]
==
dim_per_head
);
size_t
hidden_size
=
num_heads
*
dim_per_head
;
size_t
hidden_size
=
num_heads
*
dim_per_head
;
NVTE_CHECK
((
hidden_size
*
typeTo
Size
(
grad
.
dtype
()))
%
16
==
0
);
NVTE_CHECK
((
(
hidden_size
*
typeTo
NumBits
(
grad
.
dtype
()))
/
8
)
%
16
==
0
);
constexpr
unsigned
int
block
=
256
;
constexpr
unsigned
int
block
=
256
;
unsigned
int
grid_x
;
unsigned
int
grid_x
;
...
@@ -677,9 +677,9 @@ void nvte_cp_thd_read_half_tensor(const NVTETensor &tensor, const NVTETensor &cu
...
@@ -677,9 +677,9 @@ void nvte_cp_thd_read_half_tensor(const NVTETensor &tensor, const NVTETensor &cu
NVTE_API_CALL
(
nvte_thd_read_half_tensor
);
NVTE_API_CALL
(
nvte_thd_read_half_tensor
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
context_parallel
::
thd_read_half_tensor
(
*
reinterpret_cast
<
Tensor
*>
(
tensor
),
context_parallel
::
thd_read_half_tensor
(
*
convertNVTE
Tensor
Check
(
tensor
),
*
reinterpret_cast
<
Tensor
*>
(
cu_seqlens
),
*
convertNVTE
Tensor
Check
(
cu_seqlens
),
*
reinterpret_cast
<
Tensor
*>
(
half
),
half_idx
,
stream
);
*
convertNVTE
Tensor
Check
(
half
),
half_idx
,
stream
);
}
}
void
nvte_cp_thd_second_half_lse_correction
(
NVTETensor
lse
,
const
NVTETensor
&
lse_per_step
,
void
nvte_cp_thd_second_half_lse_correction
(
NVTETensor
lse
,
const
NVTETensor
&
lse_per_step
,
...
@@ -689,8 +689,8 @@ void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &ls
...
@@ -689,8 +689,8 @@ void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &ls
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
context_parallel
::
thd_second_half_lse_correction
(
context_parallel
::
thd_second_half_lse_correction
(
*
reinterpret_cast
<
Tensor
*>
(
lse
),
*
reinterpret_cast
<
Tensor
*>
(
lse_per_step
),
*
convertNVTE
Tensor
Check
(
lse
),
*
convertNVTE
Tensor
Check
(
lse_per_step
),
*
reinterpret_cast
<
Tensor
*>
(
cu_seqlens
),
lse_packed
,
stream
);
*
convertNVTE
Tensor
Check
(
cu_seqlens
),
lse_packed
,
stream
);
}
}
void
nvte_cp_thd_read_second_half_lse
(
const
NVTETensor
&
lse
,
const
NVTETensor
&
cu_seqlens
,
void
nvte_cp_thd_read_second_half_lse
(
const
NVTETensor
&
lse
,
const
NVTETensor
&
cu_seqlens
,
...
@@ -700,8 +700,8 @@ void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &c
...
@@ -700,8 +700,8 @@ void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &c
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
context_parallel
::
thd_read_second_half_lse
(
context_parallel
::
thd_read_second_half_lse
(
*
reinterpret_cast
<
Tensor
*>
(
lse
),
*
reinterpret_cast
<
Tensor
*>
(
cu_seqlens
),
*
convertNVTE
Tensor
Check
(
lse
),
*
convertNVTE
Tensor
Check
(
cu_seqlens
),
*
reinterpret_cast
<
Tensor
*>
(
half_lse
),
lse_packed
,
second_half_lse_seqlen
,
stream
);
*
convertNVTE
Tensor
Check
(
half_lse
),
lse_packed
,
second_half_lse_seqlen
,
stream
);
}
}
void
nvte_cp_thd_out_correction
(
NVTETensor
out
,
const
NVTETensor
&
out_per_step
,
void
nvte_cp_thd_out_correction
(
NVTETensor
out
,
const
NVTETensor
&
out_per_step
,
...
@@ -712,9 +712,9 @@ void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step,
...
@@ -712,9 +712,9 @@ void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step,
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
context_parallel
::
thd_out_correction
(
context_parallel
::
thd_out_correction
(
*
reinterpret_cast
<
Tensor
*>
(
out
),
*
reinterpret_cast
<
Tensor
*>
(
out_per_step
),
*
convertNVTE
Tensor
Check
(
out
),
*
convertNVTE
Tensor
Check
(
out_per_step
),
*
reinterpret_cast
<
Tensor
*>
(
lse
),
*
reinterpret_cast
<
Tensor
*>
(
lse_per_step
),
*
convertNVTE
Tensor
Check
(
lse
),
*
convertNVTE
Tensor
Check
(
lse_per_step
),
*
reinterpret_cast
<
Tensor
*>
(
cu_seqlens
),
only_second_half
,
lse_packed
,
stream
);
*
convertNVTE
Tensor
Check
(
cu_seqlens
),
only_second_half
,
lse_packed
,
stream
);
}
}
void
nvte_cp_thd_grad_correction
(
NVTETensor
grad
,
const
NVTETensor
&
grad_per_step
,
void
nvte_cp_thd_grad_correction
(
NVTETensor
grad
,
const
NVTETensor
&
grad_per_step
,
...
@@ -727,8 +727,8 @@ void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_ste
...
@@ -727,8 +727,8 @@ void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_ste
std
::
string
second_half_str
(
second_half
);
std
::
string
second_half_str
(
second_half
);
context_parallel
::
thd_grad_correction
(
context_parallel
::
thd_grad_correction
(
*
reinterpret_cast
<
Tensor
*>
(
grad
),
*
reinterpret_cast
<
Tensor
*>
(
grad_per_step
),
*
convertNVTE
Tensor
Check
(
grad
),
*
convertNVTE
Tensor
Check
(
grad_per_step
),
*
reinterpret_cast
<
Tensor
*>
(
cu_seqlens
),
first_half_str
,
second_half_str
,
stream
);
*
convertNVTE
Tensor
Check
(
cu_seqlens
),
first_half_str
,
second_half_str
,
stream
);
}
}
void
nvte_cp_thd_get_partitioned_indices
(
const
NVTETensor
&
cu_seqlens
,
NVTETensor
output
,
void
nvte_cp_thd_get_partitioned_indices
(
const
NVTETensor
&
cu_seqlens
,
NVTETensor
output
,
...
@@ -737,7 +737,7 @@ void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETenso
...
@@ -737,7 +737,7 @@ void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETenso
NVTE_API_CALL
(
nvte_thd_get_partitioned_indices
);
NVTE_API_CALL
(
nvte_thd_get_partitioned_indices
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
context_parallel
::
thd_get_partitioned_indices
(
*
reinterpret_cast
<
Tensor
*>
(
cu_seqlens
),
context_parallel
::
thd_get_partitioned_indices
(
*
convertNVTE
Tensor
Check
(
cu_seqlens
),
*
reinterpret_cast
<
Tensor
*>
(
output
),
total_tokens
,
*
convertNVTE
Tensor
Check
(
output
),
total_tokens
,
world_size
,
rank
,
stream
);
world_size
,
rank
,
stream
);
}
}
transformer_engine/common/fused_attn/flash_attn.cu
View file @
2b05e121
...
@@ -138,8 +138,8 @@ void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t s
...
@@ -138,8 +138,8 @@ void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t s
NVTE_API_CALL
(
nvte_prepare_flash_attn_fwd
);
NVTE_API_CALL
(
nvte_prepare_flash_attn_fwd
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
flash_attention
::
prepare_flash_attn_fwd
(
*
reinterpret_cast
<
Tensor
*>
(
qkvi
),
flash_attention
::
prepare_flash_attn_fwd
(
*
convertNVTE
Tensor
Check
(
qkvi
),
*
reinterpret_cast
<
Tensor
*>
(
qkv
),
stream
);
*
convertNVTE
Tensor
Check
(
qkv
),
stream
);
}
}
void
nvte_prepare_flash_attn_bwd
(
NVTETensor
q
,
NVTETensor
k
,
NVTETensor
v
,
NVTETensor
qkv
,
void
nvte_prepare_flash_attn_bwd
(
NVTETensor
q
,
NVTETensor
k
,
NVTETensor
v
,
NVTETensor
qkv
,
...
@@ -147,7 +147,7 @@ void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTET
...
@@ -147,7 +147,7 @@ void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTET
NVTE_API_CALL
(
nvte_prepare_flash_attn_bwd
);
NVTE_API_CALL
(
nvte_prepare_flash_attn_bwd
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
flash_attention
::
prepare_flash_attn_bwd
(
flash_attention
::
prepare_flash_attn_bwd
(
*
convertNVTETensorCheck
(
q
),
*
convertNVTETensorCheck
(
k
),
*
reinterpret_cast
<
Tensor
*>
(
q
),
*
reinterpret_cast
<
Tensor
*>
(
k
),
*
convertNVTETensorCheck
(
v
),
*
convertNVTETensorCheck
(
qkv
),
*
reinterpret_cast
<
Tensor
*>
(
v
),
*
reinterpret_cast
<
Tensor
*>
(
qkv
),
stream
);
stream
);
}
}
transformer_engine/common/fused_attn/fused_attn.cpp
View file @
2b05e121
...
@@ -134,10 +134,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) {
...
@@ -134,10 +134,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) {
// select a backend for fused attention
// select a backend for fused attention
NVTE_Fused_Attn_Backend
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
nvte_get_fused_attn_backend
(
NVTEDType
q_dtype
,
NVTEDType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
bool
is_training
,
NVTEDType
q_dtype
,
NVTEDType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Mask_Type
attn_mask_type
,
float
dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
float
dropout
,
size_t
num_attn_heads
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
NVTE_Fused_Attn_Backend
backend
=
NVTE_Fused_Attn_Backend
::
NVTE_No_Backend
;
NVTE_Fused_Attn_Backend
backend
=
NVTE_Fused_Attn_Backend
::
NVTE_No_Backend
;
const
int
device_id
=
cuda
::
current_device
();
const
int
device_id
=
cuda
::
current_device
();
...
@@ -216,12 +216,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
...
@@ -216,12 +216,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
}
}
if
(
if
(
// TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging
// TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging
// special conditions for blackwell
// TODO: enable THD max_t in f16_arbitrary_seqlen when support becomes available in 9.7
!
(
sm_arch_
>=
100
&&
(
head_dim_qk
>
128
||
head_dim_v
>
128
))
&&
// architecture
// architecture
((
cudnn_runtime_version
>=
8903
&&
sm_arch_
>=
80
)
||
((
cudnn_runtime_version
<
8903
&&
(
sm_arch_
==
80
||
sm_arch_
==
90
))
||
(
cudnn_runtime_version
<
8903
&&
(
sm_arch_
==
80
||
sm_arch_
==
90
)))
&&
(
cudnn_runtime_version
>=
8903
&&
sm_arch_
>=
80
&&
sm_arch_
<
100
)
||
(
cudnn_runtime_version
>=
90700
&&
sm_arch_
>=
80
))
&&
// sequence length
// sequence length
((
cudnn_runtime_version
<
90000
&&
max_seqlen_q
%
64
==
0
&&
max_seqlen_kv
%
64
==
0
)
||
((
cudnn_runtime_version
<
90000
&&
max_seqlen_q
%
64
==
0
&&
max_seqlen_kv
%
64
==
0
)
||
(
cudnn_runtime_version
>=
90000
))
&&
(
cudnn_runtime_version
>=
90000
))
&&
...
@@ -229,11 +227,28 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
...
@@ -229,11 +227,28 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
((
cudnn_runtime_version
<
8907
&&
num_attn_heads
==
num_gqa_groups
)
||
((
cudnn_runtime_version
<
8907
&&
num_attn_heads
==
num_gqa_groups
)
||
(
cudnn_runtime_version
>=
8907
))
&&
(
cudnn_runtime_version
>=
8907
))
&&
// head dimension
// head dimension
((
head_dim_qk
<=
128
&&
head_dim_qk
%
8
==
0
&&
head_dim_v
<=
128
&&
head_dim_v
%
8
==
0
)
||
// multiples of 8
// TODO (cyang): add is_training to nvte_get_fused_attn_backend
(
head_dim_qk
%
8
==
0
&&
head_dim_v
%
8
==
0
&&
// d=256 only supported for forward
// <= 128
(
sm_arch_
>=
90
&&
cudnn_runtime_version
>=
90000
&&
head_dim_qk
<=
256
&&
((
head_dim_qk
<=
128
&&
head_dim_v
<=
128
)
||
head_dim_qk
%
8
==
0
&&
head_dim_v
<=
256
&&
head_dim_v
%
8
==
0
))
&&
// 9.1: <= 256 + Hopper + fprop
// 9.5: <= 256 + Hopper + bprop
(
head_dim_qk
<=
256
&&
head_dim_v
<=
256
&&
((
!
is_training
&&
sm_arch_
==
90
&&
cudnn_runtime_version
>=
90100
)
||
(
is_training
&&
sm_arch_
==
90
&&
cudnn_runtime_version
>=
90500
)))
||
// 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1
(
!
is_training
&&
sm_arch_
>=
100
&&
cudnn_runtime_version
>=
90900
&&
max_seqlen_q
>
1
&&
layout_group
!=
NVTE_QKV_Layout_Group
::
NVTE_Paged_KV_HD_HD_HD
)
||
// 9.10: any head_dim + any arch + fprop + paged
// 9.10: any head_dim + any arch + fprop + non_paged + sq > 1
// 9.10: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM}
(
!
is_training
&&
cudnn_runtime_version
>=
91000
&&
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_Paged_KV_HD_HD_HD
||
max_seqlen_q
>
1
||
(
max_seqlen_q
==
1
&&
attn_mask_type
!=
NVTE_Mask_Type
::
NVTE_CAUSAL_MASK
&&
attn_mask_type
!=
NVTE_Mask_Type
::
NVTE_PADDING_CAUSAL_MASK
)))
||
// 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged
(
head_dim_qk
==
192
&&
head_dim_v
==
128
&&
is_training
&&
sm_arch_
>=
100
&&
cudnn_runtime_version
>=
91100
)))
&&
// bias type
// bias type
((
cudnn_runtime_version
<
8906
&&
bias_type
==
NVTE_Bias_Type
::
NVTE_NO_BIAS
)
||
((
cudnn_runtime_version
<
8906
&&
bias_type
==
NVTE_Bias_Type
::
NVTE_NO_BIAS
)
||
(
cudnn_runtime_version
>=
8906
&&
(
cudnn_runtime_version
>=
8906
&&
...
@@ -392,14 +407,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
...
@@ -392,14 +407,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
NVTE_API_CALL
(
nvte_flash_attn_fwd_qkvpacked
);
NVTE_API_CALL
(
nvte_flash_attn_fwd_qkvpacked
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
const
Tensor
*
input_cu_seqlens
=
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens
);
const
Tensor
*
input_cu_seqlens
=
convertNVTE
Tensor
Check
(
cu_seqlens
);
const
Tensor
*
input_cu_seqlens_padded
=
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens_padded
);
const
Tensor
*
input_cu_seqlens_padded
=
convertNVTE
Tensor
Check
(
cu_seqlens_padded
);
const
Tensor
*
input_rng_state
=
reinterpret_cast
<
const
Tensor
*>
(
rng_state
);
const
Tensor
*
input_rng_state
=
convertNVTE
Tensor
Check
(
rng_state
);
const
Tensor
*
input_QKV
=
reinterpret_cast
<
const
Tensor
*>
(
QKV
);
const
Tensor
*
input_QKV
=
convertNVTE
Tensor
Check
(
QKV
);
const
Tensor
*
input_Bias
=
reinterpret_cast
<
const
Tensor
*>
(
Bias
);
const
Tensor
*
input_Bias
=
convertNVTE
Tensor
Check
(
Bias
);
Tensor
*
input_output_S
=
reinterpret_cast
<
Tensor
*>
(
S
);
Tensor
*
input_output_S
=
convertNVTE
Tensor
Check
(
S
);
Tensor
*
output_O
=
reinterpret_cast
<
Tensor
*>
(
O
);
Tensor
*
output_O
=
convertNVTE
Tensor
Check
(
O
);
Tensor
*
wkspace
=
reinterpret_cast
<
Tensor
*>
(
workspace
);
Tensor
*
wkspace
=
convertNVTE
Tensor
(
workspace
);
auto
ndim
=
input_QKV
->
data
.
shape
.
size
();
auto
ndim
=
input_QKV
->
data
.
shape
.
size
();
size_t
b
=
input_cu_seqlens
->
data
.
shape
[
0
]
-
1
;
size_t
b
=
input_cu_seqlens
->
data
.
shape
[
0
]
-
1
;
...
@@ -423,8 +438,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
...
@@ -423,8 +438,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const
NVTEDType
QKV_type
=
static_cast
<
NVTEDType
>
(
input_QKV
->
data
.
dtype
);
const
NVTEDType
QKV_type
=
static_cast
<
NVTEDType
>
(
input_QKV
->
data
.
dtype
);
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
QKV_type
,
QKV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
dropout
,
h
,
h
,
max_seqlen
,
is_training
,
QKV_type
,
QKV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
dropout
,
h
,
h
,
max_seqlen
,
d
,
d
,
window_size_left
,
window_size_right
);
max_seqlen
,
max_seqlen
,
d
,
d
,
window_size_left
,
window_size_right
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
#if (CUDNN_VERSION >= 8901)
...
@@ -472,16 +487,16 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
...
@@ -472,16 +487,16 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
NVTE_API_CALL
(
nvte_flash_attn_bwd_qkvpacked
);
NVTE_API_CALL
(
nvte_flash_attn_bwd_qkvpacked
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
const
Tensor
*
input_cu_seqlens
=
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens
);
const
Tensor
*
input_cu_seqlens
=
convertNVTE
Tensor
Check
(
cu_seqlens
);
const
Tensor
*
input_cu_seqlens_padded
=
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens_padded
);
const
Tensor
*
input_cu_seqlens_padded
=
convertNVTE
Tensor
Check
(
cu_seqlens_padded
);
const
Tensor
*
input_QKV
=
reinterpret_cast
<
const
Tensor
*>
(
QKV
);
const
Tensor
*
input_QKV
=
convertNVTE
Tensor
Check
(
QKV
);
const
Tensor
*
input_O
=
reinterpret_cast
<
const
Tensor
*>
(
O
);
const
Tensor
*
input_O
=
convertNVTE
Tensor
Check
(
O
);
const
Tensor
*
input_dO
=
reinterpret_cast
<
const
Tensor
*>
(
dO
);
const
Tensor
*
input_dO
=
convertNVTE
Tensor
Check
(
dO
);
const
Tensor
*
input_S
=
reinterpret_cast
<
const
Tensor
*>
(
S
);
const
Tensor
*
input_S
=
convertNVTE
Tensor
Check
(
S
);
Tensor
*
input_output_dP
=
reinterpret_cast
<
Tensor
*>
(
dP
);
Tensor
*
input_output_dP
=
convertNVTE
Tensor
Check
(
dP
);
Tensor
*
output_dQKV
=
reinterpret_cast
<
Tensor
*>
(
dQKV
);
Tensor
*
output_dQKV
=
convertNVTE
Tensor
Check
(
dQKV
);
Tensor
*
output_dBias
=
reinterpret_cast
<
Tensor
*>
(
dBias
);
Tensor
*
output_dBias
=
convertNVTE
Tensor
Check
(
dBias
);
Tensor
*
wkspace
=
reinterpret_cast
<
Tensor
*>
(
workspace
);
Tensor
*
wkspace
=
convertNVTE
Tensor
(
workspace
);
auto
ndim
=
input_QKV
->
data
.
shape
.
size
();
auto
ndim
=
input_QKV
->
data
.
shape
.
size
();
size_t
b
=
input_cu_seqlens
->
data
.
shape
[
0
]
-
1
;
size_t
b
=
input_cu_seqlens
->
data
.
shape
[
0
]
-
1
;
...
@@ -505,12 +520,12 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
...
@@ -505,12 +520,12 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
const
NVTEDType
QKV_type
=
static_cast
<
NVTEDType
>
(
input_QKV
->
data
.
dtype
);
const
NVTEDType
QKV_type
=
static_cast
<
NVTEDType
>
(
input_QKV
->
data
.
dtype
);
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
QKV_type
,
QKV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
dropout
,
h
,
h
,
max_seqlen
,
true
,
QKV_type
,
QKV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
dropout
,
h
,
h
,
max_seqlen
,
max_seqlen
,
d
,
d
,
window_size_left
,
window_size_right
);
max_seqlen
,
d
,
d
,
window_size_left
,
window_size_right
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
#if (CUDNN_VERSION >= 8901)
Tensor
*
output_S
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
fused_attn_max_512_bwd_qkvpacked
(
fused_attn_max_512_bwd_qkvpacked
(
b
,
h
,
max_seqlen
,
d
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
input_QKV
,
b
,
h
,
max_seqlen
,
d
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
input_QKV
,
input_dO
,
output_S
,
output_dQKV
,
output_dBias
,
input_cu_seqlens
,
wkspace
,
stream
,
handle
);
input_dO
,
output_S
,
output_dQKV
,
output_dBias
,
input_cu_seqlens
,
wkspace
,
stream
,
handle
);
...
@@ -519,13 +534,13 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
...
@@ -519,13 +534,13 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
#endif
#endif
}
else
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_arbitrary_seqlen
)
{
}
else
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_arbitrary_seqlen
)
{
#if (CUDNN_VERSION >= 8900)
#if (CUDNN_VERSION >= 8900)
Tensor
*
output_S
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
input_Bias
,
*
input_rng_state
;
Tensor
*
input_Bias
,
*
input_rng_state
;
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
input_rng_state
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
input_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
input_Bias
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
2
]);
input_Bias
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
2
]);
}
else
{
}
else
{
input_rng_state
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
input_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
}
}
fused_attn_arbitrary_seqlen_bwd_qkvpacked
(
fused_attn_arbitrary_seqlen_bwd_qkvpacked
(
b
,
h
,
max_seqlen
,
d
,
t
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
b
,
h
,
max_seqlen
,
d
,
t
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
...
@@ -540,9 +555,9 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
...
@@ -540,9 +555,9 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
#endif
#endif
}
else
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_FP8
)
{
}
else
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_FP8
)
{
#if (CUDNN_VERSION >= 8900)
#if (CUDNN_VERSION >= 8900)
const
Tensor
*
input_M
=
reinterpret_cast
<
const
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
const
Tensor
*
input_M
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
const
Tensor
*
input_ZInv
=
reinterpret_cast
<
const
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
const
Tensor
*
input_ZInv
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
const
Tensor
*
input_rng_state
=
reinterpret_cast
<
const
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
2
]);
const
Tensor
*
input_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
2
]);
fused_attn_fp8_bwd_qkvpacked
(
b
,
h
,
max_seqlen
,
d
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
fused_attn_fp8_bwd_qkvpacked
(
b
,
h
,
max_seqlen
,
d
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
input_QKV
,
input_O
,
input_dO
,
input_M
,
input_ZInv
,
attn_mask_type
,
input_QKV
,
input_O
,
input_dO
,
input_M
,
input_ZInv
,
input_S
,
input_output_dP
,
output_dQKV
,
input_cu_seqlens
,
input_S
,
input_output_dP
,
output_dQKV
,
input_cu_seqlens
,
...
@@ -566,19 +581,19 @@ void nvte_fused_attn_fwd_kvpacked(
...
@@ -566,19 +581,19 @@ void nvte_fused_attn_fwd_kvpacked(
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_flash_attn_fwd_kvpacked
);
NVTE_API_CALL
(
nvte_flash_attn_fwd_kvpacked
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
const
Tensor
*
input_cu_seqlens_q
=
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens_q
);
const
Tensor
*
input_cu_seqlens_q
=
convertNVTE
Tensor
Check
(
cu_seqlens_q
);
const
Tensor
*
input_cu_seqlens_kv
=
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens_kv
);
const
Tensor
*
input_cu_seqlens_kv
=
convertNVTE
Tensor
Check
(
cu_seqlens_kv
);
const
Tensor
*
input_cu_seqlens_q_padded
=
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens_q_padded
);
const
Tensor
*
input_cu_seqlens_q_padded
=
convertNVTE
Tensor
Check
(
cu_seqlens_q_padded
);
const
Tensor
*
input_cu_seqlens_kv_padded
=
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens_kv_padded
);
const
Tensor
*
input_cu_seqlens_kv_padded
=
convertNVTE
Tensor
Check
(
cu_seqlens_kv_padded
);
const
Tensor
*
input_page_table_k
=
reinterpret_cast
<
const
Tensor
*>
(
page_table_k
);
const
Tensor
*
input_page_table_k
=
convertNVTE
Tensor
Check
(
page_table_k
);
const
Tensor
*
input_page_table_v
=
reinterpret_cast
<
const
Tensor
*>
(
page_table_v
);
const
Tensor
*
input_page_table_v
=
convertNVTE
Tensor
Check
(
page_table_v
);
const
Tensor
*
input_rng_state
=
reinterpret_cast
<
const
Tensor
*>
(
rng_state
);
const
Tensor
*
input_rng_state
=
convertNVTE
Tensor
Check
(
rng_state
);
const
Tensor
*
input_Q
=
reinterpret_cast
<
const
Tensor
*>
(
Q
);
const
Tensor
*
input_Q
=
convertNVTE
Tensor
Check
(
Q
);
const
Tensor
*
input_KV
=
reinterpret_cast
<
const
Tensor
*>
(
KV
);
const
Tensor
*
input_KV
=
convertNVTE
Tensor
Check
(
KV
);
const
Tensor
*
input_Bias
=
reinterpret_cast
<
const
Tensor
*>
(
Bias
);
const
Tensor
*
input_Bias
=
convertNVTE
Tensor
Check
(
Bias
);
Tensor
*
input_output_S
=
reinterpret_cast
<
Tensor
*>
(
S
);
Tensor
*
input_output_S
=
convertNVTE
Tensor
Check
(
S
);
Tensor
*
output_O
=
reinterpret_cast
<
Tensor
*>
(
O
);
Tensor
*
output_O
=
convertNVTE
Tensor
Check
(
O
);
Tensor
*
wkspace
=
reinterpret_cast
<
Tensor
*>
(
workspace
);
Tensor
*
wkspace
=
convertNVTE
Tensor
(
workspace
);
size_t
b
=
input_cu_seqlens_q
->
data
.
shape
[
0
]
-
1
;
size_t
b
=
input_cu_seqlens_q
->
data
.
shape
[
0
]
-
1
;
auto
ndim
=
input_Q
->
data
.
shape
.
size
();
auto
ndim
=
input_Q
->
data
.
shape
.
size
();
...
@@ -636,8 +651,8 @@ void nvte_fused_attn_fwd_kvpacked(
...
@@ -636,8 +651,8 @@ void nvte_fused_attn_fwd_kvpacked(
const
NVTEDType
KV_type
=
static_cast
<
NVTEDType
>
(
input_KV
->
data
.
dtype
);
const
NVTEDType
KV_type
=
static_cast
<
NVTEDType
>
(
input_KV
->
data
.
dtype
);
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
is_training
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_kv
,
d
,
d
,
window_size_left
,
window_size_right
);
max_seqlen_q
,
max_seqlen_kv
,
d
,
d
,
window_size_left
,
window_size_right
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
#if (CUDNN_VERSION >= 8901)
...
@@ -686,20 +701,20 @@ void nvte_fused_attn_bwd_kvpacked(
...
@@ -686,20 +701,20 @@ void nvte_fused_attn_bwd_kvpacked(
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_flash_attn_bwd_kvpacked
);
NVTE_API_CALL
(
nvte_flash_attn_bwd_kvpacked
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
const
Tensor
*
input_cu_seqlens_q
=
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens_q
);
const
Tensor
*
input_cu_seqlens_q
=
convertNVTE
Tensor
Check
(
cu_seqlens_q
);
const
Tensor
*
input_cu_seqlens_kv
=
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens_kv
);
const
Tensor
*
input_cu_seqlens_kv
=
convertNVTE
Tensor
Check
(
cu_seqlens_kv
);
const
Tensor
*
input_cu_seqlens_q_padded
=
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens_q_padded
);
const
Tensor
*
input_cu_seqlens_q_padded
=
convertNVTE
Tensor
Check
(
cu_seqlens_q_padded
);
const
Tensor
*
input_cu_seqlens_kv_padded
=
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens_kv_padded
);
const
Tensor
*
input_cu_seqlens_kv_padded
=
convertNVTE
Tensor
Check
(
cu_seqlens_kv_padded
);
const
Tensor
*
input_Q
=
reinterpret_cast
<
const
Tensor
*>
(
Q
);
const
Tensor
*
input_Q
=
convertNVTE
Tensor
Check
(
Q
);
const
Tensor
*
input_KV
=
reinterpret_cast
<
const
Tensor
*>
(
KV
);
const
Tensor
*
input_KV
=
convertNVTE
Tensor
Check
(
KV
);
const
Tensor
*
input_O
=
reinterpret_cast
<
const
Tensor
*>
(
O
);
const
Tensor
*
input_O
=
convertNVTE
Tensor
Check
(
O
);
const
Tensor
*
input_dO
=
reinterpret_cast
<
const
Tensor
*>
(
dO
);
const
Tensor
*
input_dO
=
convertNVTE
Tensor
Check
(
dO
);
const
Tensor
*
input_S
=
reinterpret_cast
<
const
Tensor
*>
(
S
);
const
Tensor
*
input_S
=
convertNVTE
Tensor
Check
(
S
);
Tensor
*
input_output_dP
=
reinterpret_cast
<
Tensor
*>
(
dP
);
Tensor
*
input_output_dP
=
convertNVTE
Tensor
Check
(
dP
);
Tensor
*
output_dQ
=
reinterpret_cast
<
Tensor
*>
(
dQ
);
Tensor
*
output_dQ
=
convertNVTE
Tensor
Check
(
dQ
);
Tensor
*
output_dKV
=
reinterpret_cast
<
Tensor
*>
(
dKV
);
Tensor
*
output_dKV
=
convertNVTE
Tensor
Check
(
dKV
);
Tensor
*
output_dBias
=
reinterpret_cast
<
Tensor
*>
(
dBias
);
Tensor
*
output_dBias
=
convertNVTE
Tensor
Check
(
dBias
);
Tensor
*
wkspace
=
reinterpret_cast
<
Tensor
*>
(
workspace
);
Tensor
*
wkspace
=
convertNVTE
Tensor
(
workspace
);
size_t
b
=
input_cu_seqlens_q
->
data
.
shape
[
0
]
-
1
;
size_t
b
=
input_cu_seqlens_q
->
data
.
shape
[
0
]
-
1
;
auto
ndim
=
input_Q
->
data
.
shape
.
size
();
auto
ndim
=
input_Q
->
data
.
shape
.
size
();
...
@@ -731,12 +746,12 @@ void nvte_fused_attn_bwd_kvpacked(
...
@@ -731,12 +746,12 @@ void nvte_fused_attn_bwd_kvpacked(
const
NVTEDType
KV_type
=
static_cast
<
NVTEDType
>
(
input_KV
->
data
.
dtype
);
const
NVTEDType
KV_type
=
static_cast
<
NVTEDType
>
(
input_KV
->
data
.
dtype
);
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
true
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_kv
,
d
,
d
,
window_size_left
,
window_size_right
);
max_seqlen_q
,
max_seqlen_kv
,
d
,
d
,
window_size_left
,
window_size_right
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
#if (CUDNN_VERSION >= 8901)
Tensor
*
output_S
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
fused_attn_max_512_bwd_kvpacked
(
fused_attn_max_512_bwd_kvpacked
(
b
,
h_q
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
b
,
h_q
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
input_Q
,
input_KV
,
input_dO
,
output_S
,
output_dQ
,
output_dKV
,
output_dBias
,
attn_mask_type
,
input_Q
,
input_KV
,
input_dO
,
output_S
,
output_dQ
,
output_dKV
,
output_dBias
,
...
@@ -746,13 +761,13 @@ void nvte_fused_attn_bwd_kvpacked(
...
@@ -746,13 +761,13 @@ void nvte_fused_attn_bwd_kvpacked(
#endif
#endif
}
else
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_arbitrary_seqlen
)
{
}
else
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_arbitrary_seqlen
)
{
#if (CUDNN_VERSION >= 8903)
#if (CUDNN_VERSION >= 8903)
Tensor
*
output_S
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
input_Bias
,
*
input_rng_state
;
Tensor
*
input_Bias
,
*
input_rng_state
;
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
input_rng_state
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
input_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
input_Bias
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
2
]);
input_Bias
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
2
]);
}
else
{
}
else
{
input_rng_state
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
input_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
}
}
fused_attn_arbitrary_seqlen_bwd_kvpacked
(
fused_attn_arbitrary_seqlen_bwd_kvpacked
(
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
t_q
,
t_kv
,
attn_scale
,
dropout
,
qkv_layout
,
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
t_q
,
t_kv
,
attn_scale
,
dropout
,
qkv_layout
,
...
@@ -768,9 +783,9 @@ void nvte_fused_attn_bwd_kvpacked(
...
@@ -768,9 +783,9 @@ void nvte_fused_attn_bwd_kvpacked(
#endif
#endif
}
else
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_FP8
)
{
}
else
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_FP8
)
{
#if (CUDNN_VERSION >= 8900)
#if (CUDNN_VERSION >= 8900)
const
Tensor
*
input_M
=
reinterpret_cast
<
const
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
const
Tensor
*
input_M
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
const
Tensor
*
input_ZInv
=
reinterpret_cast
<
const
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
const
Tensor
*
input_ZInv
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
const
Tensor
*
input_rng_state
=
reinterpret_cast
<
const
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
2
]);
const
Tensor
*
input_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
2
]);
fused_attn_fp8_bwd_kvpacked
(
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
attn_scale
,
dropout
,
fused_attn_fp8_bwd_kvpacked
(
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
input_Q
,
input_KV
,
input_O
,
qkv_layout
,
bias_type
,
attn_mask_type
,
input_Q
,
input_KV
,
input_O
,
input_dO
,
input_M
,
input_ZInv
,
input_S
,
input_output_dP
,
output_dQ
,
input_dO
,
input_M
,
input_ZInv
,
input_S
,
input_output_dP
,
output_dQ
,
...
@@ -797,20 +812,20 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
...
@@ -797,20 +812,20 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_flash_attn_fwd
);
NVTE_API_CALL
(
nvte_flash_attn_fwd
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
const
Tensor
*
input_cu_seqlens_q
=
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens_q
);
const
Tensor
*
input_cu_seqlens_q
=
convertNVTE
Tensor
Check
(
cu_seqlens_q
);
const
Tensor
*
input_cu_seqlens_kv
=
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens_kv
);
const
Tensor
*
input_cu_seqlens_kv
=
convertNVTE
Tensor
Check
(
cu_seqlens_kv
);
const
Tensor
*
input_cu_seqlens_q_padded
=
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens_q_padded
);
const
Tensor
*
input_cu_seqlens_q_padded
=
convertNVTE
Tensor
Check
(
cu_seqlens_q_padded
);
const
Tensor
*
input_cu_seqlens_kv_padded
=
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens_kv_padded
);
const
Tensor
*
input_cu_seqlens_kv_padded
=
convertNVTE
Tensor
Check
(
cu_seqlens_kv_padded
);
const
Tensor
*
input_page_table_k
=
reinterpret_cast
<
const
Tensor
*>
(
page_table_k
);
const
Tensor
*
input_page_table_k
=
convertNVTE
Tensor
Check
(
page_table_k
);
const
Tensor
*
input_page_table_v
=
reinterpret_cast
<
const
Tensor
*>
(
page_table_v
);
const
Tensor
*
input_page_table_v
=
convertNVTE
Tensor
Check
(
page_table_v
);
const
Tensor
*
input_rng_state
=
reinterpret_cast
<
const
Tensor
*>
(
rng_state
);
const
Tensor
*
input_rng_state
=
convertNVTE
Tensor
Check
(
rng_state
);
const
Tensor
*
input_Q
=
reinterpret_cast
<
const
Tensor
*>
(
Q
);
const
Tensor
*
input_Q
=
convertNVTE
Tensor
Check
(
Q
);
const
Tensor
*
input_K
=
reinterpret_cast
<
const
Tensor
*>
(
K
);
const
Tensor
*
input_K
=
convertNVTE
Tensor
Check
(
K
);
const
Tensor
*
input_V
=
reinterpret_cast
<
const
Tensor
*>
(
V
);
const
Tensor
*
input_V
=
convertNVTE
Tensor
Check
(
V
);
const
Tensor
*
input_Bias
=
reinterpret_cast
<
const
Tensor
*>
(
Bias
);
const
Tensor
*
input_Bias
=
convertNVTE
Tensor
Check
(
Bias
);
Tensor
*
input_output_S
=
reinterpret_cast
<
Tensor
*>
(
S
);
Tensor
*
input_output_S
=
convertNVTE
Tensor
Check
(
S
);
Tensor
*
output_O
=
reinterpret_cast
<
Tensor
*>
(
O
);
Tensor
*
output_O
=
convertNVTE
Tensor
Check
(
O
);
Tensor
*
wkspace
=
reinterpret_cast
<
Tensor
*>
(
workspace
);
Tensor
*
wkspace
=
convertNVTE
Tensor
(
workspace
);
auto
ndim
=
input_Q
->
data
.
shape
.
size
();
auto
ndim
=
input_Q
->
data
.
shape
.
size
();
auto
ndim_kv
=
input_K
->
data
.
shape
.
size
();
auto
ndim_kv
=
input_K
->
data
.
shape
.
size
();
...
@@ -862,8 +877,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
...
@@ -862,8 +877,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const
NVTEDType
KV_type
=
static_cast
<
NVTEDType
>
(
input_K
->
data
.
dtype
);
const
NVTEDType
KV_type
=
static_cast
<
NVTEDType
>
(
input_K
->
data
.
dtype
);
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
is_training
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_kv
,
d_qk
,
d_v
,
window_size_left
,
window_size_right
);
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
window_size_left
,
window_size_right
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
#if (CUDNN_VERSION >= 8901)
...
@@ -914,22 +929,22 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
...
@@ -914,22 +929,22 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_flash_attn_bwd
);
NVTE_API_CALL
(
nvte_flash_attn_bwd
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
const
Tensor
*
input_cu_seqlens_q
=
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens_q
);
const
Tensor
*
input_cu_seqlens_q
=
convertNVTE
Tensor
Check
(
cu_seqlens_q
);
const
Tensor
*
input_cu_seqlens_kv
=
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens_kv
);
const
Tensor
*
input_cu_seqlens_kv
=
convertNVTE
Tensor
Check
(
cu_seqlens_kv
);
const
Tensor
*
input_cu_seqlens_q_padded
=
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens_q_padded
);
const
Tensor
*
input_cu_seqlens_q_padded
=
convertNVTE
Tensor
Check
(
cu_seqlens_q_padded
);
const
Tensor
*
input_cu_seqlens_kv_padded
=
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens_kv_padded
);
const
Tensor
*
input_cu_seqlens_kv_padded
=
convertNVTE
Tensor
Check
(
cu_seqlens_kv_padded
);
const
Tensor
*
input_Q
=
reinterpret_cast
<
const
Tensor
*>
(
Q
);
const
Tensor
*
input_Q
=
convertNVTE
Tensor
Check
(
Q
);
const
Tensor
*
input_K
=
reinterpret_cast
<
const
Tensor
*>
(
K
);
const
Tensor
*
input_K
=
convertNVTE
Tensor
Check
(
K
);
const
Tensor
*
input_V
=
reinterpret_cast
<
const
Tensor
*>
(
V
);
const
Tensor
*
input_V
=
convertNVTE
Tensor
Check
(
V
);
const
Tensor
*
input_O
=
reinterpret_cast
<
const
Tensor
*>
(
O
);
const
Tensor
*
input_O
=
convertNVTE
Tensor
Check
(
O
);
const
Tensor
*
input_dO
=
reinterpret_cast
<
const
Tensor
*>
(
dO
);
const
Tensor
*
input_dO
=
convertNVTE
Tensor
Check
(
dO
);
const
Tensor
*
input_S
=
reinterpret_cast
<
const
Tensor
*>
(
S
);
const
Tensor
*
input_S
=
convertNVTE
Tensor
Check
(
S
);
Tensor
*
input_output_dP
=
reinterpret_cast
<
Tensor
*>
(
dP
);
Tensor
*
input_output_dP
=
convertNVTE
Tensor
Check
(
dP
);
Tensor
*
output_dQ
=
reinterpret_cast
<
Tensor
*>
(
dQ
);
Tensor
*
output_dQ
=
convertNVTE
Tensor
Check
(
dQ
);
Tensor
*
output_dK
=
reinterpret_cast
<
Tensor
*>
(
dK
);
Tensor
*
output_dK
=
convertNVTE
Tensor
Check
(
dK
);
Tensor
*
output_dV
=
reinterpret_cast
<
Tensor
*>
(
dV
);
Tensor
*
output_dV
=
convertNVTE
Tensor
Check
(
dV
);
Tensor
*
output_dBias
=
reinterpret_cast
<
Tensor
*>
(
dBias
);
Tensor
*
output_dBias
=
convertNVTE
Tensor
Check
(
dBias
);
Tensor
*
wkspace
=
reinterpret_cast
<
Tensor
*>
(
workspace
);
Tensor
*
wkspace
=
convertNVTE
Tensor
(
workspace
);
auto
ndim
=
input_Q
->
data
.
shape
.
size
();
auto
ndim
=
input_Q
->
data
.
shape
.
size
();
auto
ndim_kv
=
input_K
->
data
.
shape
.
size
();
auto
ndim_kv
=
input_K
->
data
.
shape
.
size
();
...
@@ -954,12 +969,12 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
...
@@ -954,12 +969,12 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const
NVTEDType
KV_type
=
static_cast
<
NVTEDType
>
(
input_K
->
data
.
dtype
);
const
NVTEDType
KV_type
=
static_cast
<
NVTEDType
>
(
input_K
->
data
.
dtype
);
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
true
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_kv
,
d_qk
,
d_v
,
window_size_left
,
window_size_right
);
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
window_size_left
,
window_size_right
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
#if (CUDNN_VERSION >= 8901)
Tensor
*
output_S
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
fused_attn_max_512_bwd
(
b
,
h_q
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
attn_scale
,
dropout
,
fused_attn_max_512_bwd
(
b
,
h_q
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
input_Q
,
input_K
,
input_V
,
qkv_layout
,
bias_type
,
attn_mask_type
,
input_Q
,
input_K
,
input_V
,
input_dO
,
output_S
,
output_dQ
,
output_dK
,
output_dV
,
output_dBias
,
input_dO
,
output_S
,
output_dQ
,
output_dK
,
output_dV
,
output_dBias
,
...
@@ -969,13 +984,13 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
...
@@ -969,13 +984,13 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
#endif
#endif
}
else
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_arbitrary_seqlen
)
{
}
else
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_arbitrary_seqlen
)
{
#if (CUDNN_VERSION >= 8900)
#if (CUDNN_VERSION >= 8900)
Tensor
*
output_S
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
input_Bias
,
*
input_rng_state
;
Tensor
*
input_Bias
,
*
input_rng_state
;
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
input_rng_state
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
input_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
input_Bias
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
2
]);
input_Bias
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
2
]);
}
else
{
}
else
{
input_rng_state
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
input_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
}
}
fused_attn_arbitrary_seqlen_bwd
(
fused_attn_arbitrary_seqlen_bwd
(
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
t_q
,
t_kv
,
attn_scale
,
dropout
,
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
t_q
,
t_kv
,
attn_scale
,
dropout
,
...
@@ -991,9 +1006,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
...
@@ -991,9 +1006,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
#endif
#endif
}
else
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_FP8
)
{
}
else
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_FP8
)
{
#if (CUDNN_VERSION >= 8900)
#if (CUDNN_VERSION >= 8900)
const
Tensor
*
input_M
=
reinterpret_cast
<
const
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
const
Tensor
*
input_M
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
const
Tensor
*
input_ZInv
=
reinterpret_cast
<
const
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
const
Tensor
*
input_ZInv
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
const
Tensor
*
input_rng_state
=
reinterpret_cast
<
const
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
2
]);
const
Tensor
*
input_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
2
]);
fused_attn_fp8_bwd
(
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
attn_scale
,
dropout
,
fused_attn_fp8_bwd
(
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
input_Q
,
input_K
,
input_V
,
input_O
,
qkv_layout
,
bias_type
,
attn_mask_type
,
input_Q
,
input_K
,
input_V
,
input_O
,
input_dO
,
input_M
,
input_ZInv
,
input_S
,
input_output_dP
,
output_dQ
,
input_dO
,
input_M
,
input_ZInv
,
input_S
,
input_output_dP
,
output_dQ
,
...
...
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
View file @
2b05e121
...
@@ -377,7 +377,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
...
@@ -377,7 +377,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
const
size_t
num_bytes_per_seqlen
=
alignTo
<
16
>
(
b
*
sizeof
(
int32_t
));
const
size_t
num_bytes_per_seqlen
=
alignTo
<
16
>
(
b
*
sizeof
(
int32_t
));
const
size_t
actual_seqlen_workspace_size
=
is_padding
?
2
*
num_bytes_per_seqlen
:
0
;
const
size_t
actual_seqlen_workspace_size
=
is_padding
?
2
*
num_bytes_per_seqlen
:
0
;
const
size_t
num_bytes_per_ragged_offset
=
const
size_t
num_bytes_per_ragged_offset
=
alignTo
<
16
>
((
b
+
1
)
*
typeTo
Size
(
ragged_offset_type
));
alignTo
<
16
>
((
(
b
+
1
)
*
typeTo
NumBits
(
ragged_offset_type
))
/
8
)
;
size_t
seqlen_offsets_workspace_size
=
0
;
size_t
seqlen_offsets_workspace_size
=
0
;
if
(
is_ragged_q
||
is_ragged_kv
)
{
if
(
is_ragged_q
||
is_ragged_kv
)
{
size_t
count
=
2
*
(
static_cast
<
size_t
>
(
is_ragged_q
)
+
static_cast
<
size_t
>
(
is_ragged_kv
));
size_t
count
=
2
*
(
static_cast
<
size_t
>
(
is_ragged_q
)
+
static_cast
<
size_t
>
(
is_ragged_kv
));
...
@@ -831,7 +831,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
...
@@ -831,7 +831,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
const
size_t
num_bytes_per_seqlen
=
alignTo
<
16
>
(
b
*
sizeof
(
int32_t
));
const
size_t
num_bytes_per_seqlen
=
alignTo
<
16
>
(
b
*
sizeof
(
int32_t
));
const
size_t
actual_seqlen_workspace_size
=
is_padding
?
2
*
num_bytes_per_seqlen
:
0
;
const
size_t
actual_seqlen_workspace_size
=
is_padding
?
2
*
num_bytes_per_seqlen
:
0
;
const
size_t
num_bytes_per_ragged_offset
=
const
size_t
num_bytes_per_ragged_offset
=
alignTo
<
16
>
((
b
+
1
)
*
typeTo
Size
(
ragged_offset_type
));
alignTo
<
16
>
((
(
b
+
1
)
*
typeTo
NumBits
(
ragged_offset_type
))
/
8
)
;
size_t
seqlen_offsets_workspace_size
=
0
;
size_t
seqlen_offsets_workspace_size
=
0
;
if
(
is_ragged_q
||
is_ragged_kv
)
{
if
(
is_ragged_q
||
is_ragged_kv
)
{
size_t
count
=
2
*
(
static_cast
<
size_t
>
(
is_ragged_q
)
+
static_cast
<
size_t
>
(
is_ragged_kv
));
size_t
count
=
2
*
(
static_cast
<
size_t
>
(
is_ragged_q
)
+
static_cast
<
size_t
>
(
is_ragged_kv
));
...
@@ -957,9 +957,9 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
...
@@ -957,9 +957,9 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
NVTE_QKV_Format
qkv_format
=
nvte_get_qkv_format
(
qkv_layout
);
NVTE_QKV_Format
qkv_format
=
nvte_get_qkv_format
(
qkv_layout
);
size_t
stride
=
0
;
size_t
stride
=
0
;
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_3HD
)
{
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_3HD
)
{
stride
=
typeTo
Size
(
QKV_type
)
*
num_attn_heads
*
head_dim
;
stride
=
(
typeTo
NumBits
(
QKV_type
)
*
num_attn_heads
*
head_dim
)
/
8
;
}
else
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_H3D
)
{
}
else
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_H3D
)
{
stride
=
typeTo
Size
(
QKV_type
)
*
head_dim
;
stride
=
(
typeTo
NumBits
(
QKV_type
)
*
head_dim
)
/
8
;
}
}
void
*
devPtrQ
=
static_cast
<
void
*>
(
devPtrQKV
);
void
*
devPtrQ
=
static_cast
<
void
*>
(
devPtrQKV
);
void
*
devPtrK
=
static_cast
<
void
*>
(
static_cast
<
int8_t
*>
(
devPtrQKV
)
+
stride
);
void
*
devPtrK
=
static_cast
<
void
*>
(
static_cast
<
int8_t
*>
(
devPtrQKV
)
+
stride
);
...
@@ -990,7 +990,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
...
@@ -990,7 +990,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
const
auto
cudnn_runtime_version
=
cudnnGetVersion
();
const
auto
cudnn_runtime_version
=
cudnnGetVersion
();
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
Aux_CTX_Tensors
->
size
=
3
;
Aux_CTX_Tensors
->
size
=
3
;
Tensor
*
output_S
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
output_S
->
data
.
dptr
=
nullptr
;
output_S
->
data
.
dptr
=
nullptr
;
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
output_S
->
data
.
shape
=
{
max_tokens
,
num_attn_heads
,
1
};
output_S
->
data
.
shape
=
{
max_tokens
,
num_attn_heads
,
1
};
...
@@ -998,17 +998,17 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
...
@@ -998,17 +998,17 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
output_S
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen
,
1
};
output_S
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen
,
1
};
}
}
output_S
->
data
.
dtype
=
DType
::
kFloat32
;
output_S
->
data
.
dtype
=
DType
::
kFloat32
;
Tensor
*
output_rng_state
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
output_rng_state
->
data
.
dptr
=
nullptr
;
output_rng_state
->
data
.
dptr
=
nullptr
;
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
dtype
=
DType
::
kInt64
;
output_rng_state
->
data
.
dtype
=
DType
::
kInt64
;
Tensor
*
output_bias
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
2
]);
Tensor
*
output_bias
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
2
]);
output_bias
->
data
.
dptr
=
nullptr
;
output_bias
->
data
.
dptr
=
nullptr
;
output_bias
->
data
.
shape
=
{
bias_b
,
bias_h
,
max_seqlen
,
max_seqlen
};
output_bias
->
data
.
shape
=
{
bias_b
,
bias_h
,
max_seqlen
,
max_seqlen
};
output_bias
->
data
.
dtype
=
QKV_type
;
output_bias
->
data
.
dtype
=
QKV_type
;
}
else
{
}
else
{
Aux_CTX_Tensors
->
size
=
2
;
Aux_CTX_Tensors
->
size
=
2
;
Tensor
*
output_S
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
output_S
->
data
.
dptr
=
nullptr
;
output_S
->
data
.
dptr
=
nullptr
;
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
output_S
->
data
.
shape
=
{
max_tokens
,
num_attn_heads
,
1
};
output_S
->
data
.
shape
=
{
max_tokens
,
num_attn_heads
,
1
};
...
@@ -1016,22 +1016,22 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
...
@@ -1016,22 +1016,22 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
output_S
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen
,
1
};
output_S
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen
,
1
};
}
}
output_S
->
data
.
dtype
=
DType
::
kFloat32
;
output_S
->
data
.
dtype
=
DType
::
kFloat32
;
Tensor
*
output_rng_state
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
output_rng_state
->
data
.
dptr
=
nullptr
;
output_rng_state
->
data
.
dptr
=
nullptr
;
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
dtype
=
DType
::
kInt64
;
output_rng_state
->
data
.
dtype
=
DType
::
kInt64
;
}
}
}
else
if
(
Aux_CTX_Tensors
->
size
==
2
)
{
}
else
if
(
Aux_CTX_Tensors
->
size
==
2
)
{
Tensor
*
output_S
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
devPtrS
=
output_S
->
data
.
dptr
;
devPtrS
=
output_S
->
data
.
dptr
;
Tensor
*
output_rng_state
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
}
else
if
(
Aux_CTX_Tensors
->
size
==
3
)
{
}
else
if
(
Aux_CTX_Tensors
->
size
==
3
)
{
Tensor
*
output_S
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
devPtrS
=
output_S
->
data
.
dptr
;
devPtrS
=
output_S
->
data
.
dptr
;
Tensor
*
output_rng_state
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
Tensor
*
output_bias
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
2
]);
Tensor
*
output_bias
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
2
]);
output_bias
->
data
.
dptr
=
devPtrBias
;
output_bias
->
data
.
dptr
=
devPtrBias
;
}
else
{
}
else
{
NVTE_ERROR
(
"Unexpected Aux_CTX_Tensors->size."
);
NVTE_ERROR
(
"Unexpected Aux_CTX_Tensors->size."
);
...
@@ -1082,9 +1082,9 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
...
@@ -1082,9 +1082,9 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
NVTE_QKV_Layout_Group
layout_group
=
nvte_get_qkv_layout_group
(
qkv_layout
);
NVTE_QKV_Layout_Group
layout_group
=
nvte_get_qkv_layout_group
(
qkv_layout
);
size_t
stride
=
0
;
size_t
stride
=
0
;
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_3HD
)
{
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_3HD
)
{
stride
=
typeTo
Size
(
QKV_type
)
*
num_attn_heads
*
head_dim
;
stride
=
(
typeTo
NumBits
(
QKV_type
)
*
num_attn_heads
*
head_dim
)
/
8
;
}
else
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_H3D
)
{
}
else
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_H3D
)
{
stride
=
typeTo
Size
(
QKV_type
)
*
head_dim
;
stride
=
(
typeTo
NumBits
(
QKV_type
)
*
head_dim
)
/
8
;
}
}
void
*
devPtrQ
=
devPtrQKV
;
void
*
devPtrQ
=
devPtrQKV
;
void
*
devPtrK
=
static_cast
<
void
*>
(
static_cast
<
int8_t
*>
(
devPtrQKV
)
+
stride
);
void
*
devPtrK
=
static_cast
<
void
*>
(
static_cast
<
int8_t
*>
(
devPtrQKV
)
+
stride
);
...
@@ -1173,9 +1173,9 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
...
@@ -1173,9 +1173,9 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
NVTE_QKV_Format
kv_format
=
nvte_get_kv_format
(
qkv_layout
);
NVTE_QKV_Format
kv_format
=
nvte_get_kv_format
(
qkv_layout
);
size_t
stride
=
0
;
size_t
stride
=
0
;
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_HD_2HD
)
{
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_HD_2HD
)
{
stride
=
typeTo
Size
(
QKV_type
)
*
num_gqa_groups
*
head_dim
;
stride
=
(
typeTo
NumBits
(
QKV_type
)
*
num_gqa_groups
*
head_dim
)
/
8
;
}
else
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_HD_H2D
)
{
}
else
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_HD_H2D
)
{
stride
=
typeTo
Size
(
QKV_type
)
*
head_dim
;
stride
=
(
typeTo
NumBits
(
QKV_type
)
*
head_dim
)
/
8
;
}
}
void
*
devPtrK
=
devPtrKV
;
void
*
devPtrK
=
devPtrKV
;
void
*
devPtrV
=
static_cast
<
void
*>
(
static_cast
<
int8_t
*>
(
devPtrKV
)
+
stride
);
void
*
devPtrV
=
static_cast
<
void
*>
(
static_cast
<
int8_t
*>
(
devPtrKV
)
+
stride
);
...
@@ -1216,7 +1216,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
...
@@ -1216,7 +1216,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
const
auto
cudnn_runtime_version
=
cudnnGetVersion
();
const
auto
cudnn_runtime_version
=
cudnnGetVersion
();
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
Aux_CTX_Tensors
->
size
=
3
;
Aux_CTX_Tensors
->
size
=
3
;
Tensor
*
output_S
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
output_S
->
data
.
dptr
=
nullptr
;
output_S
->
data
.
dptr
=
nullptr
;
if
(
q_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
if
(
q_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
output_S
->
data
.
shape
=
{
max_tokens_q
,
num_attn_heads
,
1
};
output_S
->
data
.
shape
=
{
max_tokens_q
,
num_attn_heads
,
1
};
...
@@ -1224,17 +1224,17 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
...
@@ -1224,17 +1224,17 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
output_S
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen_q
,
1
};
output_S
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen_q
,
1
};
}
}
output_S
->
data
.
dtype
=
DType
::
kFloat32
;
output_S
->
data
.
dtype
=
DType
::
kFloat32
;
Tensor
*
output_rng_state
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
output_rng_state
->
data
.
dptr
=
nullptr
;
output_rng_state
->
data
.
dptr
=
nullptr
;
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
dtype
=
DType
::
kInt64
;
output_rng_state
->
data
.
dtype
=
DType
::
kInt64
;
Tensor
*
output_bias
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
2
]);
Tensor
*
output_bias
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
2
]);
output_bias
->
data
.
dptr
=
nullptr
;
output_bias
->
data
.
dptr
=
nullptr
;
output_bias
->
data
.
shape
=
{
bias_b
,
bias_h
,
max_seqlen_q
,
max_seqlen_kv
};
output_bias
->
data
.
shape
=
{
bias_b
,
bias_h
,
max_seqlen_q
,
max_seqlen_kv
};
output_bias
->
data
.
dtype
=
QKV_type
;
output_bias
->
data
.
dtype
=
QKV_type
;
}
else
{
}
else
{
Aux_CTX_Tensors
->
size
=
2
;
Aux_CTX_Tensors
->
size
=
2
;
Tensor
*
output_S
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
output_S
->
data
.
dptr
=
nullptr
;
output_S
->
data
.
dptr
=
nullptr
;
if
(
q_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
if
(
q_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
output_S
->
data
.
shape
=
{
max_tokens_q
,
num_attn_heads
,
1
};
output_S
->
data
.
shape
=
{
max_tokens_q
,
num_attn_heads
,
1
};
...
@@ -1242,22 +1242,22 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
...
@@ -1242,22 +1242,22 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
output_S
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen_q
,
1
};
output_S
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen_q
,
1
};
}
}
output_S
->
data
.
dtype
=
DType
::
kFloat32
;
output_S
->
data
.
dtype
=
DType
::
kFloat32
;
Tensor
*
output_rng_state
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
output_rng_state
->
data
.
dptr
=
nullptr
;
output_rng_state
->
data
.
dptr
=
nullptr
;
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
dtype
=
DType
::
kInt64
;
output_rng_state
->
data
.
dtype
=
DType
::
kInt64
;
}
}
}
else
if
(
Aux_CTX_Tensors
->
size
==
2
)
{
}
else
if
(
Aux_CTX_Tensors
->
size
==
2
)
{
Tensor
*
output_S
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
devPtrS
=
output_S
->
data
.
dptr
;
devPtrS
=
output_S
->
data
.
dptr
;
Tensor
*
output_rng_state
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
}
else
if
(
Aux_CTX_Tensors
->
size
==
3
)
{
}
else
if
(
Aux_CTX_Tensors
->
size
==
3
)
{
Tensor
*
output_S
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
devPtrS
=
output_S
->
data
.
dptr
;
devPtrS
=
output_S
->
data
.
dptr
;
Tensor
*
output_rng_state
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
Tensor
*
output_bias
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
2
]);
Tensor
*
output_bias
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
2
]);
output_bias
->
data
.
dptr
=
devPtrBias
;
output_bias
->
data
.
dptr
=
devPtrBias
;
}
else
{
}
else
{
NVTE_ERROR
(
"Unexpected Aux_CTX_Tensors->size."
);
NVTE_ERROR
(
"Unexpected Aux_CTX_Tensors->size."
);
...
@@ -1313,9 +1313,9 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
...
@@ -1313,9 +1313,9 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
NVTE_QKV_Layout_Group
layout_group
=
nvte_get_qkv_layout_group
(
qkv_layout
);
NVTE_QKV_Layout_Group
layout_group
=
nvte_get_qkv_layout_group
(
qkv_layout
);
size_t
stride
=
0
;
size_t
stride
=
0
;
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_HD_2HD
)
{
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_HD_2HD
)
{
stride
=
typeTo
Size
(
QKV_type
)
*
num_gqa_groups
*
head_dim
;
stride
=
(
typeTo
NumBits
(
QKV_type
)
*
num_gqa_groups
*
head_dim
)
/
8
;
}
else
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_HD_H2D
)
{
}
else
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_HD_H2D
)
{
stride
=
typeTo
Size
(
QKV_type
)
*
head_dim
;
stride
=
(
typeTo
NumBits
(
QKV_type
)
*
head_dim
)
/
8
;
}
}
void
*
devPtrK
=
devPtrKV
;
void
*
devPtrK
=
devPtrKV
;
void
*
devPtrV
=
static_cast
<
void
*>
(
static_cast
<
int8_t
*>
(
devPtrKV
)
+
stride
);
void
*
devPtrV
=
static_cast
<
void
*>
(
static_cast
<
int8_t
*>
(
devPtrKV
)
+
stride
);
...
@@ -1446,7 +1446,7 @@ void fused_attn_arbitrary_seqlen_fwd(
...
@@ -1446,7 +1446,7 @@ void fused_attn_arbitrary_seqlen_fwd(
const
auto
cudnn_runtime_version
=
cudnnGetVersion
();
const
auto
cudnn_runtime_version
=
cudnnGetVersion
();
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
Aux_CTX_Tensors
->
size
=
3
;
Aux_CTX_Tensors
->
size
=
3
;
Tensor
*
output_S
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
output_S
->
data
.
dptr
=
nullptr
;
output_S
->
data
.
dptr
=
nullptr
;
if
(
q_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
if
(
q_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
output_S
->
data
.
shape
=
{
max_tokens_q
,
num_attn_heads
,
1
};
output_S
->
data
.
shape
=
{
max_tokens_q
,
num_attn_heads
,
1
};
...
@@ -1454,17 +1454,17 @@ void fused_attn_arbitrary_seqlen_fwd(
...
@@ -1454,17 +1454,17 @@ void fused_attn_arbitrary_seqlen_fwd(
output_S
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen_q
,
1
};
output_S
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen_q
,
1
};
}
}
output_S
->
data
.
dtype
=
DType
::
kFloat32
;
output_S
->
data
.
dtype
=
DType
::
kFloat32
;
Tensor
*
output_rng_state
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
output_rng_state
->
data
.
dptr
=
nullptr
;
output_rng_state
->
data
.
dptr
=
nullptr
;
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
dtype
=
DType
::
kInt64
;
output_rng_state
->
data
.
dtype
=
DType
::
kInt64
;
Tensor
*
output_bias
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
2
]);
Tensor
*
output_bias
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
2
]);
output_bias
->
data
.
dptr
=
nullptr
;
output_bias
->
data
.
dptr
=
nullptr
;
output_bias
->
data
.
shape
=
{
bias_b
,
bias_h
,
max_seqlen_q
,
max_seqlen_kv
};
output_bias
->
data
.
shape
=
{
bias_b
,
bias_h
,
max_seqlen_q
,
max_seqlen_kv
};
output_bias
->
data
.
dtype
=
QKV_type
;
output_bias
->
data
.
dtype
=
QKV_type
;
}
else
{
}
else
{
Aux_CTX_Tensors
->
size
=
2
;
Aux_CTX_Tensors
->
size
=
2
;
Tensor
*
output_S
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
output_S
->
data
.
dptr
=
nullptr
;
output_S
->
data
.
dptr
=
nullptr
;
if
(
q_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
if
(
q_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
output_S
->
data
.
shape
=
{
max_tokens_q
,
num_attn_heads
,
1
};
output_S
->
data
.
shape
=
{
max_tokens_q
,
num_attn_heads
,
1
};
...
@@ -1472,22 +1472,22 @@ void fused_attn_arbitrary_seqlen_fwd(
...
@@ -1472,22 +1472,22 @@ void fused_attn_arbitrary_seqlen_fwd(
output_S
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen_q
,
1
};
output_S
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen_q
,
1
};
}
}
output_S
->
data
.
dtype
=
DType
::
kFloat32
;
output_S
->
data
.
dtype
=
DType
::
kFloat32
;
Tensor
*
output_rng_state
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
output_rng_state
->
data
.
dptr
=
nullptr
;
output_rng_state
->
data
.
dptr
=
nullptr
;
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
dtype
=
DType
::
kInt64
;
output_rng_state
->
data
.
dtype
=
DType
::
kInt64
;
}
}
}
else
if
(
Aux_CTX_Tensors
->
size
==
2
)
{
}
else
if
(
Aux_CTX_Tensors
->
size
==
2
)
{
Tensor
*
output_S
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
devPtrS
=
output_S
->
data
.
dptr
;
devPtrS
=
output_S
->
data
.
dptr
;
Tensor
*
output_rng_state
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
}
else
if
(
Aux_CTX_Tensors
->
size
==
3
)
{
}
else
if
(
Aux_CTX_Tensors
->
size
==
3
)
{
Tensor
*
output_S
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
devPtrS
=
output_S
->
data
.
dptr
;
devPtrS
=
output_S
->
data
.
dptr
;
Tensor
*
output_rng_state
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
Tensor
*
output_bias
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
2
]);
Tensor
*
output_bias
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
2
]);
output_bias
->
data
.
dptr
=
devPtrBias
;
output_bias
->
data
.
dptr
=
devPtrBias
;
}
else
{
}
else
{
NVTE_ERROR
(
"Unexpected Aux_CTX_Tensors->size."
);
NVTE_ERROR
(
"Unexpected Aux_CTX_Tensors->size."
);
...
...
transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu
View file @
2b05e121
...
@@ -1239,12 +1239,12 @@ void fused_attn_max_512_fwd_qkvpacked(
...
@@ -1239,12 +1239,12 @@ void fused_attn_max_512_fwd_qkvpacked(
if
(
Aux_CTX_Tensors
->
size
==
0
)
{
if
(
Aux_CTX_Tensors
->
size
==
0
)
{
Aux_CTX_Tensors
->
size
=
1
;
Aux_CTX_Tensors
->
size
=
1
;
Tensor
*
output_S
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
output_S
->
data
.
dptr
=
nullptr
;
output_S
->
data
.
dptr
=
nullptr
;
output_S
->
data
.
shape
=
{
batch
,
num_head
,
max_seqlen
,
max_seqlen
};
output_S
->
data
.
shape
=
{
batch
,
num_head
,
max_seqlen
,
max_seqlen
};
output_S
->
data
.
dtype
=
input_QKV
->
data
.
dtype
;
output_S
->
data
.
dtype
=
input_QKV
->
data
.
dtype
;
}
else
if
(
Aux_CTX_Tensors
->
size
==
1
)
{
}
else
if
(
Aux_CTX_Tensors
->
size
==
1
)
{
Tensor
*
output_S
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
devPtrS
=
output_S
->
data
.
dptr
;
devPtrS
=
output_S
->
data
.
dptr
;
}
else
{
}
else
{
NVTE_ERROR
(
"Unexpected Aux_CTX_Tensors->size."
);
NVTE_ERROR
(
"Unexpected Aux_CTX_Tensors->size."
);
...
@@ -1317,12 +1317,12 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max
...
@@ -1317,12 +1317,12 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max
if
(
Aux_CTX_Tensors
->
size
==
0
)
{
if
(
Aux_CTX_Tensors
->
size
==
0
)
{
Aux_CTX_Tensors
->
size
=
1
;
Aux_CTX_Tensors
->
size
=
1
;
Tensor
*
output_S
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
output_S
->
data
.
dptr
=
nullptr
;
output_S
->
data
.
dptr
=
nullptr
;
output_S
->
data
.
shape
=
{
batch
,
num_head
,
q_max_seqlen
,
kv_max_seqlen
};
output_S
->
data
.
shape
=
{
batch
,
num_head
,
q_max_seqlen
,
kv_max_seqlen
};
output_S
->
data
.
dtype
=
q_type
;
output_S
->
data
.
dtype
=
q_type
;
}
else
if
(
Aux_CTX_Tensors
->
size
==
1
)
{
}
else
if
(
Aux_CTX_Tensors
->
size
==
1
)
{
Tensor
*
output_S
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
devPtrS
=
output_S
->
data
.
dptr
;
devPtrS
=
output_S
->
data
.
dptr
;
}
else
{
}
else
{
NVTE_ERROR
(
"Unexpected Aux_CTX_Tensors->size."
);
NVTE_ERROR
(
"Unexpected Aux_CTX_Tensors->size."
);
...
@@ -1386,12 +1386,12 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
...
@@ -1386,12 +1386,12 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
if
(
Aux_CTX_Tensors
->
size
==
0
)
{
if
(
Aux_CTX_Tensors
->
size
==
0
)
{
Aux_CTX_Tensors
->
size
=
1
;
Aux_CTX_Tensors
->
size
=
1
;
Tensor
*
output_S
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
output_S
->
data
.
dptr
=
nullptr
;
output_S
->
data
.
dptr
=
nullptr
;
output_S
->
data
.
shape
=
{
batch
,
num_head
,
q_max_seqlen
,
kv_max_seqlen
};
output_S
->
data
.
shape
=
{
batch
,
num_head
,
q_max_seqlen
,
kv_max_seqlen
};
output_S
->
data
.
dtype
=
q_type
;
output_S
->
data
.
dtype
=
q_type
;
}
else
if
(
Aux_CTX_Tensors
->
size
==
1
)
{
}
else
if
(
Aux_CTX_Tensors
->
size
==
1
)
{
Tensor
*
output_S
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
devPtrS
=
output_S
->
data
.
dptr
;
devPtrS
=
output_S
->
data
.
dptr
;
}
else
{
}
else
{
NVTE_ERROR
(
"Unexpected Aux_CTX_Tensors->size."
);
NVTE_ERROR
(
"Unexpected Aux_CTX_Tensors->size."
);
...
...
transformer_engine/common/fused_attn/fused_attn_fp8.cu
View file @
2b05e121
...
@@ -2364,9 +2364,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
...
@@ -2364,9 +2364,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
NVTE_QKV_Layout_Group
layout_group
=
nvte_get_qkv_layout_group
(
qkv_layout
);
NVTE_QKV_Layout_Group
layout_group
=
nvte_get_qkv_layout_group
(
qkv_layout
);
size_t
stride
=
0
;
size_t
stride
=
0
;
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_3HD
)
{
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_3HD
)
{
stride
=
typeTo
Size
(
QKV_type
)
*
num_attn_heads
*
head_dim
;
stride
=
(
typeTo
NumBits
(
QKV_type
)
*
num_attn_heads
*
head_dim
)
/
8
;
}
else
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_H3D
)
{
}
else
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_H3D
)
{
stride
=
typeTo
Size
(
QKV_type
)
*
head_dim
;
stride
=
(
typeTo
NumBits
(
QKV_type
)
*
head_dim
)
/
8
;
}
}
void
*
devPtrQ
=
static_cast
<
void
*>
(
devPtrQKV
);
void
*
devPtrQ
=
static_cast
<
void
*>
(
devPtrQKV
);
void
*
devPtrK
=
static_cast
<
void
*>
(
static_cast
<
int8_t
*>
(
devPtrQKV
)
+
stride
);
void
*
devPtrK
=
static_cast
<
void
*>
(
static_cast
<
int8_t
*>
(
devPtrQKV
)
+
stride
);
...
@@ -2383,9 +2383,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
...
@@ -2383,9 +2383,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
void
*
devPtrZInv
=
nullptr
;
void
*
devPtrZInv
=
nullptr
;
if
(
Aux_CTX_Tensors
->
size
==
0
)
{
if
(
Aux_CTX_Tensors
->
size
==
0
)
{
Aux_CTX_Tensors
->
size
=
3
;
Aux_CTX_Tensors
->
size
=
3
;
Tensor
*
output_M
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_M
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_ZInv
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_ZInv
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_rng_state
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
2
]);
Tensor
*
output_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
2
]);
output_M
->
data
.
dptr
=
nullptr
;
output_M
->
data
.
dptr
=
nullptr
;
output_M
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen
,
1
};
output_M
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen
,
1
};
output_M
->
data
.
dtype
=
DType
::
kFloat32
;
output_M
->
data
.
dtype
=
DType
::
kFloat32
;
...
@@ -2396,9 +2396,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
...
@@ -2396,9 +2396,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
dtype
=
DType
::
kInt64
;
output_rng_state
->
data
.
dtype
=
DType
::
kInt64
;
}
else
if
(
Aux_CTX_Tensors
->
size
==
3
)
{
}
else
if
(
Aux_CTX_Tensors
->
size
==
3
)
{
Tensor
*
output_M
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_M
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_ZInv
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_ZInv
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_rng_state
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
2
]);
Tensor
*
output_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
2
]);
devPtrM
=
output_M
->
data
.
dptr
;
devPtrM
=
output_M
->
data
.
dptr
;
devPtrZInv
=
output_ZInv
->
data
.
dptr
;
devPtrZInv
=
output_ZInv
->
data
.
dptr
;
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
...
@@ -2466,9 +2466,9 @@ void fused_attn_fp8_bwd_qkvpacked(
...
@@ -2466,9 +2466,9 @@ void fused_attn_fp8_bwd_qkvpacked(
NVTE_QKV_Layout_Group
layout_group
=
nvte_get_qkv_layout_group
(
qkv_layout
);
NVTE_QKV_Layout_Group
layout_group
=
nvte_get_qkv_layout_group
(
qkv_layout
);
size_t
stride
=
0
;
size_t
stride
=
0
;
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_3HD
)
{
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_3HD
)
{
stride
=
typeTo
Size
(
QKV_type
)
*
num_attn_heads
*
head_dim
;
stride
=
(
typeTo
NumBits
(
QKV_type
)
*
num_attn_heads
*
head_dim
)
/
8
;
}
else
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_H3D
)
{
}
else
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_H3D
)
{
stride
=
typeTo
Size
(
QKV_type
)
*
head_dim
;
stride
=
(
typeTo
NumBits
(
QKV_type
)
*
head_dim
)
/
8
;
}
}
void
*
devPtrQ
=
devPtrQKV
;
void
*
devPtrQ
=
devPtrQKV
;
void
*
devPtrK
=
static_cast
<
void
*>
(
static_cast
<
int8_t
*>
(
devPtrQKV
)
+
stride
);
void
*
devPtrK
=
static_cast
<
void
*>
(
static_cast
<
int8_t
*>
(
devPtrQKV
)
+
stride
);
...
@@ -2564,9 +2564,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
...
@@ -2564,9 +2564,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
NVTE_QKV_Layout_Group
layout_group
=
nvte_get_qkv_layout_group
(
qkv_layout
);
NVTE_QKV_Layout_Group
layout_group
=
nvte_get_qkv_layout_group
(
qkv_layout
);
size_t
stride
=
0
;
size_t
stride
=
0
;
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_HD_2HD
)
{
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_HD_2HD
)
{
stride
=
typeTo
Size
(
QKV_type
)
*
num_gqa_groups
*
head_dim
;
stride
=
(
typeTo
NumBits
(
QKV_type
)
*
num_gqa_groups
*
head_dim
)
/
8
;
}
else
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_HD_H2D
)
{
}
else
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_HD_H2D
)
{
stride
=
typeTo
Size
(
QKV_type
)
*
head_dim
;
stride
=
(
typeTo
NumBits
(
QKV_type
)
*
head_dim
)
/
8
;
}
}
void
*
devPtrK
=
devPtrKV
;
void
*
devPtrK
=
devPtrKV
;
void
*
devPtrV
=
static_cast
<
void
*>
(
static_cast
<
int8_t
*>
(
devPtrKV
)
+
stride
);
void
*
devPtrV
=
static_cast
<
void
*>
(
static_cast
<
int8_t
*>
(
devPtrKV
)
+
stride
);
...
@@ -2582,9 +2582,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
...
@@ -2582,9 +2582,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
void
*
devPtrZInv
=
nullptr
;
void
*
devPtrZInv
=
nullptr
;
if
(
Aux_CTX_Tensors
->
size
==
0
)
{
if
(
Aux_CTX_Tensors
->
size
==
0
)
{
Aux_CTX_Tensors
->
size
=
3
;
Aux_CTX_Tensors
->
size
=
3
;
Tensor
*
output_M
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_M
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_ZInv
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_ZInv
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_rng_state
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
2
]);
Tensor
*
output_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
2
]);
output_M
->
data
.
dptr
=
nullptr
;
output_M
->
data
.
dptr
=
nullptr
;
output_M
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen_q
,
1
};
output_M
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen_q
,
1
};
output_M
->
data
.
dtype
=
DType
::
kFloat32
;
output_M
->
data
.
dtype
=
DType
::
kFloat32
;
...
@@ -2595,9 +2595,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
...
@@ -2595,9 +2595,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
dtype
=
DType
::
kInt64
;
output_rng_state
->
data
.
dtype
=
DType
::
kInt64
;
}
else
if
(
Aux_CTX_Tensors
->
size
==
3
)
{
}
else
if
(
Aux_CTX_Tensors
->
size
==
3
)
{
Tensor
*
output_M
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_M
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_ZInv
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_ZInv
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_rng_state
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
2
]);
Tensor
*
output_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
2
]);
devPtrM
=
output_M
->
data
.
dptr
;
devPtrM
=
output_M
->
data
.
dptr
;
devPtrZInv
=
output_ZInv
->
data
.
dptr
;
devPtrZInv
=
output_ZInv
->
data
.
dptr
;
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
...
@@ -2671,9 +2671,9 @@ void fused_attn_fp8_bwd_kvpacked(
...
@@ -2671,9 +2671,9 @@ void fused_attn_fp8_bwd_kvpacked(
NVTE_QKV_Layout_Group
layout_group
=
nvte_get_qkv_layout_group
(
qkv_layout
);
NVTE_QKV_Layout_Group
layout_group
=
nvte_get_qkv_layout_group
(
qkv_layout
);
size_t
stride
=
0
;
size_t
stride
=
0
;
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_HD_2HD
)
{
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_HD_2HD
)
{
stride
=
typeTo
Size
(
QKV_type
)
*
num_gqa_groups
*
head_dim
;
stride
=
(
typeTo
NumBits
(
QKV_type
)
*
num_gqa_groups
*
head_dim
)
/
8
;
}
else
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_HD_H2D
)
{
}
else
if
(
layout_group
==
NVTE_QKV_Layout_Group
::
NVTE_HD_H2D
)
{
stride
=
typeTo
Size
(
QKV_type
)
*
head_dim
;
stride
=
(
typeTo
NumBits
(
QKV_type
)
*
head_dim
)
/
8
;
}
}
void
*
devPtrK
=
devPtrKV
;
void
*
devPtrK
=
devPtrKV
;
void
*
devPtrV
=
static_cast
<
void
*>
(
static_cast
<
int8_t
*>
(
devPtrKV
)
+
stride
);
void
*
devPtrV
=
static_cast
<
void
*>
(
static_cast
<
int8_t
*>
(
devPtrKV
)
+
stride
);
...
@@ -2779,9 +2779,9 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
...
@@ -2779,9 +2779,9 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
void
*
devPtrZInv
=
nullptr
;
void
*
devPtrZInv
=
nullptr
;
if
(
Aux_CTX_Tensors
->
size
==
0
)
{
if
(
Aux_CTX_Tensors
->
size
==
0
)
{
Aux_CTX_Tensors
->
size
=
3
;
Aux_CTX_Tensors
->
size
=
3
;
Tensor
*
output_M
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_M
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_ZInv
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_ZInv
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_rng_state
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
2
]);
Tensor
*
output_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
2
]);
output_M
->
data
.
dptr
=
nullptr
;
output_M
->
data
.
dptr
=
nullptr
;
output_M
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen_q
,
1
};
output_M
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen_q
,
1
};
output_M
->
data
.
dtype
=
DType
::
kFloat32
;
output_M
->
data
.
dtype
=
DType
::
kFloat32
;
...
@@ -2792,9 +2792,9 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
...
@@ -2792,9 +2792,9 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
dtype
=
DType
::
kInt64
;
output_rng_state
->
data
.
dtype
=
DType
::
kInt64
;
}
else
if
(
Aux_CTX_Tensors
->
size
==
3
)
{
}
else
if
(
Aux_CTX_Tensors
->
size
==
3
)
{
Tensor
*
output_M
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_M
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_ZInv
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_ZInv
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_rng_state
=
reinterpret_cast
<
Tensor
*>
(
Aux_CTX_Tensors
->
tensors
[
2
]);
Tensor
*
output_rng_state
=
convertNVTE
Tensor
Check
(
Aux_CTX_Tensors
->
tensors
[
2
]);
devPtrM
=
output_M
->
data
.
dptr
;
devPtrM
=
output_M
->
data
.
dptr
;
devPtrZInv
=
output_ZInv
->
data
.
dptr
;
devPtrZInv
=
output_ZInv
->
data
.
dptr
;
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
...
...
transformer_engine/common/fused_attn/kv_cache.cu
View file @
2b05e121
...
@@ -260,12 +260,12 @@ void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cach
...
@@ -260,12 +260,12 @@ void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cach
NVTE_API_CALL
(
nvte_copy_to_kv_cache
);
NVTE_API_CALL
(
nvte_copy_to_kv_cache
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
kv_cache
::
copy_to_kv_cache
(
kv_cache
::
copy_to_kv_cache
(
*
convertNVTETensorCheck
(
new_k
),
*
convertNVTETensorCheck
(
new_v
),
*
reinterpret_cast
<
Tensor
*>
(
new_k
),
*
reinterpret_cast
<
Tensor
*>
(
new_v
),
*
convertNVTETensorCheck
(
k_cache
),
*
convertNVTETensorCheck
(
v_cache
),
*
reinterpret_cast
<
Tensor
*>
(
k_cache
),
*
reinterpret_cast
<
Tensor
*>
(
v_cach
e
),
*
convertNVTETensorCheck
(
page_tabl
e
),
*
reinterpret_cast
<
Tensor
*>
(
page_table
),
*
reinterpret_cast
<
Tensor
*>
(
cu_new_lens
),
*
convertNVTE
Tensor
Check
(
cu_new_lens
),
*
reinterpret_cast
<
Tensor
*>
(
cu_cached_lens
),
qkv_format
,
b
,
max_ctx_len
,
max_seq_len
,
*
convertNVTE
Tensor
Check
(
cu_cached_lens
),
qkv_format
,
b
,
max_ctx_len
,
max_pages_per_seq
,
is_non_paged
,
stream
);
max_seq_len
,
max_pages_per_seq
,
is_non_paged
,
stream
);
}
}
/***************************************************************************************************
/***************************************************************************************************
...
@@ -277,9 +277,9 @@ void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETens
...
@@ -277,9 +277,9 @@ void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETens
NVTE_API_CALL
(
nvte_convert_thd_to_bshd
);
NVTE_API_CALL
(
nvte_convert_thd_to_bshd
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
kv_cache
::
convert_thd_to_bshd
(
*
reinterpret_cast
<
Tensor
*>
(
tensor
),
kv_cache
::
convert_thd_to_bshd
(
*
convertNVTE
Tensor
Check
(
tensor
),
*
reinterpret_cast
<
Tensor
*>
(
cu_seqlens
),
*
convertNVTE
Tensor
Check
(
cu_seqlens
),
*
reinterpret_cast
<
Tensor
*>
(
new_tensor
),
b
,
max_seq_len
,
stream
);
*
convertNVTE
Tensor
Check
(
new_tensor
),
b
,
max_seq_len
,
stream
);
}
}
/***************************************************************************************************
/***************************************************************************************************
...
@@ -291,7 +291,7 @@ void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETens
...
@@ -291,7 +291,7 @@ void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETens
NVTE_API_CALL
(
nvte_convert_bshd_to_thd
);
NVTE_API_CALL
(
nvte_convert_bshd_to_thd
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
kv_cache
::
convert_bshd_to_thd
(
*
reinterpret_cast
<
Tensor
*>
(
tensor
),
kv_cache
::
convert_bshd_to_thd
(
*
convertNVTE
Tensor
Check
(
tensor
),
*
reinterpret_cast
<
Tensor
*>
(
cu_seqlens
),
*
convertNVTE
Tensor
Check
(
cu_seqlens
),
*
reinterpret_cast
<
Tensor
*>
(
new_tensor
),
t
,
stream
);
*
convertNVTE
Tensor
Check
(
new_tensor
),
t
,
stream
);
}
}
transformer_engine/common/fused_rope/fused_rope.cu
View file @
2b05e121
...
@@ -308,11 +308,10 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens
...
@@ -308,11 +308,10 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens
const
int
stride_d
,
cudaStream_t
stream
)
{
const
int
stride_d
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_fused_rope_forward
);
NVTE_API_CALL
(
nvte_fused_rope_forward
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
fused_rope_forward
(
fused_rope_forward
(
*
convertNVTETensorCheck
(
input
),
*
convertNVTETensorCheck
(
cu_seqlens
),
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
*
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens
),
*
convertNVTETensorCheck
(
freqs
),
*
convertNVTETensorCheck
(
start_positions
),
*
reinterpret_cast
<
const
Tensor
*>
(
freqs
),
*
reinterpret_cast
<
const
Tensor
*>
(
start_positions
),
convertNVTETensorCheck
(
output
),
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
s
,
reinterpret_cast
<
Tensor
*>
(
output
),
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
b
,
h
,
d
,
d2
,
stride_s_or_t
,
stride_b
,
stride_h
,
stride_d
,
stream
);
stride_s_or_t
,
stride_b
,
stride_h
,
stride_d
,
stream
);
}
}
void
nvte_fused_rope_backward
(
const
NVTETensor
output_grads
,
const
NVTETensor
cu_seqlens
,
void
nvte_fused_rope_backward
(
const
NVTETensor
output_grads
,
const
NVTETensor
cu_seqlens
,
...
@@ -324,9 +323,8 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu
...
@@ -324,9 +323,8 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_fused_rope_backward
);
NVTE_API_CALL
(
nvte_fused_rope_backward
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
fused_rope_backward
(
*
reinterpret_cast
<
const
Tensor
*>
(
output_grads
),
fused_rope_backward
(
*
convertNVTETensorCheck
(
output_grads
),
*
convertNVTETensorCheck
(
cu_seqlens
),
*
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens
),
*
convertNVTETensorCheck
(
freqs
),
convertNVTETensorCheck
(
input_grads
),
*
reinterpret_cast
<
const
Tensor
*>
(
freqs
),
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
stride_s_or_t
,
reinterpret_cast
<
Tensor
*>
(
input_grads
),
qkv_format
,
interleaved
,
cp_size
,
stride_b
,
stride_h
,
stride_d
,
stream
);
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
stride_s_or_t
,
stride_b
,
stride_h
,
stride_d
,
stream
);
}
}
transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu
View file @
2b05e121
...
@@ -551,8 +551,8 @@ void nvte_scaled_aligned_causal_masked_softmax_forward(const NVTETensor input,
...
@@ -551,8 +551,8 @@ void nvte_scaled_aligned_causal_masked_softmax_forward(const NVTETensor input,
float
scale_factor
,
cudaStream_t
stream
)
{
float
scale_factor
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_scaled_aligned_causal_masked_softmax_forward
);
NVTE_API_CALL
(
nvte_scaled_aligned_causal_masked_softmax_forward
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
scaled_aligned_causal_masked_softmax_forward
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
scaled_aligned_causal_masked_softmax_forward
(
*
convertNVTE
Tensor
Check
(
input
),
reinterpret_cast
<
Tensor
*>
(
softmax_results
),
convertNVTE
Tensor
Check
(
softmax_results
),
scale_factor
,
stream
);
scale_factor
,
stream
);
}
}
...
@@ -563,6 +563,6 @@ void nvte_scaled_aligned_causal_masked_softmax_backward(const NVTETensor incomin
...
@@ -563,6 +563,6 @@ void nvte_scaled_aligned_causal_masked_softmax_backward(const NVTETensor incomin
NVTE_API_CALL
(
nvte_scaled_aligned_causal_masked_softmax_backward
);
NVTE_API_CALL
(
nvte_scaled_aligned_causal_masked_softmax_backward
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
scaled_aligned_causal_masked_softmax_backward
(
scaled_aligned_causal_masked_softmax_backward
(
*
reinterpret_cast
<
Tensor
*>
(
output_grads
),
*
reinterpret_cast
<
const
Tensor
*>
(
incoming_grads
),
*
convertNVTE
Tensor
Check
(
output_grads
),
*
convertNVTE
Tensor
Check
(
incoming_grads
),
*
reinterpret_cast
<
const
Tensor
*>
(
softmax_results
),
scale_factor
,
stream
);
*
convertNVTE
Tensor
Check
(
softmax_results
),
scale_factor
,
stream
);
}
}
transformer_engine/common/fused_softmax/scaled_masked_softmax.cu
View file @
2b05e121
...
@@ -815,8 +815,8 @@ void nvte_scaled_softmax_forward(const NVTETensor input, NVTETensor softmax_resu
...
@@ -815,8 +815,8 @@ void nvte_scaled_softmax_forward(const NVTETensor input, NVTETensor softmax_resu
float
scale_factor
,
cudaStream_t
stream
)
{
float
scale_factor
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_scaled_softmax_forward
);
NVTE_API_CALL
(
nvte_scaled_softmax_forward
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
scaled_softmax_forward
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
scaled_softmax_forward
(
*
convertNVTETensorCheck
(
input
),
convertNVTETensorCheck
(
softmax_results
),
reinterpret_cast
<
Tensor
*>
(
softmax_results
),
scale_factor
,
stream
);
scale_factor
,
stream
);
}
}
void
nvte_scaled_softmax_backward
(
const
NVTETensor
incoming_grads
,
const
NVTETensor
softmax_results
,
void
nvte_scaled_softmax_backward
(
const
NVTETensor
incoming_grads
,
const
NVTETensor
softmax_results
,
...
@@ -824,9 +824,9 @@ void nvte_scaled_softmax_backward(const NVTETensor incoming_grads, const NVTETen
...
@@ -824,9 +824,9 @@ void nvte_scaled_softmax_backward(const NVTETensor incoming_grads, const NVTETen
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_scaled_softmax_backward
);
NVTE_API_CALL
(
nvte_scaled_softmax_backward
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
scaled_softmax_backward
(
*
reinterpret_cast
<
Tensor
*>
(
output_grads
),
scaled_softmax_backward
(
*
convertNVTE
Tensor
Check
(
output_grads
),
*
reinterpret_cast
<
const
Tensor
*>
(
incoming_grads
),
*
convertNVTE
Tensor
Check
(
incoming_grads
),
*
reinterpret_cast
<
const
Tensor
*>
(
softmax_results
),
scale_factor
,
stream
);
*
convertNVTE
Tensor
Check
(
softmax_results
),
scale_factor
,
stream
);
}
}
void
nvte_scaled_masked_softmax_forward
(
const
NVTETensor
input
,
const
NVTETensor
mask
,
void
nvte_scaled_masked_softmax_forward
(
const
NVTETensor
input
,
const
NVTETensor
mask
,
...
@@ -834,9 +834,8 @@ void nvte_scaled_masked_softmax_forward(const NVTETensor input, const NVTETensor
...
@@ -834,9 +834,8 @@ void nvte_scaled_masked_softmax_forward(const NVTETensor input, const NVTETensor
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_scaled_masked_softmax_forward
);
NVTE_API_CALL
(
nvte_scaled_masked_softmax_forward
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
scaled_masked_softmax_forward
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
scaled_masked_softmax_forward
(
*
convertNVTETensorCheck
(
input
),
*
convertNVTETensorCheck
(
mask
),
*
reinterpret_cast
<
const
Tensor
*>
(
mask
),
convertNVTETensorCheck
(
softmax_results
),
scale_factor
,
stream
);
reinterpret_cast
<
Tensor
*>
(
softmax_results
),
scale_factor
,
stream
);
}
}
void
nvte_scaled_masked_softmax_backward
(
const
NVTETensor
incoming_grads
,
void
nvte_scaled_masked_softmax_backward
(
const
NVTETensor
incoming_grads
,
...
@@ -844,7 +843,7 @@ void nvte_scaled_masked_softmax_backward(const NVTETensor incoming_grads,
...
@@ -844,7 +843,7 @@ void nvte_scaled_masked_softmax_backward(const NVTETensor incoming_grads,
float
scale_factor
,
cudaStream_t
stream
)
{
float
scale_factor
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_scaled_masked_softmax_backward
);
NVTE_API_CALL
(
nvte_scaled_masked_softmax_backward
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
scaled_masked_softmax_backward
(
scaled_masked_softmax_backward
(
*
convertNVTETensorCheck
(
output_grads
),
*
reinterpret_cast
<
Tensor
*>
(
output_grads
),
*
reinterpret_cast
<
const
Tensor
*>
(
incoming_grads
),
*
convertNVTE
Tensor
Check
(
incoming_grads
),
*
reinterpret_cast
<
const
Tensor
*>
(
softmax_results
),
scale_factor
,
stream
);
*
convertNVTE
Tensor
Check
(
softmax_results
),
scale_factor
,
stream
);
}
}
transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu
View file @
2b05e121
...
@@ -599,9 +599,9 @@ void nvte_scaled_upper_triang_masked_softmax_forward(const NVTETensor input,
...
@@ -599,9 +599,9 @@ void nvte_scaled_upper_triang_masked_softmax_forward(const NVTETensor input,
NVTETensor
softmax_results
,
float
scale_factor
,
NVTETensor
softmax_results
,
float
scale_factor
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
scaled_upper_triang_masked_softmax_forward
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
scaled_upper_triang_masked_softmax_forward
(
*
convertNVTE
Tensor
Check
(
input
),
reinterpret_cast
<
Tensor
*>
(
softmax_results
),
convertNVTE
Tensor
Check
(
softmax_results
),
scale_factor
,
scale_factor
,
stream
);
stream
);
}
}
void
nvte_scaled_upper_triang_masked_softmax_backward
(
const
NVTETensor
incoming_grads
,
void
nvte_scaled_upper_triang_masked_softmax_backward
(
const
NVTETensor
incoming_grads
,
...
@@ -610,6 +610,6 @@ void nvte_scaled_upper_triang_masked_softmax_backward(const NVTETensor incoming_
...
@@ -610,6 +610,6 @@ void nvte_scaled_upper_triang_masked_softmax_backward(const NVTETensor incoming_
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
scaled_upper_triang_masked_softmax_backward
(
scaled_upper_triang_masked_softmax_backward
(
*
reinterpret_cast
<
Tensor
*>
(
output_grads
),
*
reinterpret_cast
<
const
Tensor
*>
(
incoming_grads
),
*
convertNVTE
Tensor
Check
(
output_grads
),
*
convertNVTE
Tensor
Check
(
incoming_grads
),
*
reinterpret_cast
<
const
Tensor
*>
(
softmax_results
),
scale_factor
,
stream
);
*
convertNVTE
Tensor
Check
(
softmax_results
),
scale_factor
,
stream
);
}
}
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
2b05e121
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "rocm_gemm.hip"
#include "rocm_gemm.hip"
#endif // #ifndef __HIP_PLATFORM_AMD__
#endif // #ifndef __HIP_PLATFORM_AMD__
#include <transformer_engine/gemm.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/multi_stream.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transformer_engine.h>
#include <cstdint>
#include <cstdint>
...
@@ -22,6 +23,7 @@
...
@@ -22,6 +23,7 @@
#include "../common.h"
#include "../common.h"
#include "../util/handle_manager.h"
#include "../util/handle_manager.h"
#include "../util/logging.h"
#include "../util/logging.h"
#include "../util/multi_stream.h"
#include "common/util/cuda_runtime.h"
#include "common/util/cuda_runtime.h"
#ifndef __HIP_PLATFORM_AMD__
#ifndef __HIP_PLATFORM_AMD__
...
@@ -94,7 +96,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
...
@@ -94,7 +96,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
A
.
scaling_mode
==
B
.
scaling_mode
||
A
.
scaling_mode
==
B
.
scaling_mode
||
(
A
.
scaling_mode
==
NVTE_BLOCK_SCALING_1D
&&
B
.
scaling_mode
==
NVTE_BLOCK_SCALING_2D
)
||
(
A
.
scaling_mode
==
NVTE_BLOCK_SCALING_1D
&&
B
.
scaling_mode
==
NVTE_BLOCK_SCALING_2D
)
||
(
A
.
scaling_mode
==
NVTE_BLOCK_SCALING_2D
&&
B
.
scaling_mode
==
NVTE_BLOCK_SCALING_1D
),
(
A
.
scaling_mode
==
NVTE_BLOCK_SCALING_2D
&&
B
.
scaling_mode
==
NVTE_BLOCK_SCALING_1D
),
"Inputs A and B to GEMM need to have compatible scaling modes!"
);
"Inputs A and B to GEMM need to have compatible scaling modes, but got A.scaling_mode = "
+
to_string
(
A
.
scaling_mode
)
+
", B.scaling_mode = "
+
to_string
(
B
.
scaling_mode
));
NVTE_CHECK
(
A
.
has_data
()
||
A
.
has_columnwise_data
(),
"Input A does not hold any data!"
);
NVTE_CHECK
(
A
.
has_data
()
||
A
.
has_columnwise_data
(),
"Input A does not hold any data!"
);
NVTE_CHECK
(
B
.
has_data
()
||
B
.
has_columnwise_data
(),
"Input B does not hold any data!"
);
NVTE_CHECK
(
B
.
has_data
()
||
B
.
has_columnwise_data
(),
"Input B does not hold any data!"
);
GemmParam
ret
;
GemmParam
ret
;
...
@@ -507,7 +510,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -507,7 +510,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
)));
&
epilogue
,
sizeof
(
epilogue
)));
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \
CUBLAS_VERSION < 130000
if
(
counter
!=
nullptr
)
{
if
(
counter
!=
nullptr
)
{
if
(
m_split
==
0
)
m_split
=
1
;
if
(
m_split
==
0
)
m_split
=
1
;
if
(
n_split
==
0
)
n_split
=
1
;
if
(
n_split
==
0
)
n_split
=
1
;
...
@@ -536,6 +540,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -536,6 +540,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const
auto
B_alignment
=
_getAlignment
(
reinterpret_cast
<
uintptr_t
>
(
param
.
B
));
const
auto
B_alignment
=
_getAlignment
(
reinterpret_cast
<
uintptr_t
>
(
param
.
B
));
const
auto
C_alignment
=
_getAlignment
(
reinterpret_cast
<
uintptr_t
>
(
C
));
const
auto
C_alignment
=
_getAlignment
(
reinterpret_cast
<
uintptr_t
>
(
C
));
const
auto
D_alignment
=
_getAlignment
(
reinterpret_cast
<
uintptr_t
>
(
D
));
const
auto
D_alignment
=
_getAlignment
(
reinterpret_cast
<
uintptr_t
>
(
D
));
const
auto
workspace_alignment
=
_getAlignment
(
reinterpret_cast
<
uintptr_t
>
(
workspace
));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulPreferenceSetAttribute
(
NVTE_CHECK_CUBLAS
(
cublasLtMatmulPreferenceSetAttribute
(
preference
,
CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES
,
&
A_alignment
,
sizeof
(
A_alignment
)));
preference
,
CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES
,
&
A_alignment
,
sizeof
(
A_alignment
)));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulPreferenceSetAttribute
(
NVTE_CHECK_CUBLAS
(
cublasLtMatmulPreferenceSetAttribute
(
...
@@ -544,6 +549,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -544,6 +549,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
preference
,
CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES
,
&
C_alignment
,
sizeof
(
C_alignment
)));
preference
,
CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES
,
&
C_alignment
,
sizeof
(
C_alignment
)));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulPreferenceSetAttribute
(
NVTE_CHECK_CUBLAS
(
cublasLtMatmulPreferenceSetAttribute
(
preference
,
CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES
,
&
D_alignment
,
sizeof
(
D_alignment
)));
preference
,
CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES
,
&
D_alignment
,
sizeof
(
D_alignment
)));
NVTE_CHECK
(
workspace_alignment
%
256
==
0
,
"cuBLAS workspace pointer must be aligned to 256 bytes, got "
,
workspace_alignment
);
const
auto
status
=
const
auto
status
=
cublasLtMatmulAlgoGetHeuristic
(
handle
,
operationDesc
,
Adesc
,
Bdesc
,
Cdesc
,
Ddesc
,
preference
,
cublasLtMatmulAlgoGetHeuristic
(
handle
,
operationDesc
,
Adesc
,
Bdesc
,
Cdesc
,
Ddesc
,
preference
,
...
@@ -582,18 +589,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -582,18 +589,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
}
}
#endif // __HIP_PLATFORM_AMD__
#endif // __HIP_PLATFORM_AMD__
static
std
::
once_flag
init_flag
;
static
cudaStream_t
compute_streams
[
num_streams
];
static
cudaEvent_t
cublas_event
[
num_streams
];
// Warning: only call once per device!
static
void
init_streams_and_events
()
{
for
(
int
i
=
0
;
i
<
num_streams
;
i
++
)
{
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
compute_streams
[
i
],
cudaStreamNonBlocking
,
-
1
));
NVTE_CHECK_CUDA
(
cudaEventCreate
(
&
cublas_event
[
i
]));
}
}
// Add for batchgemm
// Add for batchgemm
static
std
::
once_flag
init_flag_batchgemm
;
static
std
::
once_flag
init_flag_batchgemm
;
static
cudaStream_t
compute_streams_batchgemm
[
num_batchgemm_streams
];
static
cudaStream_t
compute_streams_batchgemm
[
num_batchgemm_streams
];
...
@@ -615,12 +610,12 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
...
@@ -615,12 +610,12 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
int
math_sm_count
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
,
bool
nvte_use_rocblas
,
int
compute_stream_offset
)
{
int
math_sm_count
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
,
bool
nvte_use_rocblas
,
int
compute_stream_offset
)
{
NVTE_API_CALL
(
nvte_cublas_gemm
);
NVTE_API_CALL
(
nvte_cublas_gemm
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
const
Tensor
*
inputA
=
reinterpret_cast
<
const
Tensor
*>
(
A
);
const
Tensor
*
inputA
=
convertNVTE
Tensor
Check
(
A
);
const
Tensor
*
inputB
=
reinterpret_cast
<
const
Tensor
*>
(
B
);
const
Tensor
*
inputB
=
convertNVTE
Tensor
Check
(
B
);
Tensor
*
outputD
=
reinterpret_cast
<
Tensor
*>
(
D
);
Tensor
*
outputD
=
convertNVTE
Tensor
(
D
);
const
Tensor
*
biasTensor
=
reinterpret_cast
<
const
Tensor
*>
(
bias
);
const
Tensor
*
biasTensor
=
convertNVTE
Tensor
(
bias
);
Tensor
*
outputGelu
=
reinterpret_cast
<
Tensor
*>
(
pre_gelu_out
);
Tensor
*
outputGelu
=
convertNVTE
Tensor
(
pre_gelu_out
);
Tensor
*
wspace
=
reinterpret_cast
<
Tensor
*>
(
workspace
);
Tensor
*
wspace
=
convertNVTE
Tensor
(
workspace
);
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
const
size_t
A0
=
inputA
->
flat_first_dim
();
const
size_t
A0
=
inputA
->
flat_first_dim
();
...
@@ -693,18 +688,19 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
...
@@ -693,18 +688,19 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
#ifndef __HIP_PLATFORM_AMD__
#ifndef __HIP_PLATFORM_AMD__
int
cudart_version
;
int
cudart_version
;
NVTE_CHECK_CUDA
(
cudaRuntimeGetVersion
(
&
cudart_version
));
NVTE_CHECK_CUDA
(
cudaRuntimeGetVersion
(
&
cudart_version
));
NVTE_CHECK
(
cudart_version
>=
12020
,
"Cuda version 12.2 is required for atomic gemm."
);
NVTE_CHECK
(
cudart_version
>=
12020
&&
cudart_version
<
13000
,
NVTE_CHECK
(
cublasLtGetVersion
()
>=
120205
,
"Cublas version 12.2.5 is required for atomic gemm."
);
"Cuda version >=12.2 and <13.0 is required for atomic gemm."
);
#endif
NVTE_CHECK
(
cublasLtGetVersion
()
>=
120205
&&
cublasLtGetVersion
()
<
130000
,
"Cublas version >=12.2.5 and <13.0 is required for atomic gemm."
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
const
Tensor
*
inputA
=
reinterpret_cast
<
const
Tensor
*>
(
A
);
const
Tensor
*
inputA
=
convertNVTE
Tensor
Check
(
A
);
const
Tensor
*
inputB
=
reinterpret_cast
<
const
Tensor
*>
(
B
);
const
Tensor
*
inputB
=
convertNVTE
Tensor
Check
(
B
);
Tensor
*
outputD
=
reinterpret_cast
<
Tensor
*>
(
D
);
Tensor
*
outputD
=
convertNVTE
Tensor
(
D
);
const
Tensor
*
biasTensor
=
reinterpret_cast
<
const
Tensor
*>
(
bias
);
const
Tensor
*
biasTensor
=
convertNVTE
Tensor
(
bias
);
Tensor
*
outputGelu
=
reinterpret_cast
<
Tensor
*>
(
pre_gelu_out
);
Tensor
*
outputGelu
=
convertNVTE
Tensor
(
pre_gelu_out
);
const
Tensor
*
inputCounter
=
reinterpret_cast
<
const
Tensor
*>
(
counter
);
const
Tensor
*
inputCounter
=
convertNVTE
Tensor
(
counter
);
Tensor
*
wspace
=
reinterpret_cast
<
Tensor
*>
(
workspace
);
Tensor
*
wspace
=
convertNVTE
Tensor
(
workspace
);
NVTE_CHECK
(
is_delayed_tensor_scaling
(
inputA
->
scaling_mode
)
&&
NVTE_CHECK
(
is_delayed_tensor_scaling
(
inputA
->
scaling_mode
)
&&
is_delayed_tensor_scaling
(
inputB
->
scaling_mode
),
is_delayed_tensor_scaling
(
inputB
->
scaling_mode
),
...
@@ -775,14 +771,15 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
...
@@ -775,14 +771,15 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_stream_cublas_gemm
);
NVTE_API_CALL
(
nvte_multi_stream_cublas_gemm
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
// Inits streams and events (once, globally)
std
::
call_once
(
init_flag
,
init_streams_and_events
);
int
num_streams
=
nvte_get_num_compute_streams
(
);
int
num_stream_used
=
std
::
min
(
num_streams
,
num_gemms
);
int
num_stream_used
=
std
::
min
(
num_streams
,
num_gemms
);
// wait for current stream to finish
// wait for current stream to finish
NVTE_CHECK_CUDA
(
cudaEventRecord
(
cublas
_event
[
0
]
,
stream
));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
detail
::
get_compute_stream
_event
(
0
)
,
stream
));
for
(
int
s
=
0
;
s
<
num_stream_used
;
s
++
)
{
for
(
int
s
=
0
;
s
<
num_stream_used
;
s
++
)
{
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
compute_streams
[
s
],
cublas_event
[
0
]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
detail
::
get_compute_stream
(
s
),
detail
::
get_compute_stream_event
(
0
)));
}
}
const
char
*
NVTE_BLAS_MULSTREAM
=
std
::
getenv
(
"NVTE_FORCE_BLAS_MULSTREAM"
);
const
char
*
NVTE_BLAS_MULSTREAM
=
std
::
getenv
(
"NVTE_FORCE_BLAS_MULSTREAM"
);
const
char
*
NVTE_FORCE_ROCM_GEMM
=
std
::
getenv
(
"NVTE_FORCE_ROCM_GEMM"
);
const
char
*
NVTE_FORCE_ROCM_GEMM
=
std
::
getenv
(
"NVTE_FORCE_ROCM_GEMM"
);
...
@@ -798,23 +795,24 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
...
@@ -798,23 +795,24 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
for
(
int
i
=
0
;
i
<
num_gemms
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_gemms
;
i
++
)
{
nvte_cublas_gemm
(
A
[
i
],
B
[
i
],
D
[
i
],
bias
[
i
],
pre_gelu_out
[
i
],
transa
,
transb
,
grad
,
nvte_cublas_gemm
(
A
[
i
],
B
[
i
],
D
[
i
],
bias
[
i
],
pre_gelu_out
[
i
],
transa
,
transb
,
grad
,
workspace
[
i
%
num_streams
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
workspace
[
i
%
num_streams
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
compute_stream
s
[
i
%
num_streams
]
);
detail
::
get_
compute_stream
(
i
%
num_streams
)
);
}
}
}
else
{
}
else
{
for
(
int
i
=
0
;
i
<
num_gemms
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_gemms
;
i
++
)
{
nvte_cublas_gemm
(
A
[
i
],
B
[
i
],
D
[
i
],
bias
[
i
],
pre_gelu_out
[
i
],
transa
,
transb
,
grad
,
nvte_cublas_gemm
(
A
[
i
],
B
[
i
],
D
[
i
],
bias
[
i
],
pre_gelu_out
[
i
],
transa
,
transb
,
grad
,
workspace
[
i
%
num_streams
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
workspace
[
i
%
num_streams
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
compute_stream
s
[
i
%
num_streams
]
,
1
,
0
,
i
%
num_streams
);
detail
::
get_
compute_stream
(
i
%
num_streams
)
,
1
,
0
,
i
%
num_streams
);
}
}
}
}
// record events on compute streams
// record events on compute streams
for
(
int
s
=
0
;
s
<
num_stream_used
;
s
++
)
{
for
(
int
s
=
0
;
s
<
num_stream_used
;
s
++
)
{
NVTE_CHECK_CUDA
(
cudaEventRecord
(
cublas_event
[
s
],
compute_streams
[
s
]));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
detail
::
get_compute_stream_event
(
s
),
detail
::
get_compute_stream
(
s
)));
}
}
// wait for all compute streams to finish
// wait for all compute streams to finish
for
(
int
s
=
0
;
s
<
num_stream_used
;
s
++
)
{
for
(
int
s
=
0
;
s
<
num_stream_used
;
s
++
)
{
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream
,
cublas
_event
[
s
]
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream
,
detail
::
get_compute_stream
_event
(
s
)
));
}
}
}
}
...
...
transformer_engine/common/include/transformer_engine/cast.h
View file @
2b05e121
...
@@ -259,6 +259,17 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_inp
...
@@ -259,6 +259,17 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_inp
*/
*/
void
nvte_dequantize
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
void
nvte_dequantize
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
/*! \brief Casts multiple input tensors to quantized output tensors.
*
* \param[in] inputs List of input tensors to be cast.
* \param[in,out] outputs List of output quantized tensors.
* \param[in] quant_config (Optional) Quantization configurations.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_multi_tensor_quantize
(
const
NVTETensor
*
inputs
,
NVTETensor
*
outputs
,
const
NVTEQuantizationConfig
quant_config
,
const
size_t
num_tensors
,
cudaStream_t
stream
);
#ifdef __cplusplus
#ifdef __cplusplus
}
// extern "C"
}
// extern "C"
#endif
#endif
...
...
transformer_engine/common/include/transformer_engine/cast_transpose_noop.h
View file @
2b05e121
...
@@ -17,23 +17,21 @@
...
@@ -17,23 +17,21 @@
extern
"C"
{
extern
"C"
{
#endif
#endif
/*! \brief Transposes the input, providing the option to immediately exit the kernel
/*! \brief Transposes the input.
* based on the value of the 'noop' tensor.
*
*
* \param[in] input Input tensor.
* \param[in] input Input tensor
to be cast
.
* \param[in] noop
Noop tensor
.
* \param[in] noop
If this single element tensor has non-zero value, kernel will exit immediately
.
* \param[in,out] output Output tensor.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
* \param[in] stream CUDA stream used for the operation.
*/
*/
void
nvte_transpose_with_noop
(
const
NVTETensor
input
,
const
NVTETensor
noop
,
NVTETensor
output
,
void
nvte_transpose_with_noop
(
const
NVTETensor
input
,
const
NVTETensor
noop
,
NVTETensor
output
,
cudaStream_t
stream
);
cudaStream_t
stream
);
/*! \brief Casts and transposes the input, providing the option to immediately exit the kernel
/*! \brief Casts and transposes the input.
* based on the value of the 'noop' tensor.
*
*
* \param[in] input Input tensor.
* \param[in] input Input tensor
to be cast
.
* \param[in] noop
Noop tensor
.
* \param[in] noop
If this single element tensor has non-zero value, kernel will exit immediately
.
* \param[in,out] output Output tensor.
* \param[in,out] output Output
quantized
tensor.
* \param[in] stream CUDA stream used for the operation.
* \param[in] stream CUDA stream used for the operation.
*/
*/
void
nvte_cast_transpose_with_noop
(
const
NVTETensor
input
,
const
NVTETensor
noop
,
NVTETensor
output
,
void
nvte_cast_transpose_with_noop
(
const
NVTETensor
input
,
const
NVTETensor
noop
,
NVTETensor
output
,
...
...
transformer_engine/common/include/transformer_engine/fused_attn.h
View file @
2b05e121
...
@@ -172,6 +172,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
...
@@ -172,6 +172,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
/*! \brief Get fused attention backend based on input parameters.
/*! \brief Get fused attention backend based on input parameters.
*
*
* \param[in] is_training Whether the model is in training mode.
* \param[in] q_dtype The data type of Tensor Q.
* \param[in] q_dtype The data type of Tensor Q.
* \param[in] kv_dtype The data type of Tensors K, V.
* \param[in] kv_dtype The data type of Tensors K, V.
* \param[in] qkv_layout The layout of Tensors Q, K, V.
* \param[in] qkv_layout The layout of Tensors Q, K, V.
...
@@ -188,10 +189,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
...
@@ -188,10 +189,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] window_size_right Sliding window size (the right half).
*/
*/
NVTE_Fused_Attn_Backend
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
nvte_get_fused_attn_backend
(
NVTEDType
q_dtype
,
NVTEDType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
bool
is_training
,
NVTEDType
q_dtype
,
NVTEDType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Mask_Type
attn_mask_type
,
float
dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
float
dropout
,
size_t
num_attn_heads
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
int64_t
window_size_left
,
int64_t
window_size_right
);
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
);
/*! \brief Compute dot product attention with packed QKV input.
/*! \brief Compute dot product attention with packed QKV input.
*
*
...
@@ -580,6 +581,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
...
@@ -580,6 +581,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
cudaStream_t
stream
);
cudaStream_t
stream
);
/*! \brief Update the RNG state with the seed and calculated offset.
/*! \brief Update the RNG state with the seed and calculated offset.
*
* \warning This API is **experimental** and subject to change.
*
*
* \param[in] rng_state_dst RNG state to store seed and offset.
* \param[in] rng_state_dst RNG state to store seed and offset.
* \param[in] seed Seed for RNG state.
* \param[in] seed Seed for RNG state.
...
@@ -595,6 +598,8 @@ void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor se
...
@@ -595,6 +598,8 @@ void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor se
NVTE_Fused_Attn_Backend
backend
,
cudaStream_t
stream
);
NVTE_Fused_Attn_Backend
backend
,
cudaStream_t
stream
);
/*! \brief Get KV format for a given QKV layout.
/*! \brief Get KV format for a given QKV layout.
*
* \warning This API is **experimental** and subject to change.
*
*
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] workspace Workspace tensor.
* \param[in] workspace Workspace tensor.
...
@@ -604,48 +609,187 @@ void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor se
...
@@ -604,48 +609,187 @@ void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor se
uint32_t
nvte_get_runtime_num_segments
(
NVTETensor
cu_seqlen
,
NVTETensor
workspace
,
size_t
len
,
uint32_t
nvte_get_runtime_num_segments
(
NVTETensor
cu_seqlen
,
NVTETensor
workspace
,
size_t
len
,
cudaStream_t
stream
);
cudaStream_t
stream
);
/*! \brief Set the seed and offset for RNG state.
*
* \warning This API is **experimental** and subject to change.
*
* \param[out] rng_state_ptr A size 2 array storing the RNG's seed and offset respectively.
* \param[in] captured Whether a CUDA graph is being captured.
* \param[in] seed_ptr Seed pointer.
* \param[in] seed_val Seed value.
* \param[in] offset_ptr Offset pointer.
* \param[in] offset_val Offset value.
* \param[in] offset_intragraph Intragraph offset in RNG states. For use with CUDA Graphs.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_extract_seed_and_offset
(
int64_t
*
rng_state_ptr
,
int
captured
,
int64_t
*
seed_ptr
,
void
nvte_extract_seed_and_offset
(
int64_t
*
rng_state_ptr
,
int
captured
,
int64_t
*
seed_ptr
,
uint64_t
seed_val
,
int64_t
*
offset_ptr
,
uint64_t
offset_val
,
uint64_t
seed_val
,
int64_t
*
offset_ptr
,
uint64_t
offset_val
,
uint32_t
offset_intragraph
,
cudaStream_t
stream
);
uint32_t
offset_intragraph
,
cudaStream_t
stream
);
/*! \brief Copy keys and values into the KV cache.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] new_k Key tensor.
* \param[in] new_v Value tensor.
* \param[out] k_cache Key cache.
* \param[out] v_cache Value cache.
* \param[in] page_table Page table for K cache, [batch_size, max_pages_per_seq].
* \param[in] cu_new_lens Cumulative sequence lengths.
* \param[in] cu_cached_lens Cached cumulative sequence lengths.
* \param[in] qkv_format QKV format, e.g. sbhd.
* \param[in] b Batch size.
* \param[in] max_ctx_len Maximum context length.
* \param[in] max_seq_len Maximum sequence length.
* \param[in] max_pages_per_seq Maximum number of pages per sequence.
* \param[in] is_non_paged Whether the cache is paged or not.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_copy_to_kv_cache
(
NVTETensor
new_k
,
NVTETensor
new_v
,
NVTETensor
k_cache
,
void
nvte_copy_to_kv_cache
(
NVTETensor
new_k
,
NVTETensor
new_v
,
NVTETensor
k_cache
,
NVTETensor
v_cache
,
NVTETensor
page_table
,
NVTETensor
cu_new_lens
,
NVTETensor
v_cache
,
NVTETensor
page_table
,
NVTETensor
cu_new_lens
,
NVTETensor
cu_cached_lens
,
NVTE_QKV_Format
qkv_format
,
int
b
,
NVTETensor
cu_cached_lens
,
NVTE_QKV_Format
qkv_format
,
int
b
,
int
max_ctx_len
,
int
max_seq_len
,
int
max_pages_per_seq
,
int
max_ctx_len
,
int
max_seq_len
,
int
max_pages_per_seq
,
int
is_non_paged
,
cudaStream_t
stream
);
int
is_non_paged
,
cudaStream_t
stream
);
/*! \brief Extract the first half (half_idx=0) or second half (half_idx=1) of a THD tensor.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] tensor Input tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[out] half Output tensor.
* \param[in] half_idx Whether to read first or second half of input tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_cp_thd_read_half_tensor
(
const
NVTETensor
&
tensor
,
const
NVTETensor
&
cu_seqlens
,
void
nvte_cp_thd_read_half_tensor
(
const
NVTETensor
&
tensor
,
const
NVTETensor
&
cu_seqlens
,
NVTETensor
half
,
int
half_idx
,
cudaStream_t
stream
);
NVTETensor
half
,
int
half_idx
,
cudaStream_t
stream
);
/*! \brief Correct the second half of the softmax LSE (LogSumExp) for context parallelism.
*
* \warning This API is **experimental** and subject to change.
*
* \param[out] lse Output tensor.
* \param[in] lse_per_step Input tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] lse_packed Whether or not lse_per_step is packed.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_cp_thd_second_half_lse_correction
(
NVTETensor
lse
,
const
NVTETensor
&
lse_per_step
,
void
nvte_cp_thd_second_half_lse_correction
(
NVTETensor
lse
,
const
NVTETensor
&
lse_per_step
,
const
NVTETensor
&
cu_seqlens
,
int
lse_packed
,
const
NVTETensor
&
cu_seqlens
,
int
lse_packed
,
cudaStream_t
stream
);
cudaStream_t
stream
);
/*! \brief Read the second half of the softmax LSE (LogSumExp) for context parallelism.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] lse Input tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[out] half_lse Output tensor.
* \param[in] lse_packed Whether or the softmax LSE is in packed format.
* \param[in] second_half_lse_seqlen Sequence length.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_cp_thd_read_second_half_lse
(
const
NVTETensor
&
lse
,
const
NVTETensor
&
cu_seqlens
,
void
nvte_cp_thd_read_second_half_lse
(
const
NVTETensor
&
lse
,
const
NVTETensor
&
cu_seqlens
,
NVTETensor
half_lse
,
int
lse_packed
,
NVTETensor
half_lse
,
int
lse_packed
,
int
second_half_lse_seqlen
,
cudaStream_t
stream
);
int
second_half_lse_seqlen
,
cudaStream_t
stream
);
/*! \brief Correct the THD format output of context parallelism in forward pass.
*
* \warning This API is **experimental** and subject to change.
*
* \param[out] out Output tensor.
* \param[in] out_per_step THD format output of context parallelism in forward pass.
* \param[in] lse Softmax LSE.
* \param[in] lse_per_step Softmax LSE per step.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] only_second_half Whether or not to correct only second half.
* \param[in] lse_packed Whether or the softmax LSE is in packed format.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_cp_thd_out_correction
(
NVTETensor
out
,
const
NVTETensor
&
out_per_step
,
void
nvte_cp_thd_out_correction
(
NVTETensor
out
,
const
NVTETensor
&
out_per_step
,
const
NVTETensor
&
lse
,
const
NVTETensor
&
lse_per_step
,
const
NVTETensor
&
lse
,
const
NVTETensor
&
lse_per_step
,
const
NVTETensor
&
cu_seqlens
,
int
only_second_half
,
int
lse_packed
,
const
NVTETensor
&
cu_seqlens
,
int
only_second_half
,
int
lse_packed
,
cudaStream_t
stream
);
cudaStream_t
stream
);
/*! \brief Correct the THD format output of context parallelism in forward pass.
*
* \warning This API is **experimental** and subject to change.
*
* \param[out] grad Output tensor.
* \param[in] grad_per_step THD format gradient of context parallelism.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] first_half One of ("add", "copy", "none") correction op for first half.
* \param[in] second_half One of ("add", "copy", "none") correction op for second half.
Must be different from first_half.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_cp_thd_grad_correction
(
NVTETensor
grad
,
const
NVTETensor
&
grad_per_step
,
void
nvte_cp_thd_grad_correction
(
NVTETensor
grad
,
const
NVTETensor
&
grad_per_step
,
const
NVTETensor
&
cu_seqlens
,
const
char
*
first_half
,
const
NVTETensor
&
cu_seqlens
,
const
char
*
first_half
,
const
char
*
second_half
,
cudaStream_t
stream
);
const
char
*
second_half
,
cudaStream_t
stream
);
/*! \brief Generate partitioned indices for inputs in THD format.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[out] output Output tensor.
* \param[in] total_tokens Total number of tokens.
* \param[in] world_size Total number of devices for context parallelism.
* \param[in] rank Device ID for current device.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_cp_thd_get_partitioned_indices
(
const
NVTETensor
&
cu_seqlens
,
NVTETensor
output
,
void
nvte_cp_thd_get_partitioned_indices
(
const
NVTETensor
&
cu_seqlens
,
NVTETensor
output
,
int
total_tokens
,
int
world_size
,
int
rank
,
int
total_tokens
,
int
world_size
,
int
rank
,
cudaStream_t
stream
);
cudaStream_t
stream
);
/*! \brief Convert tensor from THD to BSHD format.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] tensor Input tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[out] new_tensor Output tensor.
* \param[in] b Batch size.
* \param[in] max_seq_len Maximum sequence length.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_convert_thd_to_bshd
(
NVTETensor
tensor
,
NVTETensor
cu_seqlens
,
NVTETensor
new_tensor
,
void
nvte_convert_thd_to_bshd
(
NVTETensor
tensor
,
NVTETensor
cu_seqlens
,
NVTETensor
new_tensor
,
int
b
,
int
max_seq_len
,
cudaStream_t
stream
);
int
b
,
int
max_seq_len
,
cudaStream_t
stream
);
/*! \brief Convert tensor from BSHD to THD format.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] tensor Input tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[out] new_tensor Output tensor.
* \param[in] b Batch size.
* \param[in] max_seq_len Maximum sequence length.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_convert_bshd_to_thd
(
NVTETensor
tensor
,
NVTETensor
cu_seqlens
,
NVTETensor
new_tensor
,
void
nvte_convert_bshd_to_thd
(
NVTETensor
tensor
,
NVTETensor
cu_seqlens
,
NVTETensor
new_tensor
,
int
t
,
cudaStream_t
stream
);
int
t
,
cudaStream_t
stream
);
/*! \brief Prepare QKV tensor for Flash Attention forward kernel.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] qkvi Input tensor.
* \param[out] qkv Output tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_prepare_flash_attn_fwd
(
NVTETensor
qkvi
,
NVTETensor
qkv
,
cudaStream_t
stream
);
void
nvte_prepare_flash_attn_fwd
(
NVTETensor
qkvi
,
NVTETensor
qkv
,
cudaStream_t
stream
);
/*! \brief Prepare QKV tensor for Flash Attention backward kernel.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] q Input query tensor.
* \param[in] k Input key tensor.
* \param[in] v Input value tensor.
* \param[out] qkv Output tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_prepare_flash_attn_bwd
(
NVTETensor
q
,
NVTETensor
k
,
NVTETensor
v
,
NVTETensor
qkv
,
void
nvte_prepare_flash_attn_bwd
(
NVTETensor
q
,
NVTETensor
k
,
NVTETensor
v
,
NVTETensor
qkv
,
cudaStream_t
stream
);
cudaStream_t
stream
);
...
...
transformer_engine/common/include/transformer_engine/gemm.h
View file @
2b05e121
...
@@ -132,12 +132,8 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
...
@@ -132,12 +132,8 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
*/
*/
namespace
transformer_engine
{
namespace
transformer_engine
{
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
// In dcu, 2 stream is more better
constexpr
int
num_streams
=
2
;
// Add for batchgemm stream
// Add for batchgemm stream
constexpr
int
num_batchgemm_streams
=
1
;
constexpr
int
num_batchgemm_streams
=
1
;
#else
constexpr
int
num_streams
=
4
;
#endif
#endif
/*! \brief TE/JAX cudaGraph requires the cuBLAS initialization to happen outside of the capturing
/*! \brief TE/JAX cudaGraph requires the cuBLAS initialization to happen outside of the capturing
...
...
transformer_engine/common/include/transformer_engine/multi_stream.h
0 → 100644
View file @
2b05e121
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file multi_stream.h
* \brief Functions for multi streams executions.
*/
#ifndef TRANSFORMER_ENGINE_MULTI_STREAM_H
#define TRANSFORMER_ENGINE_MULTI_STREAM_H
#ifdef __cplusplus
extern
"C"
{
#endif
/*! \brief Number of CUDA streams to use in multi-stream operations */
int
nvte_get_num_compute_streams
();
#ifdef __cplusplus
}
// extern "C"
#endif
#endif // TRANSFORMER_ENGINE_MULTI_STREAM_H
transformer_engine/common/include/transformer_engine/multi_tensor.h
View file @
2b05e121
...
@@ -17,6 +17,25 @@
...
@@ -17,6 +17,25 @@
extern
"C"
{
extern
"C"
{
#endif
#endif
/*! \brief Computes L2 norm for a list of tensors.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] output Scratch space. Required size grows with number of inputs.
* \param[in] output_per_tensor Fixed size auxilliary scratch space.
* \param[out] ret L2 norm of all inputs.
* \param[out] ret_per_tensor L2 norm for each tensor.
* \param[in] per_tensor Whether to calculate per tensor or cumulative norm.
* \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_multi_tensor_l2norm_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
void
nvte_multi_tensor_l2norm_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
NVTETensor
output
,
NVTETensor
output_per_tensor
,
NVTETensor
ret
,
NVTETensor
output
,
NVTETensor
output_per_tensor
,
NVTETensor
ret
,
...
@@ -24,6 +43,28 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
...
@@ -24,6 +43,28 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
int
max_chunks_per_tensor
,
const
int
device_id
,
int
max_chunks_per_tensor
,
const
int
device_id
,
cudaStream_t
stream
);
cudaStream_t
stream
);
/*! \brief Computes L2 norm for a list of tensors after unscaling.
*
* Unscaling is only done for computing the L2 norm. The tensors themselves are not updated.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] output Scratch space. Required size grows with number of inputs.
* \param[in] output_per_tensor Fixed size auxilliary scratch space.
* \param[out] ret L2 norm of all inputs.
* \param[out] ret_per_tensor L2 norm for each tensor.
* \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] per_tensor Whether to calculate per tensor or cumulative norm.
* \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_multi_tensor_unscale_l2norm_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
void
nvte_multi_tensor_unscale_l2norm_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
NVTETensor
output
,
const
size_t
num_tensors_per_list
,
NVTETensor
output
,
...
@@ -32,6 +73,27 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
...
@@ -32,6 +73,27 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
int
per_tensor
,
int
max_chunks_per_tensor
,
int
per_tensor
,
int
max_chunks_per_tensor
,
const
int
device_id
,
cudaStream_t
stream
);
const
int
device_id
,
cudaStream_t
stream
);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] lr Learning rate.
* \param[in] beta1 Coefficient for first moment of gradient.
* \param[in] beta2 Coefficient for second moment of gradient.
* \param[in] epsilon Term added to the denominator for numerical stability.
* \param[in] step Iteration counter.
* \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_multi_tensor_adam_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
void
nvte_multi_tensor_adam_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
...
@@ -39,12 +101,57 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
...
@@ -39,12 +101,57 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
const
int
bias_correction
,
const
float
weight_decay
,
const
int
bias_correction
,
const
float
weight_decay
,
const
int
device_id
,
cudaStream_t
stream
);
const
int
device_id
,
cudaStream_t
stream
);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* where the master parameters only store the remainder bits.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] lr Learning rate.
* \param[in] beta1 Coefficient for first moment of gradient.
* \param[in] beta2 Coefficient for second moment of gradient.
* \param[in] epsilon Term added to the denominator for numerical stability.
* \param[in] step Iteration counter.
* \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_multi_tensor_adam_param_remainder_cuda
(
void
nvte_multi_tensor_adam_param_remainder_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
size_t
num_tensors_per_list
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
const
int
device_id
,
cudaStream_t
stream
);
const
float
weight_decay
,
const
int
device_id
,
cudaStream_t
stream
);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* when model parameters are in Float8 precision.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] lr Learning rate.
* \param[in] beta1 Coefficient for first moment of gradient.
* \param[in] beta2 Coefficient for second moment of gradient.
* \param[in] epsilon Term added to the denominator for numerical stability.
* \param[in] step Iteration counter.
* \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] fp8_dtype FP8 data type for model parameters.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_multi_tensor_adam_fp8_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
void
nvte_multi_tensor_adam_fp8_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
const
float
lr
,
const
size_t
num_tensors_per_list
,
const
float
lr
,
...
@@ -53,28 +160,125 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
...
@@ -53,28 +160,125 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
const
float
weight_decay
,
const
NVTEDType
fp8_dtype
,
const
float
weight_decay
,
const
NVTEDType
fp8_dtype
,
const
int
device_id
,
cudaStream_t
stream
);
const
int
device_id
,
cudaStream_t
stream
);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* with CUDA graph support and LR scheduling.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] lr Learning rate.
* \param[in] beta1 Coefficient for first moment of gradient.
* \param[in] beta2 Coefficient for second moment of gradient.
* \param[in] epsilon Term added to the denominator for numerical stability.
* \param[in] step Iteration counter.
* \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_multi_tensor_adam_capturable_cuda
(
void
nvte_multi_tensor_adam_capturable_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
NVTETensor
lr
,
const
float
beta1
,
const
float
beta2
,
const
size_t
num_tensors_per_list
,
NVTETensor
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
NVTETensor
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
epsilon
,
NVTETensor
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
NVTETensor
inv_scale
,
const
int
device_id
,
cudaStream_t
stream
);
const
float
weight_decay
,
NVTETensor
inv_scale
,
const
int
device_id
,
cudaStream_t
stream
);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* with CUDA graph support, LR scheduling, and FP32 master weights.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] lr Learning rate.
* \param[in] beta1 Coefficient for first moment of gradient.
* \param[in] beta2 Coefficient for second moment of gradient.
* \param[in] epsilon Term added to the denominator for numerical stability.
* \param[in] step Iteration counter.
* \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_multi_tensor_adam_capturable_master_cuda
(
void
nvte_multi_tensor_adam_capturable_master_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
NVTETensor
lr
,
const
float
beta1
,
const
float
beta2
,
const
size_t
num_tensors_per_list
,
NVTETensor
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
NVTETensor
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
epsilon
,
NVTETensor
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
NVTETensor
inv_scale
,
const
int
device_id
,
cudaStream_t
stream
);
const
float
weight_decay
,
NVTETensor
inv_scale
,
const
int
device_id
,
cudaStream_t
stream
);
/*! \brief Compute and apply gradient update to parameters for SGD optimizer.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] wd Weight decay (L2 penalty).
* \param[in] momentum Momentum factor.
* \param[in] dampening Dampening factor.
* \param[in] lr Learning rate.
* \param[in] nesterov Whether or not to enable nesterov momentum.
* \param[in] first_run Whether momentum buffers have been initialized.
* \param[in] wd_after_momentum Whether to applied weight decay after momentum update.
* \param[in] scale Scalar for the scaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_multi_tensor_sgd_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
void
nvte_multi_tensor_sgd_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
float
wd
,
float
momentum
,
float
dampening
,
float
lr
,
int
nesterov
,
float
wd
,
float
momentum
,
float
dampening
,
float
lr
,
int
nesterov
,
int
first_run
,
int
wd_after_momentum
,
float
scale
,
int
first_run
,
int
wd_after_momentum
,
float
scale
,
const
int
device_id
,
cudaStream_t
stream
);
const
int
device_id
,
cudaStream_t
stream
);
/*! \brief Check overflow and scale a list of tensors.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] scale Scalar for the scaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_multi_tensor_scale_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
void
nvte_multi_tensor_scale_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
float
scale
,
const
int
device_id
,
cudaStream_t
stream
);
float
scale
,
const
int
device_id
,
cudaStream_t
stream
);
/*! \brief Check overflow and scale a list of tensors.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] max_fp8 Maximum representible value in underlying FP8 format.
* \param[in] force_pow_2_scales Ensure scaling factors are a power of 2.
* \param[in] epsilon Term added to the denominator for numerical stability.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_multi_tensor_compute_scale_and_scale_inv_cuda
(
void
nvte_multi_tensor_compute_scale_and_scale_inv_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
float
max_fp8
,
int
force_pow_2_scales
,
float
epsilon
,
const
size_t
num_tensors_per_list
,
float
max_fp8
,
int
force_pow_2_scales
,
float
epsilon
,
...
...
transformer_engine/common/include/transformer_engine/transformer_engine.h
View file @
2b05e121
...
@@ -22,17 +22,18 @@ extern "C" {
...
@@ -22,17 +22,18 @@ extern "C" {
* \brief TE datatype.
* \brief TE datatype.
*/
*/
enum
NVTEDType
{
enum
NVTEDType
{
kNVTEByte
=
0
,
/*!< Byte */
kNVTEByte
=
0
,
/*!< Byte */
kNVTEInt16
=
1
,
/*!< 16-bit integer */
kNVTEInt16
=
1
,
/*!< 16-bit integer */
kNVTEInt32
=
2
,
/*!< 32-bit integer */
kNVTEInt32
=
2
,
/*!< 32-bit integer */
kNVTEInt64
=
3
,
/*!< 64-bit integer */
kNVTEInt64
=
3
,
/*!< 64-bit integer */
kNVTEFloat32
=
4
,
/*!< 32-bit float */
kNVTEFloat32
=
4
,
/*!< 32-bit float */
kNVTEFloat16
=
5
,
/*!< 16-bit float (E5M10) */
kNVTEFloat16
=
5
,
/*!< 16-bit float (E5M10) */
kNVTEBFloat16
=
6
,
/*!< 16-bit bfloat (E8M7) */
kNVTEBFloat16
=
6
,
/*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E4M3
=
7
,
/*!< 8-bit float (E4M3) */
kNVTEFloat8E4M3
=
7
,
/*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2
=
8
,
/*!< 8-bit float (E5M2) */
kNVTEFloat8E5M2
=
8
,
/*!< 8-bit float (E5M2) */
kNVTEFloat8E8M0
=
9
,
/*!< 8-bit float (E8M0) */
kNVTEFloat8E8M0
=
9
,
/*!< 8-bit float (E8M0) */
kNVTENumTypes
/*!< Number of supported types */
kNVTEFloat4E2M1
=
10
,
/*!< 4-bit float (E2M1) */
kNVTENumTypes
/*!< Number of supported types */
};
};
/*! \struct NVTEShape
/*! \struct NVTEShape
...
@@ -87,6 +88,10 @@ enum NVTEScalingMode {
...
@@ -87,6 +88,10 @@ enum NVTEScalingMode {
*/
*/
NVTE_BLOCK_SCALING_1D
=
2
,
NVTE_BLOCK_SCALING_1D
=
2
,
NVTE_BLOCK_SCALING_2D
=
3
,
NVTE_BLOCK_SCALING_2D
=
3
,
/*! Single NVFP4 scale per block of 16 contiguous elements in forward pass (FWD),
and single MXFP8 scale per block of 32 contiguous elements in backward pass (BWD).
*/
NVTE_FWD_NVFP4_BWD_MXFP8_SCALING
=
4
,
NVTE_INVALID_SCALING
=
100
NVTE_INVALID_SCALING
=
100
};
};
...
@@ -177,6 +182,14 @@ size_t nvte_tensor_ndims(const NVTETensor tensor);
...
@@ -177,6 +182,14 @@ size_t nvte_tensor_ndims(const NVTETensor tensor);
*/
*/
size_t
nvte_tensor_size
(
const
NVTETensor
tensor
,
const
size_t
dim
);
size_t
nvte_tensor_size
(
const
NVTETensor
tensor
,
const
size_t
dim
);
/*! \brief Get the byte size for the tensor.
*
* \param[in] tensor Tensor.
*
* \return Byte size of the tensor.
*/
size_t
nvte_tensor_size_bytes
(
const
NVTETensor
tensor
);
/*! \brief Get a tensor's total number of elements.
/*! \brief Get a tensor's total number of elements.
*
*
* \param[in] tensor Tensor.
* \param[in] tensor Tensor.
...
@@ -193,6 +206,14 @@ size_t nvte_tensor_numel(const NVTETensor tensor);
...
@@ -193,6 +206,14 @@ size_t nvte_tensor_numel(const NVTETensor tensor);
*/
*/
size_t
nvte_tensor_element_size
(
const
NVTETensor
tensor
);
size_t
nvte_tensor_element_size
(
const
NVTETensor
tensor
);
/*! \brief Get the bit size for the tensor's data type.
*
* \param[in] tensor Tensor.
*
* \return Bit size of the tensor's data type.
*/
size_t
nvte_tensor_element_size_bits
(
const
NVTETensor
tensor
);
/*! \brief Get a tensor's data type.
/*! \brief Get a tensor's data type.
*
*
* \param[in] tensor Tensor.
* \param[in] tensor Tensor.
...
@@ -302,6 +323,13 @@ enum NVTEQuantizationConfigAttribute {
...
@@ -302,6 +323,13 @@ enum NVTEQuantizationConfigAttribute {
conditional early even when captured in a static CUDA graph.
conditional early even when captured in a static CUDA graph.
*/
*/
kNVTEQuantizationConfigNoopTensor
=
2
,
kNVTEQuantizationConfigNoopTensor
=
2
,
/*! Data format for an FP8 block-scaled tensor
*
* This is not the right design since the tensor format is a
* property of the tensor, not the quantization. This enum will
* likely be refactored away in the future.
*/
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat
=
3
,
kNVTEQuantizationConfigNumAttributes
kNVTEQuantizationConfigNumAttributes
};
};
...
@@ -383,7 +411,8 @@ enum class DType {
...
@@ -383,7 +411,8 @@ enum class DType {
kFloat8E4M3
=
7
,
kFloat8E4M3
=
7
,
kFloat8E5M2
=
8
,
kFloat8E5M2
=
8
,
kFloat8E8M0
=
9
,
kFloat8E8M0
=
9
,
kInt8
=
10
,
kFloat4E2M1
=
10
,
kInt8
=
11
,
kNumTypes
kNumTypes
};
};
...
@@ -392,7 +421,16 @@ enum class DType {
...
@@ -392,7 +421,16 @@ enum class DType {
* Return true if TE datatype is FP8
* Return true if TE datatype is FP8
* \param[in] DType TE Datatype of interest
* \param[in] DType TE Datatype of interest
*/
*/
bool
is_fp8_dtype
(
const
DType
t
);
inline
bool
is_fp8_dtype
(
const
DType
t
)
{
return
t
==
DType
::
kFloat8E4M3
||
t
==
DType
::
kFloat8E5M2
;
}
/*! \brief Check if TE datatype is FP4
*
* Return true if TE datatype is FP4
* \param[in] DType TE Datatype of interest
*/
inline
bool
is_fp4_dtype
(
const
DType
t
)
{
return
t
==
DType
::
kFloat4E2M1
;
}
/*! \struct TensorWrapper
/*! \struct TensorWrapper
* \brief C++ wrapper for the NVTETensor class.
* \brief C++ wrapper for the NVTETensor class.
...
@@ -621,6 +659,15 @@ class TensorWrapper {
...
@@ -621,6 +659,15 @@ class TensorWrapper {
return
nvte_tensor_element_size
(
tensor_
);
return
nvte_tensor_element_size
(
tensor_
);
}
}
/*! \brief Get the tensor's element size in bits.
*
* \return Element size in bits.
*/
size_t
element_size_bits
()
const
noexcept
{
if
(
tensor_
==
nullptr
)
return
0
;
return
nvte_tensor_element_size_bits
(
tensor_
);
}
/*! \brief Get the tensor's allocated size in bytes. This will return 0 for tensors with nullptr
/*! \brief Get the tensor's allocated size in bytes. This will return 0 for tensors with nullptr
* data even if the TensorWrapper has a non-zero shape and valid dtype.
* data even if the TensorWrapper has a non-zero shape and valid dtype.
*
*
...
@@ -628,7 +675,7 @@ class TensorWrapper {
...
@@ -628,7 +675,7 @@ class TensorWrapper {
*/
*/
size_t
bytes
()
const
noexcept
{
size_t
bytes
()
const
noexcept
{
if
(
tensor_
==
nullptr
||
this
->
dptr
()
==
nullptr
)
return
0
;
if
(
tensor_
==
nullptr
||
this
->
dptr
()
==
nullptr
)
return
0
;
return
nvte_tensor_
numel
(
tensor_
)
*
nvte_tensor_element_size
(
tensor_
);
return
nvte_tensor_
size_bytes
(
tensor_
);
}
}
/*! \brief Get the data type of this TensorWrapper.
/*! \brief Get the data type of this TensorWrapper.
...
@@ -722,6 +769,16 @@ class TensorWrapper {
...
@@ -722,6 +769,16 @@ class TensorWrapper {
NVTETensor
tensor_
=
nullptr
;
NVTETensor
tensor_
=
nullptr
;
};
};
/*! \enum Float8BlockScaleTensorFormat
* \brief Data format for an FP8 block-scaled tensor
*/
enum
class
Float8BlockScaleTensorFormat
{
/*! FP8 data is transposed if needed and scales are swizzled */
GEMM_READY
=
0
,
/*! FP8 data is untransposed and scales are not swizzled or padded */
COMPACT
=
1
};
/*! \struct QuantizationConfigWrapper
/*! \struct QuantizationConfigWrapper
* \brief C++ wrapper for NVTEQuantizationConfigWrapper.
* \brief C++ wrapper for NVTEQuantizationConfigWrapper.
*/
*/
...
@@ -775,6 +832,13 @@ class QuantizationConfigWrapper {
...
@@ -775,6 +832,13 @@ class QuantizationConfigWrapper {
sizeof
(
NVTETensor
));
sizeof
(
NVTETensor
));
}
}
/*! \brief Set FP8 block-scaled tensor format */
void
set_float8_block_scale_tensor_format
(
Float8BlockScaleTensorFormat
format
)
{
nvte_set_quantization_config_attribute
(
config_
,
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat
,
&
format
,
sizeof
(
Float8BlockScaleTensorFormat
));
}
private:
private:
/*! \brief Wrapped NVTEQuantizationConfig. */
/*! \brief Wrapped NVTEQuantizationConfig. */
NVTEQuantizationConfig
config_
=
nullptr
;
NVTEQuantizationConfig
config_
=
nullptr
;
...
...
transformer_engine/common/multi_tensor/adam.cu
View file @
2b05e121
...
@@ -807,7 +807,7 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
...
@@ -807,7 +807,7 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
multi_tensor_adam
::
multi_tensor_adam_cuda
(
multi_tensor_adam
::
multi_tensor_adam_cuda
(
chunk_size
,
*
reinterpret_cast
<
Tensor
*>
(
noop_flag
),
chunk_size
,
*
convertNVTE
Tensor
Check
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
lr
,
beta1
,
beta2
,
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
lr
,
beta1
,
beta2
,
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
device_id
,
stream
);
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
device_id
,
stream
);
}
}
...
@@ -821,7 +821,7 @@ void nvte_multi_tensor_adam_param_remainder_cuda(
...
@@ -821,7 +821,7 @@ void nvte_multi_tensor_adam_param_remainder_cuda(
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
multi_tensor_adam
::
multi_tensor_adam_param_remainder_cuda
(
multi_tensor_adam
::
multi_tensor_adam_param_remainder_cuda
(
chunk_size
,
*
reinterpret_cast
<
Tensor
*>
(
noop_flag
),
chunk_size
,
*
convertNVTE
Tensor
Check
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
lr
,
beta1
,
beta2
,
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
lr
,
beta1
,
beta2
,
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
device_id
,
stream
);
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
device_id
,
stream
);
}
}
...
@@ -837,7 +837,7 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
...
@@ -837,7 +837,7 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
multi_tensor_adam
::
multi_tensor_adam_fp8_cuda
(
multi_tensor_adam
::
multi_tensor_adam_fp8_cuda
(
chunk_size
,
*
reinterpret_cast
<
Tensor
*>
(
noop_flag
),
chunk_size
,
*
convertNVTE
Tensor
Check
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
lr
,
beta1
,
beta2
,
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
lr
,
beta1
,
beta2
,
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
static_cast
<
DType
>
(
fp8_dtype
),
device_id
,
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
static_cast
<
DType
>
(
fp8_dtype
),
device_id
,
stream
);
stream
);
...
@@ -852,11 +852,10 @@ void nvte_multi_tensor_adam_capturable_cuda(
...
@@ -852,11 +852,10 @@ void nvte_multi_tensor_adam_capturable_cuda(
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
multi_tensor_adam
::
multi_tensor_adam_capturable_cuda
(
multi_tensor_adam
::
multi_tensor_adam_capturable_cuda
(
chunk_size
,
*
reinterpret_cast
<
Tensor
*>
(
noop_flag
),
chunk_size
,
*
convertNVTE
Tensor
Check
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
*
reinterpret_cast
<
Tensor
*>
(
lr
),
beta1
,
beta2
,
epsilon
,
*
reinterpret_cast
<
Tensor
*>
(
step
),
*
convertNVTETensorCheck
(
lr
),
beta1
,
beta2
,
epsilon
,
*
convertNVTETensorCheck
(
step
),
mode
,
mode
,
bias_correction
,
weight_decay
,
*
reinterpret_cast
<
Tensor
*>
(
inv_scale
),
device_id
,
bias_correction
,
weight_decay
,
*
convertNVTETensorCheck
(
inv_scale
),
device_id
,
stream
);
stream
);
}
}
void
nvte_multi_tensor_adam_capturable_master_cuda
(
void
nvte_multi_tensor_adam_capturable_master_cuda
(
...
@@ -868,9 +867,8 @@ void nvte_multi_tensor_adam_capturable_master_cuda(
...
@@ -868,9 +867,8 @@ void nvte_multi_tensor_adam_capturable_master_cuda(
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
multi_tensor_adam
::
multi_tensor_adam_capturable_master_cuda
(
multi_tensor_adam
::
multi_tensor_adam_capturable_master_cuda
(
chunk_size
,
*
reinterpret_cast
<
Tensor
*>
(
noop_flag
),
chunk_size
,
*
convertNVTE
Tensor
Check
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
*
reinterpret_cast
<
Tensor
*>
(
lr
),
beta1
,
beta2
,
epsilon
,
*
reinterpret_cast
<
Tensor
*>
(
step
),
*
convertNVTETensorCheck
(
lr
),
beta1
,
beta2
,
epsilon
,
*
convertNVTETensorCheck
(
step
),
mode
,
mode
,
bias_correction
,
weight_decay
,
*
reinterpret_cast
<
Tensor
*>
(
inv_scale
),
device_id
,
bias_correction
,
weight_decay
,
*
convertNVTETensorCheck
(
inv_scale
),
device_id
,
stream
);
stream
);
}
}
Prev
1
2
3
4
5
6
7
8
9
10
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment