Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
f8c2af4c
Commit
f8c2af4c
authored
May 21, 2025
by
yuguo
Browse files
Merge commit '
1d903f5e
' of...
Merge commit '
1d903f5e
' of
https://github.com/NVIDIA/TransformerEngine
parents
e92773a3
1d903f5e
Changes
211
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1227 additions
and
455 deletions
+1227
-455
transformer_engine/common/fused_attn/utils.cu
transformer_engine/common/fused_attn/utils.cu
+70
-0
transformer_engine/common/fused_attn/utils.h
transformer_engine/common/fused_attn/utils.h
+32
-0
transformer_engine/common/fused_rope/fused_rope.cu
transformer_engine/common/fused_rope/fused_rope.cu
+31
-26
transformer_engine/common/include/transformer_engine/fused_attn.h
...mer_engine/common/include/transformer_engine/fused_attn.h
+77
-8
transformer_engine/common/include/transformer_engine/fused_rope.h
...mer_engine/common/include/transformer_engine/fused_rope.h
+7
-6
transformer_engine/common/include/transformer_engine/multi_tensor.h
...r_engine/common/include/transformer_engine/multi_tensor.h
+87
-0
transformer_engine/common/include/transformer_engine/permutation.h
...er_engine/common/include/transformer_engine/permutation.h
+3
-0
transformer_engine/common/include/transformer_engine/recipe.h
...sformer_engine/common/include/transformer_engine/recipe.h
+11
-0
transformer_engine/common/include/transformer_engine/transformer_engine.h
...ne/common/include/transformer_engine/transformer_engine.h
+39
-27
transformer_engine/common/multi_tensor/adam.cu
transformer_engine/common/multi_tensor/adam.cu
+203
-128
transformer_engine/common/multi_tensor/compute_scale.cu
transformer_engine/common/multi_tensor/compute_scale.cu
+34
-19
transformer_engine/common/multi_tensor/l2norm.cu
transformer_engine/common/multi_tensor/l2norm.cu
+178
-97
transformer_engine/common/multi_tensor/multi_tensor_apply.cuh
...sformer_engine/common/multi_tensor/multi_tensor_apply.cuh
+71
-48
transformer_engine/common/multi_tensor/scale.cu
transformer_engine/common/multi_tensor/scale.cu
+35
-25
transformer_engine/common/multi_tensor/sgd.cu
transformer_engine/common/multi_tensor/sgd.cu
+59
-59
transformer_engine/common/permutation/permutation.cu
transformer_engine/common/permutation/permutation.cu
+10
-0
transformer_engine/common/recipe/fp8_block_scaling.cu
transformer_engine/common/recipe/fp8_block_scaling.cu
+245
-0
transformer_engine/common/transformer_engine.cpp
transformer_engine/common/transformer_engine.cpp
+18
-9
transformer_engine/common/util/cuda_runtime.cpp
transformer_engine/common/util/cuda_runtime.cpp
+9
-3
transformer_engine/common/util/logging.h
transformer_engine/common/util/logging.h
+8
-0
No files found.
transformer_engine/common/fused_attn/utils.cu
View file @
f8c2af4c
...
...
@@ -562,5 +562,75 @@ size_t get_max_tokens(size_t num_tokens) {
return
max_t
;
}
__global__
void
populate_rng_state_kernel
(
int64_t
*
rng_state_dst
,
const
int64_t
*
const
seed
,
int64_t
offset
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
>
0
)
return
;
rng_state_dst
[
0
]
=
seed
[
0
];
rng_state_dst
[
1
]
=
offset
;
}
__global__
void
get_runtime_num_segments_kernel
(
int32_t
*
cu_seqlen
,
size_t
len
,
uint32_t
*
out
)
{
int
tid
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
tid
>=
len
)
return
;
if
(
cu_seqlen
[
tid
]
>
0
)
{
// atomicAdd only support 32 bits dtype
atomicAdd
(
out
,
1
);
}
}
void
PopulateRngStateAsync
(
void
*
rng_state_dst
,
const
void
*
seed
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
NVTE_Fused_Attn_Backend
backend
,
cudaStream_t
stream
)
{
size_t
increment
=
0
;
if
(
backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_arbitrary_seqlen
)
{
increment
=
16
;
}
else
{
constexpr
int
threads_per_cta
=
128
;
increment
=
(
q_max_seqlen
*
kv_max_seqlen
+
threads_per_cta
-
1
)
/
threads_per_cta
;
}
auto
offset
=
FusedAttnOffsetManager
::
Instance
().
GetAndUpdateOffset
(
increment
);
populate_rng_state_kernel
<<<
1
,
1
,
0
,
stream
>>>
(
reinterpret_cast
<
int64_t
*>
(
rng_state_dst
),
reinterpret_cast
<
const
int64_t
*>
(
seed
),
offset
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
uint32_t
GetRuntimeNumSegments
(
void
*
cu_seqlen
,
void
*
workspace
,
size_t
len
,
cudaStream_t
stream
)
{
// workspace size requires 4 bytes
uint32_t
*
dout
=
static_cast
<
uint32_t
*>
(
workspace
);
uint32_t
hout
{};
cudaMemsetAsync
(
dout
,
0
,
sizeof
(
uint32_t
),
stream
);
constexpr
int
threads
=
128
;
const
int
blocks
=
(
len
-
1
)
/
threads
+
1
;
get_runtime_num_segments_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
static_cast
<
int32_t
*>
(
cu_seqlen
),
len
,
dout
);
cudaMemcpyAsync
(
&
hout
,
dout
,
sizeof
(
uint32_t
),
cudaMemcpyDeviceToHost
,
stream
);
cudaStreamSynchronize
(
stream
);
return
hout
;
}
__global__
void
extract_seed_and_offset
(
int64_t
*
rng_state_ptr
,
bool
captured
,
int64_t
*
seed_ptr
,
uint64_t
seed_val
,
int64_t
*
offset_ptr
,
uint64_t
offset_val
,
uint32_t
offset_intragraph
)
{
if
(
captured
)
{
rng_state_ptr
[
0
]
=
*
seed_ptr
;
rng_state_ptr
[
1
]
=
static_cast
<
int64_t
>
(
*
offset_ptr
+
static_cast
<
int64_t
>
(
offset_intragraph
));
}
else
{
rng_state_ptr
[
0
]
=
static_cast
<
int64_t
>
(
seed_val
);
rng_state_ptr
[
1
]
=
static_cast
<
int64_t
>
(
offset_val
);
}
}
}
// namespace fused_attn
}
// namespace transformer_engine
void
nvte_extract_seed_and_offset
(
int64_t
*
rng_state_ptr
,
int
captured
,
int64_t
*
seed_ptr
,
uint64_t
seed_val
,
int64_t
*
offset_ptr
,
uint64_t
offset_val
,
uint32_t
offset_intragraph
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_extract_seed_and_offset
);
using
namespace
transformer_engine
;
fused_attn
::
extract_seed_and_offset
<<<
1
,
1
,
0
,
stream
>>>
(
rng_state_ptr
,
captured
,
seed_ptr
,
seed_val
,
offset_ptr
,
offset_val
,
offset_intragraph
);
}
transformer_engine/common/fused_attn/utils.h
View file @
f8c2af4c
...
...
@@ -150,6 +150,38 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at
size_t
get_max_batch_size
(
size_t
batch_size
);
size_t
get_max_tokens
(
size_t
num_tokens
);
class
FusedAttnOffsetManager
{
public:
static
FusedAttnOffsetManager
&
Instance
()
{
static
thread_local
FusedAttnOffsetManager
instance
;
return
instance
;
}
size_t
GetAndUpdateOffset
(
size_t
increment
)
{
size_t
ret
=
offset_
;
offset_
+=
increment
;
return
ret
;
}
FusedAttnOffsetManager
(
FusedAttnOffsetManager
const
&
)
=
delete
;
void
operator
=
(
FusedAttnOffsetManager
const
&
)
=
delete
;
private:
FusedAttnOffsetManager
()
{}
size_t
offset_
=
0
;
};
__global__
void
populate_rng_state_kernel
(
int64_t
*
rng_state_dst
,
const
int64_t
*
const
seed
,
int64_t
offset
);
__global__
void
get_runtime_num_segments_kernel
(
int32_t
*
cu_seqlen
,
size_t
len
,
uint32_t
*
out
);
void
PopulateRngStateAsync
(
void
*
rng_state_dst
,
const
void
*
const
seed
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
NVTE_Fused_Attn_Backend
backend
,
cudaStream_t
stream
);
uint32_t
GetRuntimeNumSegments
(
void
*
cu_seqlen
,
void
*
workspace
,
size_t
len
,
cudaStream_t
stream
);
}
// namespace fused_attn
}
// namespace transformer_engine
...
...
transformer_engine/common/fused_rope/fused_rope.cu
View file @
f8c2af4c
...
...
@@ -115,10 +115,10 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq
template
<
typename
scalar_t
>
__global__
void
fused_rope_forward_kernel
(
const
scalar_t
*
src
,
const
int
*
cu_seqlens
,
const
float
*
freqs
,
scalar_
t
*
d
st
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
h
,
const
int
d
,
const
int
d
2
,
const
int
stride_s_or_t
,
const
int
stride_b
,
const
float
*
freqs
,
const
in
t
*
st
art_positions
,
scalar_t
*
dst
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_s_or_t
,
const
int
stride_b
,
const
int
stride_h
,
const
int
stride_d
,
const
int
o_stride_s_or_t
,
const
int
o_stride_b
,
const
int
o_stride_h
,
const
int
o_stride_d
)
{
...
...
@@ -149,7 +149,8 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq
cur_seqlens
*
cp_size
-
(
cp_rank
+
1
)
*
cur_seqlens
/
2
+
s_id
-
cur_seqlens
/
2
;
}
}
else
{
s_id_for_freqs
=
s_id
;
int
begin_offset
=
(
start_positions
==
nullptr
)
?
0
:
start_positions
[
b_id
];
s_id_for_freqs
=
s_id
+
begin_offset
;
}
fused_rope_block_forward
(
src
,
freqs
,
dst
,
interleaved
,
s_id_for_freqs
,
offset_block
,
...
...
@@ -199,11 +200,12 @@ __global__ void fused_rope_backward_kernel(
template
<
typename
scalar_t
>
void
fused_rope_forward_launcher
(
const
scalar_t
*
input
,
const
int
*
cu_seqlens
,
const
float
*
freqs
,
scalar_t
*
output
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_s_or_t
,
const
int
stride_b
,
const
int
stride_h
,
const
int
stride_d
,
cudaStream_t
stream
)
{
const
int
*
start_positions
,
scalar_t
*
output
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_s_or_t
,
const
int
stride_b
,
const
int
stride_h
,
const
int
stride_d
,
cudaStream_t
stream
)
{
int
warps_per_block
=
h
<
16
?
4
:
8
;
dim3
blocks
(
s
,
b
);
dim3
threads
(
THREADS_PER_WARP
,
warps_per_block
);
...
...
@@ -223,8 +225,9 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c
const
int
o_stride_d
=
1
;
fused_rope_forward_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
input
,
cu_seqlens
,
freqs
,
output
,
interleaved
,
cp_size
,
cp_rank
,
s
,
h
,
d
,
d2
,
stride_s_or_t
,
stride_b
,
stride_h
,
stride_d
,
o_stride_s_or_t
,
o_stride_b
,
o_stride_h
,
o_stride_d
);
input
,
cu_seqlens
,
freqs
,
start_positions
,
output
,
interleaved
,
cp_size
,
cp_rank
,
s
,
h
,
d
,
d2
,
stride_s_or_t
,
stride_b
,
stride_h
,
stride_d
,
o_stride_s_or_t
,
o_stride_b
,
o_stride_h
,
o_stride_d
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
...
...
@@ -262,15 +265,17 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se
}
void
fused_rope_forward
(
const
Tensor
&
input
,
const
Tensor
&
cu_seqlens
,
const
Tensor
&
freqs
,
Tensor
*
output
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_s_or_t
,
const
int
stride_b
,
const
Tensor
&
start_positions
,
Tensor
*
output
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_s_or_t
,
const
int
stride_b
,
const
int
stride_h
,
const
int
stride_d
,
cudaStream_t
stream
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
data
.
dtype
,
scalar_t
,
fused_rope_forward_launcher
(
reinterpret_cast
<
const
scalar_t
*>
(
input
.
data
.
dptr
),
reinterpret_cast
<
const
int
*>
(
cu_seqlens
.
data
.
dptr
),
reinterpret_cast
<
const
float
*>
(
freqs
.
data
.
dptr
),
reinterpret_cast
<
const
int
*>
(
start_positions
.
data
.
dptr
),
reinterpret_cast
<
scalar_t
*>
(
output
->
data
.
dptr
),
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
stride_s_or_t
,
stride_b
,
stride_h
,
stride_d
,
stream
););
...
...
@@ -295,19 +300,19 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, c
}
// end namespace transformer_engine
void
nvte_fused_rope_forward
(
const
NVTETensor
input
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
freqs
,
NVTETensor
output
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_s_or_t
,
const
int
stride_b
,
const
int
stride_h
,
const
int
stride_d
,
cudaStream_t
stream
)
{
const
NVTETensor
freqs
,
const
NVTETensor
start_positions
,
NVTETensor
output
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_s_or_t
,
const
int
stride_b
,
const
int
stride_h
,
const
int
stride_d
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_fused_rope_forward
);
using
namespace
transformer_engine
;
fused_rope_forward
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
*
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens
),
*
reinterpret_cast
<
const
Tensor
*>
(
freqs
),
reinterpret_cast
<
Tensor
*>
(
output
),
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
stride_s_or_t
,
stride_b
,
stride_h
,
stride_d
,
stream
);
fused_rope_forward
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
*
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens
),
*
reinterpret_cast
<
const
Tensor
*>
(
freqs
),
*
reinterpret_cast
<
const
Tensor
*>
(
start_positions
),
reinterpret_cast
<
Tensor
*>
(
output
),
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
stride_s_or_t
,
stride_b
,
stride_h
,
stride_d
,
stream
);
}
void
nvte_fused_rope_backward
(
const
NVTETensor
output_grads
,
const
NVTETensor
cu_seqlens
,
...
...
transformer_engine/common/include/transformer_engine/fused_attn.h
View file @
f8c2af4c
...
...
@@ -11,8 +11,7 @@
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_
#include <cstdint>
#include "stdint.h"
#include "transformer_engine.h"
#ifdef __cplusplus
...
...
@@ -245,7 +244,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_fused_attn_fwd_qkvpacked
(
const
NVTETensor
QKV
,
const
NVTETensor
Bias
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
const
NVTETensor
rng_state
,
size_t
max_seqlen
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
...
...
@@ -301,7 +300,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
*/
void
nvte_fused_attn_bwd_qkvpacked
(
const
NVTETensor
QKV
,
const
NVTETensor
O
,
const
NVTETensor
dO
,
const
NVTETensor
S
,
NVTETensor
dP
,
const
NVTETensorPack
*
Aux_CTX_Tensors
,
NVTETensor
dQKV
,
const
NVTETensorPack
*
Aux_CTX_Tensors
,
NVTETensor
dQKV
,
NVTETensor
dBias
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
size_t
max_seqlen
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
...
...
@@ -369,7 +368,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
*/
void
nvte_fused_attn_fwd_kvpacked
(
const
NVTETensor
Q
,
const
NVTETensor
KV
,
const
NVTETensor
Bias
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
...
...
@@ -430,7 +429,7 @@ void nvte_fused_attn_fwd_kvpacked(
*/
void
nvte_fused_attn_bwd_kvpacked
(
const
NVTETensor
Q
,
const
NVTETensor
KV
,
const
NVTETensor
O
,
const
NVTETensor
dO
,
const
NVTETensor
S
,
NVTETensor
dP
,
const
NVTETensorPack
*
Aux_CTX_Tensors
,
NVTETensor
dQ
,
const
NVTETensor
S
,
NVTETensor
dP
,
const
NVTETensorPack
*
Aux_CTX_Tensors
,
NVTETensor
dQ
,
NVTETensor
dKV
,
NVTETensor
dBias
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
float
attn_scale
,
float
dropout
,
...
...
@@ -501,7 +500,7 @@ void nvte_fused_attn_bwd_kvpacked(
*/
void
nvte_fused_attn_fwd
(
const
NVTETensor
Q
,
const
NVTETensor
K
,
const
NVTETensor
V
,
const
NVTETensor
Bias
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
...
...
@@ -570,7 +569,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
*/
void
nvte_fused_attn_bwd
(
const
NVTETensor
Q
,
const
NVTETensor
K
,
const
NVTETensor
V
,
const
NVTETensor
O
,
const
NVTETensor
dO
,
const
NVTETensor
S
,
NVTETensor
dP
,
const
NVTETensorPack
*
Aux_CTX_Tensors
,
NVTETensor
dQ
,
NVTETensor
dK
,
const
NVTETensorPack
*
Aux_CTX_Tensors
,
NVTETensor
dQ
,
NVTETensor
dK
,
NVTETensor
dV
,
NVTETensor
dBias
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
size_t
max_seqlen_q
,
...
...
@@ -580,6 +579,76 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
int64_t
window_size_right
,
bool
deterministic
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Update the RNG state with the seed and calculated offset.
*
* \param[in] rng_state_dst RNG state to store seed and offset.
* \param[in] seed Seed for RNG state.
* \param[in] q_max_seqlen Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] kv_max_seqlen Max sequence length used for computing for K and V.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] backend Fused attention backend.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_populate_rng_state_async
(
NVTETensor
rng_state_dst
,
const
NVTETensor
seed
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
NVTE_Fused_Attn_Backend
backend
,
cudaStream_t
stream
);
/*! \brief Get KV format for a given QKV layout.
*
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] workspace Workspace tensor.
* \param[in] len batch_size x sequence_length.
* \param[in] stream CUDA stream used for this operation.
*/
uint32_t
nvte_get_runtime_num_segments
(
NVTETensor
cu_seqlen
,
NVTETensor
workspace
,
size_t
len
,
cudaStream_t
stream
);
void
nvte_extract_seed_and_offset
(
int64_t
*
rng_state_ptr
,
int
captured
,
int64_t
*
seed_ptr
,
uint64_t
seed_val
,
int64_t
*
offset_ptr
,
uint64_t
offset_val
,
uint32_t
offset_intragraph
,
cudaStream_t
stream
);
void
nvte_copy_to_kv_cache
(
NVTETensor
new_k
,
NVTETensor
new_v
,
NVTETensor
k_cache
,
NVTETensor
v_cache
,
NVTETensor
page_table
,
NVTETensor
cu_new_lens
,
NVTETensor
cu_cached_lens
,
NVTE_QKV_Format
qkv_format
,
int
b
,
int
max_ctx_len
,
int
max_seq_len
,
int
max_pages_per_seq
,
int
is_non_paged
,
cudaStream_t
stream
);
void
nvte_cp_thd_read_half_tensor
(
const
NVTETensor
&
tensor
,
const
NVTETensor
&
cu_seqlens
,
NVTETensor
half
,
int
half_idx
,
cudaStream_t
stream
);
void
nvte_cp_thd_second_half_lse_correction
(
NVTETensor
lse
,
const
NVTETensor
&
lse_per_step
,
const
NVTETensor
&
cu_seqlens
,
int
lse_packed
,
cudaStream_t
stream
);
void
nvte_cp_thd_read_second_half_lse
(
const
NVTETensor
&
lse
,
const
NVTETensor
&
cu_seqlens
,
NVTETensor
half_lse
,
int
lse_packed
,
int
second_half_lse_seqlen
,
cudaStream_t
stream
);
void
nvte_cp_thd_out_correction
(
NVTETensor
out
,
const
NVTETensor
&
out_per_step
,
const
NVTETensor
&
lse
,
const
NVTETensor
&
lse_per_step
,
const
NVTETensor
&
cu_seqlens
,
int
only_second_half
,
int
lse_packed
,
cudaStream_t
stream
);
void
nvte_cp_thd_grad_correction
(
NVTETensor
grad
,
const
NVTETensor
&
grad_per_step
,
const
NVTETensor
&
cu_seqlens
,
const
char
*
first_half
,
const
char
*
second_half
,
cudaStream_t
stream
);
void
nvte_cp_thd_get_partitioned_indices
(
const
NVTETensor
&
cu_seqlens
,
NVTETensor
output
,
int
total_tokens
,
int
world_size
,
int
rank
,
cudaStream_t
stream
);
void
nvte_convert_thd_to_bshd
(
NVTETensor
tensor
,
NVTETensor
cu_seqlens
,
NVTETensor
new_tensor
,
int
b
,
int
max_seq_len
,
cudaStream_t
stream
);
void
nvte_convert_bshd_to_thd
(
NVTETensor
tensor
,
NVTETensor
cu_seqlens
,
NVTETensor
new_tensor
,
int
t
,
cudaStream_t
stream
);
void
nvte_prepare_flash_attn_fwd
(
NVTETensor
qkvi
,
NVTETensor
qkv
,
cudaStream_t
stream
);
void
nvte_prepare_flash_attn_bwd
(
NVTETensor
q
,
NVTETensor
k
,
NVTETensor
v
,
NVTETensor
qkv
,
cudaStream_t
stream
);
#ifdef __cplusplus
}
// extern "C"
#endif
...
...
transformer_engine/common/include/transformer_engine/fused_rope.h
View file @
f8c2af4c
...
...
@@ -20,6 +20,7 @@ extern "C" {
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* (Required for the thd format, empty tensor for other formats)
* \param[in] freqs The freqs tensor.
* \param[in] start_positions The beginning offsets for applying RoPE embeddings.
* \param[out] output Output tensor.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
...
...
@@ -37,12 +38,12 @@ extern "C" {
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_fused_rope_forward
(
const
NVTETensor
input
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
freqs
,
NVTETensor
output
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_s_or_t
,
const
int
stride_b
,
const
int
stride_h
,
const
int
stride_d
,
cudaStream_t
stream
);
const
NVTETensor
freqs
,
const
NVTETensor
start_positions
,
NVTETensor
output
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_s_or_t
,
const
int
stride_b
,
const
int
stride_h
,
const
int
stride_d
,
cudaStream_t
stream
);
/*! \brief Compute the backward of the fused rope.
*
...
...
transformer_engine/common/include/transformer_engine/multi_tensor.h
0 → 100644
View file @
f8c2af4c
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file multi_tensor.h
* \brief Functions handling multi tensor kernels.
*/
#ifndef TRANSFORMER_ENGINE_MULTI_TENSOR_H_
#define TRANSFORMER_ENGINE_MULTI_TENSOR_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern
"C"
{
#endif
void
nvte_multi_tensor_l2norm_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
NVTETensor
output
,
NVTETensor
output_per_tensor
,
NVTETensor
ret
,
NVTETensor
ret_per_tensor
,
int
per_tensor
,
int
max_chunks_per_tensor
,
const
int
device_id
,
cudaStream_t
stream
);
void
nvte_multi_tensor_unscale_l2norm_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
NVTETensor
output
,
NVTETensor
output_per_tensor
,
NVTETensor
ret
,
NVTETensor
ret_per_tensor
,
NVTETensor
inv_scale
,
int
per_tensor
,
int
max_chunks_per_tensor
,
const
int
device_id
,
cudaStream_t
stream
);
void
nvte_multi_tensor_adam_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
const
int
device_id
,
cudaStream_t
stream
);
void
nvte_multi_tensor_adam_param_remainder_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
const
int
device_id
,
cudaStream_t
stream
);
void
nvte_multi_tensor_adam_fp8_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
const
NVTEDType
fp8_dtype
,
const
int
device_id
,
cudaStream_t
stream
);
void
nvte_multi_tensor_adam_capturable_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
NVTETensor
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
NVTETensor
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
NVTETensor
inv_scale
,
const
int
device_id
,
cudaStream_t
stream
);
void
nvte_multi_tensor_adam_capturable_master_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
NVTETensor
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
NVTETensor
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
NVTETensor
inv_scale
,
const
int
device_id
,
cudaStream_t
stream
);
void
nvte_multi_tensor_sgd_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
float
wd
,
float
momentum
,
float
dampening
,
float
lr
,
int
nesterov
,
int
first_run
,
int
wd_after_momentum
,
float
scale
,
const
int
device_id
,
cudaStream_t
stream
);
void
nvte_multi_tensor_scale_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
float
scale
,
const
int
device_id
,
cudaStream_t
stream
);
void
nvte_multi_tensor_compute_scale_and_scale_inv_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
float
max_fp8
,
int
force_pow_2_scales
,
float
epsilon
,
const
int
device_id
,
cudaStream_t
stream
);
#ifdef __cplusplus
}
// extern "C"
#endif
#endif // TRANSFORMER_ENGINE_MULTI_TENSOR_H_
transformer_engine/common/include/transformer_engine/permutation.h
View file @
f8c2af4c
...
...
@@ -18,4 +18,7 @@ void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id
const
NVTETensor
prob
,
const
int
num_rows
,
const
int
topK
,
const
int
num_cols
,
cudaStream_t
stream
=
nullptr
);
void
nvte_device_radix_sort_pairs
(
void
*
temp_storage
,
size_t
*
temp_storage_bytes
,
int
*
keys_in
,
int
*
keys_out
,
int
*
values_in
,
int
*
values_out
,
size_t
num_items
);
#endif // TRANSFORMER_ENGINE_PERMUTATION_H_
transformer_engine/common/include/transformer_engine/recipe.h
View file @
f8c2af4c
...
...
@@ -96,6 +96,17 @@ void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t s
void
nvte_compute_scale_from_amax
(
NVTETensor
output
,
const
NVTEQuantizationConfig
config
,
cudaStream_t
stream
);
void
nvte_fp8_block_scaling_compute_partial_amax
(
const
NVTETensor
inp
,
NVTETensor
amax
,
size_t
h
,
size_t
w
,
size_t
amax_stride_h
,
size_t
amax_stride_w
,
size_t
start_offset
,
size_t
block_len
,
cudaStream_t
stream
);
void
nvte_fp8_block_scaling_partial_cast
(
const
NVTETensor
inp
,
NVTETensor
out
,
const
NVTETensor
scale
,
size_t
h
,
size_t
w
,
size_t
scale_stride_h
,
size_t
scale_stride_w
,
size_t
start_offset
,
size_t
block_len
,
const
NVTEDType
out_dtype
,
cudaStream_t
stream
);
#ifdef __cplusplus
}
// extern "C"
#endif
...
...
transformer_engine/common/include/transformer_engine/transformer_engine.h
View file @
f8c2af4c
...
...
@@ -23,14 +23,15 @@ extern "C" {
*/
enum
NVTEDType
{
kNVTEByte
=
0
,
/*!< Byte */
kNVTEInt32
=
1
,
/*!< 32-bit integer */
kNVTEInt64
=
2
,
/*!< 64-bit integer */
kNVTEFloat32
=
3
,
/*!< 32-bit float */
kNVTEFloat16
=
4
,
/*!< 16-bit float (E5M10) */
kNVTEBFloat16
=
5
,
/*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E4M3
=
6
,
/*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2
=
7
,
/*!< 8-bit float (E5M2) */
kNVTEFloat8E8M0
=
8
,
/*!< 8-bit float (E8M0) */
kNVTEInt16
=
1
,
/*!< 16-bit integer */
kNVTEInt32
=
2
,
/*!< 32-bit integer */
kNVTEInt64
=
3
,
/*!< 64-bit integer */
kNVTEFloat32
=
4
,
/*!< 32-bit float */
kNVTEFloat16
=
5
,
/*!< 16-bit float (E5M10) */
kNVTEBFloat16
=
6
,
/*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E4M3
=
7
,
/*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2
=
8
,
/*!< 8-bit float (E5M2) */
kNVTEFloat8E8M0
=
9
,
/*!< 8-bit float (E8M0) */
kNVTENumTypes
/*!< Number of supported types */
};
...
...
@@ -38,12 +39,10 @@ enum NVTEDType {
* \brief Shape of the tensor.
*/
struct
NVTEShape
{
/*! \brief Shape data,
of size ndim
. */
const
size_t
*
data
;
/*! \brief Shape data,
with ndim valid elements
. */
size_t
data
[
15
]
;
/*! \brief Number of dimensions. */
size_t
ndim
;
/*! \brief Copy of data. Num dims limited to permit fixed struct size.*/
size_t
owned_data
[
14
];
};
/*! \struct NVTEBasicTensor
...
...
@@ -343,6 +342,23 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
*/
void
nvte_destroy_quantization_config
(
NVTEQuantizationConfig
config
);
/*! \brief Check if non-TN FP8 Gemm is supported.
*
* \return A flag which indicates whether non-TN FP8 Gemm is supported or not.
*/
int
nvte_is_non_tn_fp8_gemm_supported
();
/*! \brief Performs a memset of the data at the given pointer and size in bytes.
*
* \param[in] ptr Pointer to the memory to be set.
* \param[in] value Value to set the memory to.
* \param[in] size_in_bytes Size of the memory in bytes.
* \param[in] stream CUDA stream to use for the operation.
*
* This function calls a fill kernel for small sizes and calls cudaMemsetAsync for larger sizes.
*/
void
nvte_memset
(
void
*
ptr
,
int
value
,
size_t
size_in_bytes
,
cudaStream_t
stream
);
#ifdef __cplusplus
}
// extern "C"
...
...
@@ -358,14 +374,15 @@ namespace transformer_engine {
*/
enum
class
DType
{
kByte
=
0
,
kInt32
=
1
,
kInt64
=
2
,
kFloat32
=
3
,
kFloat16
=
4
,
kBFloat16
=
5
,
kFloat8E4M3
=
6
,
kFloat8E5M2
=
7
,
kFloat8E8M0
=
8
,
kInt16
=
1
,
kInt32
=
2
,
kInt64
=
3
,
kFloat32
=
4
,
kFloat16
=
5
,
kBFloat16
=
6
,
kFloat8E4M3
=
7
,
kFloat8E5M2
=
8
,
kFloat8E8M0
=
9
,
kNumTypes
};
...
...
@@ -691,15 +708,10 @@ class TensorWrapper {
static
constexpr
size_t
defaultData
=
1
;
static
constexpr
NVTEShape
defaultShape
=
{
&
defaultData
,
1
,
{
defaultData
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
}
};
{
defaultData
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
1
};
private:
NVTEShape
convertShape
(
const
NVTEShape
&
s
)
{
NVTEShape
ret
=
s
;
// Move the ownership rather than pointing to the parent shape.
ret
.
data
=
ret
.
owned_data
;
return
ret
;
}
NVTEShape
convertShape
(
const
NVTEShape
&
s
)
{
return
s
;
}
NVTEShape
convertShape
(
const
std
::
vector
<
size_t
>
&
s
)
{
return
nvte_make_shape
(
s
.
data
(),
s
.
size
());
...
...
transformer_engine/
pytorch/csrc/extensions/multi_tensor
/multi_tensor
_
adam.cu
→
transformer_engine/
common
/multi_tensor
/
adam.cu
View file @
f8c2af4c
...
...
@@ -4,23 +4,16 @@
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#ifdef __HIP_PLATFORM_AMD__
#include "amd_detail/hip_float8.h"
#else
#include <cuda_fp8.h>
#endif
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include "
common
/utils.cuh"
#include "
..
/utils.cuh"
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace
transformer_engine
{
namespace
multi_tensor_adam
{
#define BLOCK_SIZE 512
#define ILP 4
...
...
@@ -39,7 +32,6 @@ using fp8e5m2 = __nv_fp8_e5m2;
using
fp8e4m3
=
te_hip_fp8_e4m3
;
using
fp8e5m2
=
te_hip_fp8_e5m2
;
#endif
using
transformer_engine
::
DType
;
template
<
typename
T
>
struct
is_fp8
:
std
::
false_type
{};
...
...
@@ -585,12 +577,13 @@ struct AdamCapturableMasterFunctor {
}
};
void
multi_tensor_adam_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
const
float
lr
,
void
multi_tensor_adam_cuda
(
int
chunk_size
,
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
Tensor
*
>>
tensor_lists
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
)
{
using
namespace
at
;
const
float
weight_decay
,
const
int
device_id
,
cudaStream_t
stream
)
{
const
size_t
num_tensor_lists
=
tensor_lists
.
size
();
const
size_t
num_tensors_per_list
=
tensor_lists
[
0
].
size
();
// Handle bias correction mode
float
bias_correction1
=
1.0
f
,
bias_correction2
=
1.0
f
;
...
...
@@ -601,10 +594,10 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
size_t
max_size
=
0
;
bool
requires_64bit_indexing
=
false
;
for
(
auto
i
t
=
tensor_lists
.
begin
();
it
!=
tensor_lists
.
end
()
;
i
t
++
)
{
for
(
auto
it2
=
it
->
begin
();
it2
!=
it
->
end
();
it2
++
)
{
if
(
it2
->
numel
()
>
max_size
)
{
max_size
=
it2
->
numel
();
for
(
size_t
i
=
0
;
i
<
num_
tensor_lists
;
i
++
)
{
for
(
size_t
j
=
0
;
j
<
num_tensors_per_list
;
j
++
)
{
if
(
tensor_lists
[
i
][
j
]
->
numel
()
>
max_size
)
{
max_size
=
tensor_lists
[
i
][
j
]
->
numel
();
if
(
max_size
>=
INT_MAX
)
{
requires_64bit_indexing
=
true
;
break
;
...
...
@@ -616,69 +609,70 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
}
}
const
auto
g_in_type
=
tensor_lists
[
0
][
0
].
scalar_type
();
const
auto
p_in_type
=
tensor_lists
[
1
][
0
].
scalar_type
();
auto
tl_size
=
tensor_lists
.
size
();
const
auto
g_in_type_te
=
tensor_lists
[
0
][
0
]
->
dtype
();
const
auto
p_in_type_te
=
tensor_lists
[
1
][
0
]
->
dtype
();
// case 4: g, p, m, v
// case 5: g, p, m, v, p_master
TORCH
_CHECK
(
tl_size
==
4
||
tl_size
==
5
,
"tensor list must contain 4 or 5"
);
NVTE
_CHECK
(
num_tensor_lists
==
4
||
num_tensor_lists
==
5
,
"tensor list must contain 4 or 5"
);
if
(
requires_64bit_indexing
)
{
if
(
tl_size
==
4
)
{
if
(
num_tensor_lists
==
4
)
{
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT
(
p_in_type
,
0
,
"adam"
,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT
(
g_in_type
,
1
,
"adam"
,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
p_in_type
_te
,
p_in_type
,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
g_in_type
_te
,
g_in_type
,
multi_tensor_apply
<
BLOCK_SIZE
,
4
>
((
int64_t
)
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
scalar_t_0
,
scalar_t_1
,
float
,
int64_t
>
(),
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
);));
AdamFunctor
<
p_in_type
,
g_in_type
,
float
,
int64_t
>
(),
device_id
,
stream
,
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
);));
}
else
{
// g, p, m, v, p_master
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT
(
p_in_type
,
0
,
"adam"
,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT
(
g_in_type
,
1
,
"adam"
,
multi_tensor_apply
<
BLOCK_SIZE
,
5
>
(
(
int64_t
)
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctorMaster
<
scalar_t_0
,
scalar_t_1
,
float
,
int64_t
>
(),
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
);));
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
p_in_type
_te
,
p_in_type
,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
g_in_type
_te
,
g_in_type
,
multi_tensor_apply
<
BLOCK_SIZE
,
5
>
(
(
int64_t
)
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctorMaster
<
p_in_type
,
g_in_type
,
float
,
int64_t
>
(),
device_id
,
stream
,
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
);));
}
}
else
{
if
(
tl_size
==
4
)
{
if
(
num_tensor_lists
==
4
)
{
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT
(
p_in_type
,
0
,
"adam"
,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT
(
g_in_type
,
1
,
"adam"
,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
p_in_type
_te
,
p_in_type
,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
g_in_type
_te
,
g_in_type
,
multi_tensor_apply
<
BLOCK_SIZE
,
4
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
scalar_t_0
,
scalar_t_1
,
float
,
int32_t
>
(),
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
);));
AdamFunctor
<
p_in_type
,
g_in_type
,
float
,
int32_t
>
(),
device_id
,
stream
,
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
);));
}
else
{
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT
(
p_in_type
,
0
,
"adam"
,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT
(
g_in_type
,
1
,
"adam"
,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
p_in_type
_te
,
p_in_type
,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
g_in_type
_te
,
g_in_type
,
multi_tensor_apply
<
BLOCK_SIZE
,
5
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctorMaster
<
scalar_t_0
,
scalar_t_1
,
float
,
int32_t
>
(),
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
);));
AdamFunctorMaster
<
p_in_type
,
g_in_type
,
float
,
int32_t
>
(),
device_id
,
stream
,
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
);));
}
}
AT_CUDA
_CHECK
(
cudaGetLastError
());
NVTE
_CHECK
_CUDA
(
cudaGetLastError
());
}
void
multi_tensor_adam_param_remainder_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
void
multi_tensor_adam_param_remainder_cuda
(
int
chunk_size
,
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
Tensor
*
>>
tensor_lists
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
)
{
using
namespace
at
;
const
int
bias_correction
,
const
float
weight_decay
,
const
int
device_id
,
cudaStream_t
stream
)
{
const
size_t
num_tensor_lists
=
tensor_lists
.
size
();
// Handle bias correction mode
float
bias_correction1
=
1.0
f
,
bias_correction2
=
1.0
f
;
...
...
@@ -687,34 +681,34 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
}
const
auto
g_in_type
=
tensor_lists
[
0
][
0
].
scalar_type
();
const
auto
p_in_type
=
tensor_lists
[
1
][
0
].
scalar_type
();
auto
tl_size
=
tensor_lists
.
size
();
const
auto
g_in_type_te
=
tensor_lists
[
0
][
0
]
->
dtype
();
const
auto
p_in_type_te
=
tensor_lists
[
1
][
0
]
->
dtype
();
// case 5: g, p, m, v, p_master
TORCH
_CHECK
(
tl_size
==
5
,
"tensor list must contain 5"
);
TORCH
_CHECK
(
p_in_type
==
at
::
Scalar
Type
::
BFloat16
,
"Adam with BF16 param remainders requires BF16 params"
);
NVTE
_CHECK
(
num_tensor_lists
==
5
,
"tensor list must contain 5"
);
NVTE
_CHECK
(
p_in_type
_te
==
D
Type
::
k
BFloat16
,
"Adam with BF16 param remainders requires BF16 params"
);
// g, p, m, v, p_master
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT
(
p_in_type
,
0
,
"adam"
,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT
(
g_in_type
,
1
,
"adam"
,
multi_tensor_apply
<
BLOCK_SIZE
,
5
>
((
int64_t
)
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctorMasterParamRemainder
<
scalar_t_1
,
float
,
int64_t
>
(),
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
);));
AT_CUDA_CHECK
(
cudaGetLastError
());
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
g_in_type_te
,
g_in_type
,
multi_tensor_apply
<
BLOCK_SIZE
,
5
>
((
int64_t
)
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctorMasterParamRemainder
<
g_in_type
,
float
,
int64_t
>
(),
device_id
,
stream
,
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
););
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
void
multi_tensor_adam_fp8_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
const
float
lr
,
void
multi_tensor_adam_fp8_cuda
(
int
chunk_size
,
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
Tensor
*
>>
tensor_lists
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
DType
fp8_dtype
)
{
using
namespace
at
;
const
float
weight_decay
,
const
DType
fp8_dtype
,
const
int
device_id
,
cudaStream_t
stream
)
{
const
size_t
num_tensor_lists
=
tensor_lists
.
size
();
const
size_t
num_tensors_per_list
=
tensor_lists
[
0
].
size
();
// Handle bias correction mode
float
bias_correction1
=
1.0
f
,
bias_correction2
=
1.0
f
;
...
...
@@ -725,10 +719,10 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
size_t
max_size
=
0
;
bool
requires_64bit_indexing
=
false
;
for
(
auto
i
t
=
tensor_lists
.
begin
();
it
!=
tensor_lists
.
end
()
;
i
t
++
)
{
for
(
auto
it2
=
it
->
begin
();
it2
!=
it
->
end
();
it2
++
)
{
if
(
it2
->
numel
()
>
max_size
)
{
max_size
=
it2
->
numel
();
for
(
size_t
i
=
0
;
i
<
num_
tensor_lists
;
i
++
)
{
for
(
size_t
j
=
0
;
j
<
num_tensors_per_list
;
j
++
)
{
if
(
tensor_lists
[
i
][
j
]
->
numel
()
>
max_size
)
{
max_size
=
tensor_lists
[
i
][
j
]
->
numel
();
if
(
max_size
>=
INT_MAX
)
{
requires_64bit_indexing
=
true
;
break
;
...
...
@@ -740,66 +734,147 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
}
}
const
auto
g_in_type
=
tensor_lists
[
0
][
0
].
scalar_type
();
auto
tl_size
=
tensor_lists
.
size
();
const
auto
g_in_type_te
=
tensor_lists
[
0
][
0
]
->
dtype
();
// case 8: g, p_fp8, m, v, p_master, scale, amax, scale_inv
TORCH
_CHECK
(
tl_size
==
8
,
"tensor list must contain 8 tensors"
);
NVTE
_CHECK
(
num_tensor_lists
==
8
,
"tensor list must contain 8 tensors"
);
if
(
requires_64bit_indexing
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY
(
fp8_dtype
,
FP8_T
,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT
(
g_in_type
,
0
,
"adam"
,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
g_in_type
_te
,
g_in_type
,
multi_tensor_apply
<
BLOCK_SIZE
,
5
,
true
>
(
(
int64_t
)
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctorMaster
<
FP8_T
,
scalar_t_0
,
float
,
int64_t
>
(),
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
);));
AdamFunctorMaster
<
FP8_T
,
g_in_type
,
float
,
int64_t
>
(),
device_id
,
stream
,
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
);));
}
else
{
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY
(
fp8_dtype
,
FP8_T
,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT
(
g_in_type
,
0
,
"adam"
,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
g_in_type
_te
,
g_in_type
,
multi_tensor_apply
<
BLOCK_SIZE
,
5
,
true
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctorMaster
<
FP8_T
,
scalar_t_0
,
float
,
int32_t
>
(),
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
);));
AdamFunctorMaster
<
FP8_T
,
g_in_type
,
float
,
int32_t
>
(),
device_id
,
stream
,
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
);));
}
AT_CUDA
_CHECK
(
cudaGetLastError
());
NVTE
_CHECK
_CUDA
(
cudaGetLastError
());
}
void
multi_tensor_adam_capturable_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
at
::
Tensor
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
at
::
Tensor
inv_scale
)
{
using
namespace
at
;
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"adam"
,
void
multi_tensor_adam_capturable_cuda
(
int
chunk_size
,
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
Tensor
*>>
tensor_lists
,
Tensor
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
Tensor
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
Tensor
inv_scale
,
const
int
device_id
,
cudaStream_t
stream
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
tensor_lists
[
0
][
0
]
->
dtype
(),
dtype
,
multi_tensor_apply
<
BLOCK_SIZE
,
4
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
AdamCapturableFunctor
<
scalar_t_0
,
float
>
(),
beta1
,
beta2
,
step
.
data_ptr
<
int
>
(),
bias_correction
,
epsilon
,
lr
.
data_ptr
<
float
>
(),
(
adamMode_t
)
mode
,
weight_decay
,
inv_scale
.
data_ptr
<
float
>
());)
AdamCapturableFunctor
<
dtype
,
float
>
(),
device_id
,
stream
,
beta1
,
beta2
,
reinterpret_cast
<
int
*>
(
step
.
data
.
dptr
),
bias_correction
,
epsilon
,
reinterpret_cast
<
float
*>
(
lr
.
data
.
dptr
),
(
adamMode_t
)
mode
,
weight_decay
,
reinterpret_cast
<
float
*>
(
inv_scale
.
data
.
dptr
));)
AT_CUDA
_CHECK
(
cudaGetLastError
());
NVTE
_CHECK
_CUDA
(
cudaGetLastError
());
}
void
multi_tensor_adam_capturable_master_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
at
::
Tensor
step
,
const
int
mode
,
void
multi_tensor_adam_capturable_master_cuda
(
int
chunk_size
,
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
Tensor
*
>>
tensor_lists
,
Tensor
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
Tensor
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
at
::
Tensor
inv_scale
)
{
using
namespace
at
;
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"adam"
,
Tensor
inv_scale
,
const
int
device_id
,
cudaStream_t
stream
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
tensor_lists
[
0
][
0
]
->
dtype
(),
dtype
,
multi_tensor_apply
<
BLOCK_SIZE
,
5
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
AdamCapturableMasterFunctor
<
scalar_t_0
,
float
>
(),
beta1
,
beta2
,
step
.
data_ptr
<
int
>
(),
bias_correction
,
epsilon
,
lr
.
data_ptr
<
float
>
(),
(
adamMode_t
)
mode
,
weight_decay
,
inv_scale
.
data_ptr
<
float
>
());)
AdamCapturableMasterFunctor
<
dtype
,
float
>
(),
device_id
,
stream
,
beta1
,
beta2
,
reinterpret_cast
<
int
*>
(
step
.
data
.
dptr
),
bias_correction
,
epsilon
,
reinterpret_cast
<
float
*>
(
lr
.
data
.
dptr
),
(
adamMode_t
)
mode
,
weight_decay
,
reinterpret_cast
<
float
*>
(
inv_scale
.
data
.
dptr
));)
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
// namespace multi_tensor_adam
}
// namespace transformer_engine
void
nvte_multi_tensor_adam_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
const
int
device_id
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_tensor_adam_cuda
);
using
namespace
transformer_engine
;
multi_tensor_adam
::
multi_tensor_adam_cuda
(
chunk_size
,
*
reinterpret_cast
<
Tensor
*>
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
lr
,
beta1
,
beta2
,
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
device_id
,
stream
);
}
void
nvte_multi_tensor_adam_param_remainder_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
const
int
device_id
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_tensor_adam_param_remainder_cuda
);
using
namespace
transformer_engine
;
multi_tensor_adam
::
multi_tensor_adam_param_remainder_cuda
(
chunk_size
,
*
reinterpret_cast
<
Tensor
*>
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
lr
,
beta1
,
beta2
,
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
device_id
,
stream
);
}
void
nvte_multi_tensor_adam_fp8_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
const
NVTEDType
fp8_dtype
,
const
int
device_id
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_tensor_adam_fp8_cuda
);
using
namespace
transformer_engine
;
multi_tensor_adam
::
multi_tensor_adam_fp8_cuda
(
chunk_size
,
*
reinterpret_cast
<
Tensor
*>
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
lr
,
beta1
,
beta2
,
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
static_cast
<
DType
>
(
fp8_dtype
),
device_id
,
stream
);
}
void
nvte_multi_tensor_adam_capturable_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
NVTETensor
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
NVTETensor
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
NVTETensor
inv_scale
,
const
int
device_id
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_tensor_adam_capturable_cuda
);
using
namespace
transformer_engine
;
multi_tensor_adam
::
multi_tensor_adam_capturable_cuda
(
chunk_size
,
*
reinterpret_cast
<
Tensor
*>
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
*
reinterpret_cast
<
Tensor
*>
(
lr
),
beta1
,
beta2
,
epsilon
,
*
reinterpret_cast
<
Tensor
*>
(
step
),
mode
,
bias_correction
,
weight_decay
,
*
reinterpret_cast
<
Tensor
*>
(
inv_scale
),
device_id
,
stream
);
}
AT_CUDA_CHECK
(
cudaGetLastError
());
void
nvte_multi_tensor_adam_capturable_master_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
NVTETensor
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
NVTETensor
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
NVTETensor
inv_scale
,
const
int
device_id
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_tensor_adam_capturable_master_cuda
);
using
namespace
transformer_engine
;
multi_tensor_adam
::
multi_tensor_adam_capturable_master_cuda
(
chunk_size
,
*
reinterpret_cast
<
Tensor
*>
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
*
reinterpret_cast
<
Tensor
*>
(
lr
),
beta1
,
beta2
,
epsilon
,
*
reinterpret_cast
<
Tensor
*>
(
step
),
mode
,
bias_correction
,
weight_decay
,
*
reinterpret_cast
<
Tensor
*>
(
inv_scale
),
device_id
,
stream
);
}
transformer_engine/
pytorch/csrc/extensions/multi_tensor
/multi_tensor
_
compute_scale.cu
→
transformer_engine/
common
/multi_tensor
/
compute_scale.cu
View file @
f8c2af4c
...
...
@@ -4,23 +4,21 @@
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include <limits>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <assert.h>
#include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include <sstream>
#include "
common
/recipe/recipe_common.cuh"
#include "
common
/utils.cuh"
#include "
..
/recipe/recipe_common.cuh"
#include "
..
/utils.cuh"
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace
transformer_engine
{
namespace
multi_tensor_compute_scale
{
#define BLOCK_SIZE 256
...
...
@@ -57,12 +55,29 @@ struct ComputeScaleAndScaleInvFunctor {
}
};
void
multi_tensor_compute_scale_and_scale_inv_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
max_fp8
,
bool
force_pow_2_scales
,
float
epsilon
)
{
using
namespace
at
;
void
multi_tensor_compute_scale_and_scale_inv_cuda
(
int
chunk_size
,
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
Tensor
*
>>
tensor_lists
,
float
max_fp8
,
bool
force_pow_2_scales
,
float
epsilon
,
const
int
device_id
,
cudaStream_t
stream
)
{
multi_tensor_apply
<
BLOCK_SIZE
,
3
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
ComputeScaleAndScaleInvFunctor
(),
max_fp8
,
force_pow_2_scales
,
epsilon
);
AT_CUDA_CHECK
(
cudaGetLastError
());
ComputeScaleAndScaleInvFunctor
(),
device_id
,
stream
,
max_fp8
,
force_pow_2_scales
,
epsilon
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
// namespace multi_tensor_compute_scale
}
// namespace transformer_engine
void
nvte_multi_tensor_compute_scale_and_scale_inv_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
float
max_fp8
,
int
force_pow_2_scales
,
float
epsilon
,
const
int
device_id
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_tensor_compute_scale_and_scale_inv_cuda
);
using
namespace
transformer_engine
;
multi_tensor_compute_scale
::
multi_tensor_compute_scale_and_scale_inv_cuda
(
chunk_size
,
*
reinterpret_cast
<
Tensor
*>
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
max_fp8
,
force_pow_2_scales
,
epsilon
,
device_id
,
stream
);
}
transformer_engine/
pytorch/csrc/extensi
on
s
/multi_tensor/
multi_tensor_l2norm_kernel
.cu
→
transformer_engine/
comm
on/multi_tensor/
l2norm
.cu
View file @
f8c2af4c
...
...
@@ -4,18 +4,16 @@
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include "../utils.cuh"
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace
transformer_engine
{
namespace
multi_tensor_l2norm
{
#define BLOCK_SIZE 512
#define ILP 4
...
...
@@ -31,6 +29,96 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, int s
((
LT
*
)
dst
)[
dst_offset
]
=
((
LT
*
)
src
)[
src_offset
];
// NOLINT(*)
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
{
// lanes is intended to be <= 32.
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
x
[
tid
]
+
x
[
tid
+
i
];
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
x
[
tid
]
+
x
[
tid
+
32
];
else
final
=
val
;
// __SYNCWARP();
#ifdef __HIP_PLATFORM_AMD__
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down
(
final
,
i
);
#else
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
#endif
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
}
__syncthreads
();
// Avoid potential write before read race when reduce_block_into_lanes is called back to back
return
final
;
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes_max_op
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
{
// lanes is intended to be <= 32.
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
i
]));
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
32
]));
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
#ifdef __HIP_PLATFORM_AMD__
final
=
fmaxf
(
fabsf
(
final
),
fabsf
(
__shfl_down
(
final
,
i
)));
#else
final
=
fmaxf
(
fabsf
(
final
),
fabsf
(
__shfl_down_sync
(
0xffffffff
,
final
,
i
)));
#endif
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
template
<
typename
x_t
>
struct
L2NormFunctor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
...
...
@@ -56,7 +144,7 @@ struct L2NormFunctor {
x_t
r_x
[
ILP
];
for
(
int
i
=
0
;
i
<
ILP
;
i
++
)
{
vals
[
i
]
=
0.
f
;
r_x
[
i
]
=
0
;
r_x
[
i
]
=
0
.
f
;
}
// to make things simple, we put aligned case in a different code path
...
...
@@ -126,7 +214,7 @@ struct UnscaleL2NormFunctor {
x_t
r_x
[
ILP
];
for
(
int
i
=
0
;
i
<
ILP
;
i
++
)
{
vals
[
i
]
=
0.
f
;
r_x
[
i
]
=
0
;
r_x
[
i
]
=
0
.
f
;
}
// to make things simple, we put aligned case in a different code path
...
...
@@ -310,103 +398,96 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
}
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
multi_tensor_l2norm_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
optional
<
bool
>
per_tensor_python
)
{
bool
per_tensor
=
per_tensor_python
.
has_value
()
?
per_tensor_python
.
value
()
:
false
;
auto
float_options
=
tensor_lists
[
0
][
0
].
options
().
dtype
(
at
::
kFloat
);
auto
output
=
at
::
zeros
({
320
},
float_options
);
at
::
Tensor
output_per_tensor
;
at
::
Tensor
ret_per_tensor
;
int
ntensors
=
tensor_lists
[
0
].
size
();
int
max_chunks_per_tensor
=
-
1
;
if
(
per_tensor
)
{
for
(
int
t
=
0
;
t
<
ntensors
;
t
++
)
{
int
max_chunks_this_tensor
=
(
tensor_lists
[
0
][
t
].
numel
()
+
chunk_size
-
1
)
/
chunk_size
;
if
(
max_chunks_this_tensor
>
max_chunks_per_tensor
)
max_chunks_per_tensor
=
max_chunks_this_tensor
;
}
output_per_tensor
=
at
::
zeros
({
ntensors
*
max_chunks_per_tensor
},
float_options
);
ret_per_tensor
=
at
::
empty
({
ntensors
},
float_options
);
}
else
{
ret_per_tensor
=
at
::
empty
({
0
},
float_options
);
}
DISPATCH_FLOAT_HALF_AND_BFLOAT
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_l2norm_cuda"
,
multi_tensor_apply
<
BLOCK_SIZE
,
1
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
L2NormFunctor
<
scalar_t_0
>
(),
output
.
data_ptr
<
float
>
(),
per_tensor
?
output_per_tensor
.
data_ptr
<
float
>
()
:
nullptr
,
per_tensor
,
max_chunks_per_tensor
);)
AT_CUDA_CHECK
(
cudaGetLastError
());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
void
multi_tensor_l2norm_cuda
(
int
chunk_size
,
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
Tensor
*>>
tensor_lists
,
Tensor
output
,
Tensor
output_per_tensor
,
Tensor
ret
,
Tensor
ret_per_tensor
,
bool
per_tensor
,
int
max_chunks_per_tensor
,
const
int
device_id
,
cudaStream_t
stream
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
tensor_lists
[
0
][
0
]
->
dtype
(),
dtype
,
multi_tensor_apply
<
BLOCK_SIZE
,
1
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
L2NormFunctor
<
dtype
>
(),
device_id
,
stream
,
reinterpret_cast
<
float
*>
(
output
.
data
.
dptr
),
per_tensor
?
reinterpret_cast
<
float
*>
(
output_per_tensor
.
data
.
dptr
)
:
nullptr
,
per_tensor
,
max_chunks_per_tensor
);)
NVTE_CHECK_CUDA
(
cudaGetLastError
());
// This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now
auto
ret
=
at
::
empty
({
1
},
output
.
options
()
);
c
onst
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
output
));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
cleanup
<<<
per_tensor
?
ntensors
:
1
,
512
,
0
,
stream
>>>
(
output
.
data_ptr
<
float
>
(),
per_tensor
?
output_per_tensor
.
data
_
ptr
<
float
>
()
:
nullptr
,
ret
.
data_ptr
<
float
>
(),
per_tensor
?
ret_per_tensor
.
data
_
ptr
<
float
>
(
)
:
nullptr
,
per_tensor
,
const
OptionalCUDAGuard
device_guard
(
device_id
);
c
leanup
<<<
per_tensor
?
tensor_lists
[
0
].
size
()
:
1
,
512
,
0
,
stream
>>>
(
reinterpret_cast
<
float
*>
(
output
.
data
.
dptr
),
per_tensor
?
reinterpret_cast
<
float
*>
(
output_per_tensor
.
data
.
dptr
)
:
nullptr
,
reinterpret_cast
<
float
*>
(
ret
.
data
.
d
ptr
)
,
per_tensor
?
reinterpret_cast
<
float
*>
(
ret_per_tensor
.
data
.
d
ptr
)
:
nullptr
,
per_tensor
,
max_chunks_per_tensor
);
return
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
(
ret
,
ret_per_tensor
);
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
multi_tensor_unscale_l2norm_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
inv_scale
,
at
::
optional
<
bool
>
per_tensor_python
)
{
bool
per_tensor
=
per_tensor_python
.
has_value
()
?
per_tensor_python
.
value
()
:
false
;
auto
float_options
=
tensor_lists
[
0
][
0
].
options
().
dtype
(
at
::
kFloat
);
auto
output
=
at
::
zeros
({
320
},
float_options
);
at
::
Tensor
output_per_tensor
;
at
::
Tensor
ret_per_tensor
;
int
ntensors
=
tensor_lists
[
0
].
size
();
int
max_chunks_per_tensor
=
-
1
;
if
(
per_tensor
)
{
for
(
int
t
=
0
;
t
<
ntensors
;
t
++
)
{
int
max_chunks_this_tensor
=
(
tensor_lists
[
0
][
t
].
numel
()
+
chunk_size
-
1
)
/
chunk_size
;
if
(
max_chunks_this_tensor
>
max_chunks_per_tensor
)
max_chunks_per_tensor
=
max_chunks_this_tensor
;
}
output_per_tensor
=
at
::
zeros
({
ntensors
*
max_chunks_per_tensor
},
float_options
);
ret_per_tensor
=
at
::
empty
({
ntensors
},
float_options
);
}
else
{
ret_per_tensor
=
at
::
empty
({
0
},
float_options
);
}
DISPATCH_FLOAT_HALF_AND_BFLOAT
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_unscale_l2norm_cuda"
,
multi_tensor_apply
<
BLOCK_SIZE
,
1
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
UnscaleL2NormFunctor
<
scalar_t_0
>
(),
inv_scale
.
data_ptr
<
float
>
(),
output
.
data_ptr
<
float
>
(),
per_tensor
?
output_per_tensor
.
data_ptr
<
float
>
()
:
nullptr
,
per_tensor
,
max_chunks_per_tensor
);)
AT_CUDA_CHECK
(
cudaGetLastError
());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
void
multi_tensor_unscale_l2norm_cuda
(
int
chunk_size
,
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
Tensor
*>>
tensor_lists
,
Tensor
output
,
Tensor
output_per_tensor
,
Tensor
ret
,
Tensor
ret_per_tensor
,
Tensor
inv_scale
,
bool
per_tensor
,
int
max_chunks_per_tensor
,
const
int
device_id
,
cudaStream_t
stream
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
tensor_lists
[
0
][
0
]
->
dtype
(),
dtype
,
multi_tensor_apply
<
BLOCK_SIZE
,
1
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
UnscaleL2NormFunctor
<
dtype
>
(),
device_id
,
stream
,
reinterpret_cast
<
float
*>
(
inv_scale
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
output
.
data
.
dptr
),
per_tensor
?
reinterpret_cast
<
float
*>
(
output_per_tensor
.
data
.
dptr
)
:
nullptr
,
per_tensor
,
max_chunks_per_tensor
);)
NVTE_CHECK_CUDA
(
cudaGetLastError
());
// This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now
auto
ret
=
at
::
empty
({
1
},
output
.
options
()
);
c
onst
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
output
));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
cleanup
<<<
per_tensor
?
ntensors
:
1
,
512
,
0
,
stream
>>>
(
output
.
data_ptr
<
float
>
(),
per_tensor
?
output_per_tensor
.
data
_
ptr
<
float
>
()
:
nullptr
,
ret
.
data_ptr
<
float
>
(),
per_tensor
?
ret_per_tensor
.
data
_
ptr
<
float
>
(
)
:
nullptr
,
per_tensor
,
const
OptionalCUDAGuard
device_guard
(
device_id
);
c
leanup
<<<
per_tensor
?
tensor_lists
[
0
].
size
()
:
1
,
512
,
0
,
stream
>>>
(
reinterpret_cast
<
float
*>
(
output
.
data
.
dptr
),
per_tensor
?
reinterpret_cast
<
float
*>
(
output_per_tensor
.
data
.
dptr
)
:
nullptr
,
reinterpret_cast
<
float
*>
(
ret
.
data
.
d
ptr
)
,
per_tensor
?
reinterpret_cast
<
float
*>
(
ret_per_tensor
.
data
.
d
ptr
)
:
nullptr
,
per_tensor
,
max_chunks_per_tensor
);
}
}
// namespace multi_tensor_l2norm
}
// namespace transformer_engine
void
nvte_multi_tensor_l2norm_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
NVTETensor
output
,
NVTETensor
output_per_tensor
,
NVTETensor
ret
,
NVTETensor
ret_per_tensor
,
int
per_tensor
,
int
max_chunks_per_tensor
,
const
int
device_id
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_tensor_l2norm_cuda
);
using
namespace
transformer_engine
;
multi_tensor_l2norm
::
multi_tensor_l2norm_cuda
(
chunk_size
,
*
reinterpret_cast
<
Tensor
*>
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
*
reinterpret_cast
<
Tensor
*>
(
output
),
*
reinterpret_cast
<
Tensor
*>
(
output_per_tensor
),
*
reinterpret_cast
<
Tensor
*>
(
ret
),
*
reinterpret_cast
<
Tensor
*>
(
ret_per_tensor
),
per_tensor
,
max_chunks_per_tensor
,
device_id
,
stream
);
}
return
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
(
ret
,
ret_per_tensor
);
void
nvte_multi_tensor_unscale_l2norm_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
NVTETensor
output
,
NVTETensor
output_per_tensor
,
NVTETensor
ret
,
NVTETensor
ret_per_tensor
,
NVTETensor
inv_scale
,
int
per_tensor
,
int
max_chunks_per_tensor
,
const
int
device_id
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_tensor_unscale_l2norm_cuda
);
using
namespace
transformer_engine
;
multi_tensor_l2norm
::
multi_tensor_unscale_l2norm_cuda
(
chunk_size
,
*
reinterpret_cast
<
Tensor
*>
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
*
reinterpret_cast
<
Tensor
*>
(
output
),
*
reinterpret_cast
<
Tensor
*>
(
output_per_tensor
),
*
reinterpret_cast
<
Tensor
*>
(
ret
),
*
reinterpret_cast
<
Tensor
*>
(
ret_per_tensor
),
*
reinterpret_cast
<
Tensor
*>
(
inv_scale
),
per_tensor
,
max_chunks_per_tensor
,
device_id
,
stream
);
}
transformer_engine/
pytorch/csrc
/multi_tensor_apply.cuh
→
transformer_engine/
common/multi_tensor
/multi_tensor_apply.cuh
View file @
f8c2af4c
...
...
@@ -5,17 +5,62 @@
************************************************************************/
#pragma once
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include "
common
/common.h"
#include "
..
/common.h"
// This header is the one-stop shop for all your multi-tensor apply needs.
// Change device if needed.
class
OptionalCUDAGuard
{
public:
explicit
OptionalCUDAGuard
(
int
new_device
)
{
if
(
new_device
<
0
)
return
;
int
current_device
;
NVTE_CHECK_CUDA
(
cudaGetDevice
(
&
current_device
));
if
(
new_device
!=
current_device
)
{
NVTE_CHECK_CUDA
(
cudaSetDevice
(
new_device
));
device_changed_
=
true
;
prev_device_
=
current_device
;
}
}
OptionalCUDAGuard
(
const
OptionalCUDAGuard
&
)
=
delete
;
OptionalCUDAGuard
&
operator
=
(
const
OptionalCUDAGuard
&
)
=
delete
;
OptionalCUDAGuard
(
OptionalCUDAGuard
&&
other
)
noexcept
:
prev_device_
(
other
.
prev_device_
),
device_changed_
(
other
.
device_changed_
)
{
other
.
device_changed_
=
false
;
}
OptionalCUDAGuard
&
operator
=
(
OptionalCUDAGuard
&&
other
)
noexcept
{
if
(
this
!=
&
other
)
{
if
(
device_changed_
)
{
cudaSetDevice
(
prev_device_
);
}
prev_device_
=
other
.
prev_device_
;
device_changed_
=
other
.
device_changed_
;
other
.
device_changed_
=
false
;
}
return
*
this
;
}
~
OptionalCUDAGuard
()
{
if
(
device_changed_
)
{
NVTE_CHECK_CUDA
(
cudaSetDevice
(
prev_device_
));
}
}
private:
int
prev_device_
;
bool
device_changed_
=
false
;
};
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr
int
depth_to_max_tensors
[
6
]
=
{
110
,
64
,
48
,
36
,
30
,
24
};
constexpr
int
depth_to_max_blocks
[
6
]
=
{
320
,
320
,
320
,
320
,
320
,
320
};
...
...
@@ -46,62 +91,40 @@ __global__ void __launch_bounds__(block_size) multi_tensor_apply_kernel(int64_t
}
template
<
int64_t
block_size
,
int
depth
,
bool
USE_FP8
=
false
,
typename
T
,
typename
...
ArgTypes
>
void
multi_tensor_apply
(
int64_t
chunk_size
,
const
at
::
Tensor
&
noop_flag
,
const
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
&
tensor_lists
,
T
callable
,
ArgTypes
...
args
)
{
if
constexpr
(
USE_FP8
)
{
TORCH_CHECK
(
tensor_lists
.
size
()
==
depth
+
3
,
"tensor_lists.size() != depth + 3, tensor_lists should have 3 more tensors (scale, "
"amax, scale_inv) for fp8"
);
}
else
{
TORCH_CHECK
(
tensor_lists
.
size
()
==
depth
,
"tensor_lists.size() != depth"
);
}
int
len0
=
tensor_lists
[
0
].
size
();
TORCH_CHECK
(
len0
>
0
,
"tensor_lists[0].size() is not > 0"
);
auto
ref_device
=
tensor_lists
[
0
][
0
].
device
();
TORCH_CHECK
(
ref_device
.
type
()
==
at
::
kCUDA
,
"expected input to be on cuda"
);
for
(
int
l
=
0
;
l
<
depth
;
l
++
)
{
// No range-based for because I need indices
TORCH_CHECK
(
tensor_lists
[
l
].
size
()
==
len0
,
"Size mismatch among tensor lists"
);
for
(
int
t
=
0
;
t
<
tensor_lists
[
l
].
size
();
t
++
)
{
// TODO: Print which tensor fails.
bool
contiguous_memory
=
tensor_lists
[
l
][
t
].
is_contiguous
();
contiguous_memory
=
(
contiguous_memory
||
tensor_lists
[
l
][
t
].
is_contiguous
(
at
::
MemoryFormat
::
ChannelsLast
)
||
tensor_lists
[
l
][
t
].
is_contiguous
(
at
::
MemoryFormat
::
ChannelsLast3d
));
TORCH_CHECK
(
contiguous_memory
,
"A tensor was not contiguous."
);
TORCH_CHECK
(
tensor_lists
[
l
][
t
].
device
()
==
ref_device
,
"A tensor was not on the same device as the first tensor"
);
TORCH_CHECK
(
tensor_lists
[
l
][
t
].
numel
()
==
tensor_lists
[
0
][
t
].
numel
(),
"Size mismatch"
);
}
}
void
multi_tensor_apply
(
int64_t
chunk_size
,
const
transformer_engine
::
Tensor
&
noop_flag
,
std
::
vector
<
std
::
vector
<
transformer_engine
::
Tensor
*>>
tensor_lists
,
T
callable
,
const
int
device_id
,
cudaStream_t
stream
,
ArgTypes
...
args
)
{
const
size_t
num_tensor_lists
=
tensor_lists
.
size
();
const
size_t
num_tensors_per_list
=
tensor_lists
[
0
].
size
();
if
constexpr
(
USE_FP8
)
{
TORCH_CHECK
(
tensor_lists
[
depth
].
size
()
==
len0
&&
tensor_lists
[
depth
+
1
].
size
()
==
len0
,
"Size mismatch among tensor lists"
);
NVTE_CHECK
(
num_tensor_lists
==
depth
+
3
,
"tensor_lists.size() != depth + 3, tensor_lists should have 3 more tensors (scale, "
"amax, scale_inv) for fp8"
);
}
else
{
NVTE_CHECK
(
num_tensor_lists
==
depth
,
"tensor_lists.size() != depth"
);
}
int
ntensors
=
tensor_lists
[
0
].
size
();
TensorListMetadata
<
depth
,
USE_FP8
>
tl
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
tensor_lists
[
0
][
0
]));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
OptionalCUDAGuard
device_guard
(
device_id
);
tl
.
start_tensor_this_launch
=
0
;
int
loc_block_info
=
0
;
int
loc_tensor_info
=
0
;
auto
kernel
=
&
multi_tensor_apply_kernel
<
block_size
,
TensorListMetadata
<
depth
,
USE_FP8
>
,
T
,
ArgTypes
...
>
;
for
(
int
t
=
0
;
t
<
ntensors
;
t
++
)
{
tl
.
sizes
[
loc_tensor_info
]
=
tensor_lists
[
0
][
t
]
.
numel
();
for
(
int
t
=
0
;
t
<
n
um_
tensors
_per_list
;
t
++
)
{
tl
.
sizes
[
loc_tensor_info
]
=
tensor_lists
[
0
][
t
]
->
numel
();
for
(
int
d
=
0
;
d
<
depth
;
d
++
)
tl
.
addresses
[
d
][
loc_tensor_info
]
=
tensor_lists
[
d
][
t
]
.
data
_
ptr
()
;
tl
.
addresses
[
d
][
loc_tensor_info
]
=
tensor_lists
[
d
][
t
]
->
data
.
d
ptr
;
if
constexpr
(
USE_FP8
)
{
for
(
int
i
=
0
;
i
<
3
;
i
++
)
tl
.
fp8_meta_addresses
[
i
][
loc_tensor_info
]
=
tensor_lists
[
depth
+
i
][
t
]
.
data
_
ptr
()
;
tl
.
fp8_meta_addresses
[
i
][
loc_tensor_info
]
=
tensor_lists
[
depth
+
i
][
t
]
->
data
.
d
ptr
;
}
loc_tensor_info
++
;
auto
chunks_this_tensor
=
(
tensor_lists
[
0
][
t
]
.
numel
()
+
chunk_size
-
1
)
/
chunk_size
;
auto
chunks_this_tensor
=
(
tensor_lists
[
0
][
t
]
->
numel
()
+
chunk_size
-
1
)
/
chunk_size
;
for
(
auto
chunk
=
0
;
chunk
<
chunks_this_tensor
;
chunk
++
)
{
tl
.
block_to_tensor
[
loc_block_info
]
=
loc_tensor_info
-
1
;
...
...
@@ -111,12 +134,12 @@ void multi_tensor_apply(int64_t chunk_size, const at::Tensor &noop_flag,
bool
tensors_full
=
(
loc_tensor_info
==
depth_to_max_tensors
[
depth
-
1
]
&&
chunk
==
chunks_this_tensor
-
1
);
bool
blocks_full
=
(
loc_block_info
==
depth_to_max_blocks
[
depth
-
1
]);
bool
last_chunk
=
(
t
==
ntensors
-
1
&&
chunk
==
chunks_this_tensor
-
1
);
bool
last_chunk
=
(
t
==
n
um_
tensors
_per_list
-
1
&&
chunk
==
chunks_this_tensor
-
1
);
if
(
tensors_full
||
blocks_full
||
last_chunk
)
{
kernel
<<<
loc_block_info
,
block_size
,
0
,
stream
>>>
(
chunk_size
,
noop_flag
.
data
_
ptr
<
int
>
(
),
tl
,
callable
,
args
...);
chunk_size
,
reinterpret_cast
<
int
*>
(
noop_flag
.
data
.
d
ptr
),
tl
,
callable
,
args
...);
AT_CUDA
_CHECK
(
cudaGetLastError
());
NVTE
_CHECK
_CUDA
(
cudaGetLastError
());
// Reset. The control flow possibilities here make my brain hurt.
loc_block_info
=
0
;
...
...
transformer_engine/
pytorch/csrc/extensi
on
s
/multi_tensor/
multi_tensor_scale_kernel
.cu
→
transformer_engine/
comm
on/multi_tensor/
scale
.cu
View file @
f8c2af4c
...
...
@@ -4,19 +4,20 @@
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include <cuda_fp8.h>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include <iostream>
#include <sstream>
#include "../utils.cuh"
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace
transformer_engine
{
namespace
multi_tensor_scale
{
#define BLOCK_SIZE 512
#define ILP 4
...
...
@@ -66,7 +67,7 @@ struct ScaleFunctor {
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
r_out
[
ii
]
=
static_cast
<
float
>
(
r_in
[
ii
])
*
scale
;
finite
=
finite
&&
isfinite
(
r_in
[
ii
]);
finite
=
finite
&&
isfinite
(
static_cast
<
float
>
(
r_in
[
ii
])
)
;
}
// store
load_store
(
out
,
r_out
,
i_start
,
0
);
...
...
@@ -76,7 +77,7 @@ struct ScaleFunctor {
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
r_in
[
ii
]
=
0
;
r_in
[
ii
]
=
0
.
f
;
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
r_in
[
ii
]
=
in
[
i
];
}
...
...
@@ -88,7 +89,7 @@ struct ScaleFunctor {
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
r_out
[
ii
]
=
static_cast
<
float
>
(
r_in
[
ii
])
*
scale
;
finite
=
finite
&&
isfinite
(
r_in
[
ii
]);
finite
=
finite
&&
isfinite
(
static_cast
<
float
>
(
r_in
[
ii
])
)
;
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
...
...
@@ -101,20 +102,29 @@ struct ScaleFunctor {
}
};
void
multi_tensor_scale_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
scale
)
{
using
namespace
at
;
// The output (downscaled) type is always float.
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
DISPATCH_FLOAT_HALF_AND_BFLOAT
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_scale_cuda"
,
DISPATCH_FLOAT_HALF_AND_BFLOAT
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"multi_tensor_scale_cuda"
,
void
multi_tensor_scale_cuda
(
int
chunk_size
,
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
Tensor
*>>
tensor_lists
,
float
scale
,
const
int
device_id
,
cudaStream_t
stream
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
tensor_lists
[
0
][
0
]
->
dtype
(),
p_in_type
,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
tensor_lists
[
1
][
0
]
->
dtype
(),
g_in_type
,
multi_tensor_apply
<
BLOCK_SIZE
,
2
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
ScaleFunctor
<
scalar_t_0
,
scalar_t_1
>
(),
scale
);))
AT_CUDA_CHECK
(
cudaGetLastError
());
ScaleFunctor
<
p_in_type
,
g_in_type
>
(),
device_id
,
stream
,
scale
);))
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
// namespace multi_tensor_scale
}
// namespace transformer_engine
void
nvte_multi_tensor_scale_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
float
scale
,
const
int
device_id
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_tensor_scale_cuda
);
using
namespace
transformer_engine
;
// AT_CUDA_CHECK(cudaDeviceSynchronize());
multi_tensor_scale
::
multi_tensor_scale_cuda
(
chunk_size
,
*
reinterpret_cast
<
Tensor
*>
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
scale
,
device_id
,
stream
);
}
transformer_engine/
pytorch/csrc/extensi
on
s
/multi_tensor/
multi_tensor_sgd_kernel
.cu
→
transformer_engine/
comm
on/multi_tensor/
sgd
.cu
View file @
f8c2af4c
...
...
@@ -4,14 +4,16 @@
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h>
#include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include "../utils.cuh"
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace
transformer_engine
{
namespace
multi_tensor_sgd
{
#define BLOCK_SIZE 512
#define ILP 4
...
...
@@ -54,9 +56,9 @@ struct SGDFunctor {
T_weight
*
mom_in
=
reinterpret_cast
<
T_weight
*>
(
tl
.
addresses
[
2
][
tensor_loc
]);
mom_in
+=
chunk_idx
*
chunk_size
;
at
::
Half
*
model_weights_out
=
nullptr
;
fp16
*
model_weights_out
=
nullptr
;
if
(
N
==
4
)
{
model_weights_out
=
(
at
::
Half
*
)
tl
.
addresses
[
3
][
tensor_loc
];
model_weights_out
=
reinterpret_cast
<
fp16
*>
(
tl
.
addresses
[
3
][
tensor_loc
]
)
;
model_weights_out
+=
chunk_idx
*
chunk_size
;
}
...
...
@@ -112,7 +114,7 @@ struct SGDFunctor {
weight_in
[
i
]
+=
(
-
lr
*
incoming_grads
[
ii
]);
// if necessary, write out an fp16 copy of the weights
if
(
N
==
4
)
model_weights_out
[
i
]
=
static_cast
<
at
::
Half
>
(
weight_in
[
i
]);
if
(
N
==
4
)
model_weights_out
[
i
]
=
static_cast
<
fp16
>
(
weight_in
[
i
]);
// also write out the new momentum
if
(
momentum
!=
0.
f
)
mom_in
[
i
]
=
incoming_moms
[
ii
];
...
...
@@ -122,23 +124,23 @@ struct SGDFunctor {
}
};
void
multi_tensor_sgd_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
wd
,
float
momentum
,
float
dampening
,
float
lr
,
bool
nesterov
,
bool
first_run
,
bool
wd_after_momentum
,
float
scale
)
{
auto
num_tensors
=
tensor_lists
.
size
();
auto
grad_type
=
tensor_lists
[
0
][
0
].
scalar_type
();
auto
weight_type
=
tensor_lists
[
1
][
0
].
scalar_type
();
if
(
num_tensors
==
4
)
{
for
(
int
i
=
0
;
i
<
tensor_lists
[
3
].
size
();
i
++
)
TORCH_CHECK
(
tensor_lists
[
3
][
i
].
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Additional output tensors should always be fp16."
);
void
multi_tensor_sgd_cuda
(
int
chunk_size
,
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
Tensor
*>>
tensor_lists
,
float
wd
,
float
momentum
,
float
dampening
,
float
lr
,
bool
nesterov
,
bool
first_run
,
bool
wd_after_momentum
,
float
scale
,
const
int
device_id
,
cudaStream_t
stream
)
{
const
size_t
num_tensor_lists
=
tensor_lists
.
size
();
const
size_t
num_tensors_per_list
=
tensor_lists
[
0
].
size
();
auto
grad_type
=
tensor_lists
[
0
][
0
]
->
dtype
();
auto
weight_type
=
tensor_lists
[
1
][
0
]
->
dtype
();
if
(
num_tensor_lists
==
4
)
{
for
(
int
i
=
0
;
i
<
num_tensors_per_list
;
i
++
)
NVTE_CHECK
(
tensor_lists
[
3
][
i
]
->
dtype
()
==
DType
::
kFloat16
,
"Additional output tensors should always be fp16."
);
}
TORCH_CHECK
(
noop_flag
.
device
()
==
tensor_lists
[
0
][
0
].
device
(),
"expected noop flag to be on the same device as tensors"
);
// We have 3 possibilities to handle here, in terms of
// grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No
...
...
@@ -150,53 +152,51 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
// we don't want the majority of them.
// Case 1. fp16, fp16, fp16, No
if
(
grad_type
==
at
::
ScalarType
::
Half
&&
weight_type
==
at
::
ScalarType
::
Half
&&
num_tensors
==
3
)
{
if
(
grad_type
==
DType
::
kFloat16
&&
weight_type
==
DType
::
kFloat16
&&
num_tensor_lists
==
3
)
{
multi_tensor_apply
<
BLOCK_SIZE
,
3
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
3
,
at
::
Half
,
at
::
Half
>
()
,
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
SGDFunctor
<
3
,
fp16
,
fp16
>
(),
device_id
,
stream
,
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
}
// Case 2. fp16, fp32, fp32, No
// else if (grad_type == at::ScalarType::Half &&
// weight_type == at::ScalarType::Float &&
// num_tensors == 3) {
// multi_tensor_apply<BLOCK_SIZE, 3>(
// chunk_size,
// noop_flag,
// tensor_lists,
// SGDFunctor<3, at::Half, float>(),
// wd,
// momentum,
// dampening,
// lr,
// nesterov,
// first_run,
// wd_after_momentum);
// }
// Case 2. fp32, fp32, fp32, No
else
if
(
grad_type
==
at
::
Scalar
Type
::
Float
&&
// NOLINT(*)
weight_type
==
at
::
Scalar
Type
::
Float
&&
num_tensors
==
3
)
{
else
if
(
grad_type
==
D
Type
::
k
Float
32
&&
// NOLINT(*)
weight_type
==
D
Type
::
k
Float
32
&&
num_tensor
_list
s
==
3
)
{
multi_tensor_apply
<
BLOCK_SIZE
,
3
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
3
,
float
,
float
>
(),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
SGDFunctor
<
3
,
float
,
float
>
(),
device_id
,
stream
,
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
}
// Case 3. fp16, fp32, fp32, Yes
else
if
(
grad_type
==
at
::
ScalarType
::
Half
&&
// NOLINT(*)
weight_type
==
at
::
Scalar
Type
::
Float
&&
num_tensors
==
4
)
{
else
if
(
grad_type
==
DType
::
kFloat16
&&
// NOLINT(*)
weight_type
==
D
Type
::
k
Float
32
&&
num_tensor
_list
s
==
4
)
{
multi_tensor_apply
<
BLOCK_SIZE
,
4
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
4
,
at
::
Half
,
float
>
(),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
SGDFunctor
<
4
,
fp16
,
float
>
(),
device_id
,
stream
,
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
}
// Case 4. fp32, fp32, fp32, Yes
else
if
(
grad_type
==
at
::
Scalar
Type
::
Float
&&
// NOLINT(*)
weight_type
==
at
::
Scalar
Type
::
Float
&&
num_tensors
==
4
)
{
else
if
(
grad_type
==
D
Type
::
k
Float
32
&&
// NOLINT(*)
weight_type
==
D
Type
::
k
Float
32
&&
num_tensor
_list
s
==
4
)
{
multi_tensor_apply
<
BLOCK_SIZE
,
4
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
4
,
float
,
float
>
(),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
SGDFunctor
<
4
,
float
,
float
>
(),
device_id
,
stream
,
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
}
else
{
AT_ERROR
(
"multi_tensor_sgd only supports some combinations of gradient & weight types. Given: "
,
"gradient: "
,
grad_type
,
", weight: "
,
weight_type
,
", num_lists: "
,
num_tensors
);
NVTE_ERROR
(
"Unsupported combination of weight and gradient types."
);
}
AT_CUDA_CHECK
(
cudaGetLastError
());
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
// namespace multi_tensor_sgd
}
// namespace transformer_engine
void
nvte_multi_tensor_sgd_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
float
wd
,
float
momentum
,
float
dampening
,
float
lr
,
int
nesterov
,
int
first_run
,
int
wd_after_momentum
,
float
scale
,
const
int
device_id
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_tensor_sgd_cuda
);
using
namespace
transformer_engine
;
multi_tensor_sgd
::
multi_tensor_sgd_cuda
(
chunk_size
,
*
reinterpret_cast
<
Tensor
*>
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
,
device_id
,
stream
);
}
transformer_engine/common/permutation/permutation.cu
View file @
f8c2af4c
...
...
@@ -6,6 +6,8 @@
#include <transformer_engine/permutation.h>
#include <cub/cub.cuh>
#include "../common.h"
#ifdef __HIP_PLATFORM_AMD__
...
...
@@ -385,3 +387,11 @@ void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id
reinterpret_cast
<
const
float
*>
(
prob_cu
->
data
.
dptr
),
num_rows
,
topK
,
num_cols
,
stream
););
}
void
nvte_device_radix_sort_pairs
(
void
*
temp_storage
,
size_t
*
temp_storage_bytes
,
int
*
keys_in
,
int
*
keys_out
,
int
*
values_in
,
int
*
values_out
,
size_t
num_items
)
{
NVTE_API_CALL
(
nvte_device_radix_sort_pairs
);
cub
::
DeviceRadixSort
::
SortPairs
(
temp_storage
,
*
temp_storage_bytes
,
keys_in
,
keys_out
,
values_in
,
values_out
,
num_items
);
}
transformer_engine/common/recipe/fp8_block_scaling.cu
0 → 100644
View file @
f8c2af4c
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/recipe.h>
#include <cassert>
#include "../common.h"
#include "../utils.cuh"
namespace
transformer_engine
{
namespace
fp8_block_scaling_recipe
{
constexpr
int
kTileDim
=
128
;
constexpr
int
kThreadsPerBlock
=
256
;
template
<
typename
IType
>
__global__
void
__launch_bounds__
(
kThreadsPerBlock
)
fp8_block_scaling_compute_partial_amax_kernel
(
const
IType
*
input
,
float
*
amax_ptr
,
const
size_t
amax_stride_h
,
const
size_t
amax_stride_w
,
const
size_t
h
,
const
size_t
w
,
const
size_t
start_offset
,
const
size_t
len
)
{
constexpr
int
kThreadsPerWarp
=
32
;
constexpr
int
kLoopsPerRow
=
kTileDim
/
kThreadsPerWarp
;
constexpr
int
kNumWarps
=
kThreadsPerBlock
/
kThreadsPerWarp
;
constexpr
int
kLoopsPerCol
=
kTileDim
/
kNumWarps
;
const
int
tile_col
=
blockIdx
.
x
;
const
int
tile_row
=
blockIdx
.
y
;
const
size_t
end_offset
=
start_offset
+
len
;
const
IType
*
input_minus_offset
=
input
-
start_offset
;
__shared__
float
smem
[
kNumWarps
];
float
amax
=
0.0
f
;
for
(
int
loop_col
=
0
;
loop_col
<
kLoopsPerCol
;
++
loop_col
)
{
size_t
r
=
tile_row
*
kTileDim
+
loop_col
*
kNumWarps
+
threadIdx
.
x
/
kThreadsPerWarp
;
for
(
int
loop_row
=
0
;
loop_row
<
kLoopsPerRow
;
++
loop_row
)
{
size_t
c
=
tile_col
*
kTileDim
+
loop_row
*
kThreadsPerWarp
+
(
threadIdx
.
x
%
kThreadsPerWarp
);
size_t
idx
=
r
*
w
+
c
;
if
(
r
<
h
&&
c
<
w
&&
idx
>=
start_offset
&&
idx
<
end_offset
)
{
float
other_amax
=
fabs
(
static_cast
<
float
>
(
input_minus_offset
[
idx
]));
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
}
}
for
(
int
delta
=
kThreadsPerWarp
/
2
;
delta
>
0
;
delta
/=
2
)
{
float
other_amax
=
__shfl_down_sync
(
0xFFFFFFFF
,
amax
,
delta
);
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
if
(
threadIdx
.
x
%
kThreadsPerWarp
==
0
)
{
smem
[
threadIdx
.
x
/
kThreadsPerWarp
]
=
amax
;
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
for
(
int
i
=
0
;
i
<
kNumWarps
;
++
i
)
{
float
other_amax
=
smem
[
i
];
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
amax_ptr
[
tile_row
*
amax_stride_h
+
tile_col
*
amax_stride_w
]
=
amax
;
}
}
template
<
typename
IType
,
typename
OType
,
bool
kWidthAligned
>
__global__
void
__launch_bounds__
(
kThreadsPerBlock
)
fp8_block_scaling_partial_cast_kernel
(
const
IType
*
input
,
OType
*
output
,
const
float
*
scale_ptr
,
const
size_t
scale_stride_h
,
const
size_t
scale_stride_w
,
const
size_t
h
,
const
size_t
w
,
const
size_t
start_offset
,
const
size_t
len
)
{
using
transformer_engine
::
Vec
;
static_assert
(
sizeof
(
OType
)
==
1
);
constexpr
int
kNumOutputElemsPerBank
=
4
/
sizeof
(
OType
);
constexpr
int
kThreadsPerWarp
=
32
;
constexpr
int
kLoopsPerRow
=
kTileDim
/
kThreadsPerWarp
;
constexpr
int
kNumWarps
=
kThreadsPerBlock
/
kThreadsPerWarp
;
constexpr
int
kRowsPerWarp
=
kTileDim
/
kNumWarps
;
__shared__
OType
smem
[
kTileDim
][
kTileDim
+
kNumOutputElemsPerBank
];
const
int
tile_w
=
blockIdx
.
x
;
const
int
tile_h
=
blockIdx
.
y
;
const
size_t
end_offset
=
start_offset
+
len
;
const
IType
*
input_minus_offset
=
input
-
start_offset
;
OType
*
output_minus_offset
=
output
-
start_offset
;
const
float
scale
=
scale_ptr
[
tile_h
*
scale_stride_h
+
tile_w
*
scale_stride_w
];
// Load input data into shared memory
bool
skip_store
=
true
;
for
(
int
i
=
0
;
i
<
kRowsPerWarp
;
++
i
)
{
for
(
int
j
=
0
;
j
<
kLoopsPerRow
;
++
j
)
{
const
int
h_in_smem
=
threadIdx
.
x
/
kThreadsPerWarp
*
kRowsPerWarp
+
i
;
const
int
w_in_smem
=
threadIdx
.
x
%
kThreadsPerWarp
+
kThreadsPerWarp
*
j
;
const
int
h_in_input
=
tile_h
*
kTileDim
+
h_in_smem
;
const
int
w_in_input
=
tile_w
*
kTileDim
+
w_in_smem
;
const
size_t
idx_in_input
=
static_cast
<
size_t
>
(
h_in_input
)
*
w
+
w_in_input
;
if
(
h_in_input
<
h
&&
w_in_input
<
w
&&
idx_in_input
>=
start_offset
&&
idx_in_input
<
end_offset
)
{
float
inp
=
static_cast
<
float
>
(
input_minus_offset
[
idx_in_input
])
*
scale
;
smem
[
h_in_smem
][
w_in_smem
]
=
static_cast
<
OType
>
(
inp
);
skip_store
=
false
;
}
}
}
for
(
int
delta
=
kThreadsPerWarp
/
2
;
delta
>
0
;
delta
/=
2
)
{
bool
other_skip_store
=
__shfl_down_sync
(
0xFFFFFFFF
,
skip_store
,
delta
);
skip_store
=
skip_store
&&
other_skip_store
;
}
skip_store
=
__shfl_sync
(
0xFFFFFFFF
,
skip_store
,
0
);
if
(
skip_store
)
{
return
;
}
// Store the casted data into the output.
// Note that this store operation might write "out-of-bounds", but it is intentional:
// 1. The "out-of-bounds" here only crosses the boundary of the "local shard" (i.e., the region
// from start_offset to end_offset), not the boundary of the entire output memory. Therefore,
// this out-of-bounds write will not cause illegal memory access.
// 2. We assume that the subsequent all-gather operation happens in-place, so any parts that
// should not be updated here will be overwritten by the all-gather.
// This tricky approach allows us to avoid checking whether each output index falls within
// [start, end), resulting in a significant performance improvement.
Vec
<
OType
,
kNumOutputElemsPerBank
>
vec_output
;
for
(
int
i
=
0
;
i
<
kRowsPerWarp
;
++
i
)
{
const
int
row_in_smem
=
threadIdx
.
x
/
kThreadsPerWarp
*
kRowsPerWarp
+
i
;
const
int
col_in_smem
=
threadIdx
.
x
%
kThreadsPerWarp
*
kNumOutputElemsPerBank
;
for
(
int
j
=
0
;
j
<
kNumOutputElemsPerBank
;
++
j
)
{
vec_output
.
data
.
elt
[
j
]
=
smem
[
row_in_smem
][
col_in_smem
+
j
];
}
const
int
row_in_output
=
tile_h
*
kTileDim
+
row_in_smem
;
const
int
col_in_output
=
tile_w
*
kTileDim
+
col_in_smem
;
const
size_t
idx_in_output
=
static_cast
<
size_t
>
(
row_in_output
)
*
w
+
col_in_output
;
if
(
row_in_output
<
h
)
{
if
constexpr
(
kWidthAligned
)
{
vec_output
.
store_to
(
output_minus_offset
+
idx_in_output
);
}
else
{
int
num
=
min
(
static_cast
<
size_t
>
(
kNumOutputElemsPerBank
),
static_cast
<
size_t
>
(
col_in_output
<
w
?
w
-
col_in_output
:
0
));
vec_output
.
store_to_elts
(
output_minus_offset
,
idx_in_output
,
num
);
}
}
}
}
void
fp8_block_scaling_compute_partial_amax
(
const
Tensor
inp
,
Tensor
amax
,
size_t
h
,
size_t
w
,
size_t
amax_stride_h
,
size_t
amax_stride_w
,
size_t
start_offset
,
size_t
block_len
,
cudaStream_t
stream
)
{
NVTE_CHECK
(
block_len
==
128
,
"Currently only block_len = 128 is supported"
);
size_t
len
=
inp
.
numel
();
assert
(
h
>
0
&&
w
>
0
);
assert
(
start_offset
<
h
*
w
);
assert
(
start_offset
+
len
<=
h
*
w
);
size_t
blocks_x
=
(
w
+
kTileDim
-
1
)
/
kTileDim
;
size_t
blocks_y
=
(
h
+
kTileDim
-
1
)
/
kTileDim
;
assert
(
blocks_x
<=
std
::
numeric_limits
<
unsigned
int
>::
max
());
assert
(
blocks_y
<=
std
::
numeric_limits
<
unsigned
int
>::
max
());
dim3
grid
(
blocks_x
,
blocks_y
);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
inp
.
dtype
(),
inp_dtype
,
fp8_block_scaling_compute_partial_amax_kernel
<
inp_dtype
>
<<<
grid
,
kThreadsPerBlock
,
0
,
stream
>>>
(
reinterpret_cast
<
const
inp_dtype
*>
(
inp
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
amax
.
data
.
dptr
),
amax_stride_h
,
amax_stride_w
,
h
,
w
,
start_offset
,
len
);)
}
void
fp8_block_scaling_partial_cast
(
const
Tensor
inp
,
Tensor
out
,
const
Tensor
scale
,
size_t
h
,
size_t
w
,
size_t
scale_stride_h
,
size_t
scale_stride_w
,
size_t
start_offset
,
size_t
block_len
,
const
DType
out_dtype
,
cudaStream_t
stream
)
{
NVTE_CHECK
(
block_len
==
128
,
"Currently only block_len = 128 is supported"
);
size_t
len
=
inp
.
numel
();
assert
(
h
>
0
&&
w
>
0
);
assert
(
start_offset
<
h
*
w
);
assert
(
start_offset
+
len
<=
h
*
w
);
size_t
blocks_x
=
(
w
+
kTileDim
-
1
)
/
kTileDim
;
size_t
blocks_y
=
(
h
+
kTileDim
-
1
)
/
kTileDim
;
assert
(
blocks_x
<=
std
::
numeric_limits
<
unsigned
int
>::
max
());
assert
(
blocks_y
<=
std
::
numeric_limits
<
unsigned
int
>::
max
());
dim3
grid
(
blocks_x
,
blocks_y
);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
inp
.
dtype
(),
inp_dtype
,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY
(
out_dtype
,
fp8_type
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
w
%
kTileDim
==
0
,
kWidthAligned
,
fp8_block_scaling_partial_cast_kernel
<
inp_dtype
,
fp8_type
,
kWidthAligned
>
<<<
grid
,
kThreadsPerBlock
,
0
,
stream
>>>
(
reinterpret_cast
<
const
inp_dtype
*>
(
inp
.
data
.
dptr
),
reinterpret_cast
<
fp8_type
*>
(
out
.
data
.
dptr
),
reinterpret_cast
<
const
float
*>
(
scale
.
data
.
dptr
),
scale_stride_h
,
scale_stride_w
,
h
,
w
,
start_offset
,
len
);)))
}
}
// namespace fp8_block_scaling_recipe
}
// namespace transformer_engine
void
nvte_fp8_block_scaling_compute_partial_amax
(
const
NVTETensor
inp
,
NVTETensor
amax
,
size_t
h
,
size_t
w
,
size_t
amax_stride_h
,
size_t
amax_stride_w
,
size_t
start_offset
,
size_t
block_len
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_fp8_block_scaling_compute_partial_amax
);
using
namespace
transformer_engine
;
fp8_block_scaling_recipe
::
fp8_block_scaling_compute_partial_amax
(
*
reinterpret_cast
<
const
Tensor
*>
(
inp
),
*
reinterpret_cast
<
Tensor
*>
(
amax
),
h
,
w
,
amax_stride_h
,
amax_stride_w
,
start_offset
,
block_len
,
stream
);
}
void
nvte_fp8_block_scaling_partial_cast
(
const
NVTETensor
inp
,
NVTETensor
out
,
const
NVTETensor
scale
,
size_t
h
,
size_t
w
,
size_t
scale_stride_h
,
size_t
scale_stride_w
,
size_t
start_offset
,
size_t
block_len
,
const
NVTEDType
out_dtype
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_fp8_block_scaling_partial_cast
);
using
namespace
transformer_engine
;
fp8_block_scaling_recipe
::
fp8_block_scaling_partial_cast
(
*
reinterpret_cast
<
const
Tensor
*>
(
inp
),
*
reinterpret_cast
<
Tensor
*>
(
out
),
*
reinterpret_cast
<
const
Tensor
*>
(
scale
),
h
,
w
,
scale_stride_h
,
scale_stride_w
,
start_offset
,
block_len
,
static_cast
<
DType
>
(
out_dtype
),
stream
);
}
transformer_engine/common/transformer_engine.cpp
View file @
f8c2af4c
...
...
@@ -10,6 +10,7 @@
#include <iostream>
#include "common.h"
#include "common/util/cuda_runtime.h"
namespace
transformer_engine
{
...
...
@@ -48,11 +49,11 @@ std::string to_string(const DType type) {
std
::
string
to_string
(
const
NVTEScalingMode
&
mode
)
{
switch
(
mode
)
{
case
NVTE_DELAYED_TENSOR_SCALING
:
return
"
Delayed Tensor Scaling
"
;
return
"
NVTE_DELAYED_TENSOR_SCALING
"
;
case
NVTE_MXFP8_1D_SCALING
:
return
"MXFP8
1D
Scaling
"
;
return
"
NVTE_
MXFP8
_
1D
_SCALING
"
;
case
NVTE_INVALID_SCALING
:
return
"
Invalid Scaling
"
;
return
"
NVTE_INVALID_SCALING
"
;
}
return
"Invalid Scaling"
;
}
...
...
@@ -214,15 +215,13 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor) {
NVTEShape
nvte_make_shape
(
const
size_t
*
data
,
size_t
ndim
)
{
NVTEShape
ret
;
if
(
ndim
==
0
)
{
ret
.
data
=
nullptr
;
ret
.
ndim
=
0
;
return
ret
;
}
NVTE_CHECK
(
ndim
<=
sizeof
(
ret
.
owned_
data
)
/
sizeof
(
ret
.
owned_
data
[
0
]),
NVTE_CHECK
(
ndim
<=
sizeof
(
ret
.
data
)
/
sizeof
(
ret
.
data
[
0
]),
"Too many dims for NVTEShape (requested: "
,
ndim
,
", max: "
,
sizeof
(
ret
.
owned_data
)
/
sizeof
(
ret
.
owned_data
[
0
]),
")"
);
std
::
copy
(
data
,
data
+
ndim
,
ret
.
owned_data
);
ret
.
data
=
ret
.
owned_data
;
", max: "
,
sizeof
(
ret
.
data
)
/
sizeof
(
ret
.
data
[
0
]),
")"
);
std
::
copy
(
data
,
data
+
ndim
,
ret
.
data
);
ret
.
ndim
=
ndim
;
return
ret
;
}
...
...
@@ -350,7 +349,7 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
NVTEBasicTensor
nvte_get_tensor_param
(
const
NVTETensor
tensor
,
NVTETensorParam
param_name
)
{
if
(
tensor
==
nullptr
)
{
return
{
nullptr
,
kNVTEFloat32
,
{
nullptr
,
0
}
};
return
{
nullptr
,
kNVTEFloat32
,
nvte_make_shape
(
nullptr
,
0
)
};
}
const
auto
&
t
=
*
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
);
switch
(
param_name
)
{
...
...
@@ -483,3 +482,13 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
delete
reinterpret_cast
<
transformer_engine
::
QuantizationConfig
*>
(
config
);
}
}
int
nvte_is_non_tn_fp8_gemm_supported
()
{
int
deviceComputeCapability
=
transformer_engine
::
cuda
::
sm_arch
(
transformer_engine
::
cuda
::
current_device
());
// Note: this is temporary restriction and should be lifted in the future.
// (remove the note once it's done.)
return
(
deviceComputeCapability
>=
100
&&
deviceComputeCapability
<
120
)
||
deviceComputeCapability
>=
130
;
}
transformer_engine/common/util/cuda_runtime.cpp
View file @
f8c2af4c
...
...
@@ -134,9 +134,15 @@ bool supports_multicast(int device_id) {
auto
init
=
[
&
]()
{
CUdevice
cudev
;
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuDeviceGet
,
&
cudev
,
device_id
);
int
result
;
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuDeviceGetAttribute
,
&
result
,
CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED
,
cudev
);
// Multicast support requires both CUDA12.1 UMD + KMD
int
result
=
0
;
// Check if KMD >= 12.1
int
driver_version
;
NVTE_CHECK_CUDA
(
cudaDriverGetVersion
(
&
driver_version
));
if
(
driver_version
>=
12010
)
{
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuDeviceGetAttribute
,
&
result
,
CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED
,
cudev
);
}
cache
[
device_id
]
=
static_cast
<
bool
>
(
result
);
};
std
::
call_once
(
flags
[
device_id
],
init
);
...
...
transformer_engine/common/util/logging.h
View file @
f8c2af4c
...
...
@@ -23,10 +23,18 @@
#endif // __HIP_PLATFORM_AMD__
#include <nvrtc.h>
#include <iostream>
#include <stdexcept>
#include "../util/string.h"
#define NVTE_WARN(...) \
do { \
std::cerr << ::transformer_engine::concat_strings( \
__FILE__ ":", __LINE__, " in function ", __func__, ": ", \
::transformer_engine::concat_strings(__VA_ARGS__), "\n"); \
} while (false)
#define NVTE_ERROR(...) \
do { \
throw ::std::runtime_error(::transformer_engine::concat_strings( \
...
...
Prev
1
2
3
4
5
6
7
8
9
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment