Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
2b05e121
Commit
2b05e121
authored
Jun 17, 2025
by
yuguo
Browse files
Merge commit '
a69692ac
' of...
Merge commit '
a69692ac
' of
https://github.com/NVIDIA/TransformerEngine
parents
0fd441c2
a69692ac
Changes
245
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
395 additions
and
264 deletions
+395
-264
transformer_engine/common/multi_tensor/compute_scale.cu
transformer_engine/common/multi_tensor/compute_scale.cu
+1
-1
transformer_engine/common/multi_tensor/l2norm.cu
transformer_engine/common/multi_tensor/l2norm.cu
+7
-7
transformer_engine/common/multi_tensor/scale.cu
transformer_engine/common/multi_tensor/scale.cu
+1
-1
transformer_engine/common/multi_tensor/sgd.cu
transformer_engine/common/multi_tensor/sgd.cu
+1
-1
transformer_engine/common/normalization/common.cpp
transformer_engine/common/normalization/common.cpp
+20
-13
transformer_engine/common/normalization/common.h
transformer_engine/common/normalization/common.h
+5
-2
transformer_engine/common/normalization/layernorm/ln_api.cpp
transformer_engine/common/normalization/layernorm/ln_api.cpp
+23
-18
transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
...ormer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
+22
-16
transformer_engine/common/permutation/permutation.cu
transformer_engine/common/permutation/permutation.cu
+13
-22
transformer_engine/common/recipe/__init__.py
transformer_engine/common/recipe/__init__.py
+15
-52
transformer_engine/common/recipe/current_scaling.cu
transformer_engine/common/recipe/current_scaling.cu
+3
-3
transformer_engine/common/recipe/delayed_scaling.cu
transformer_engine/common/recipe/delayed_scaling.cu
+7
-7
transformer_engine/common/recipe/fp8_block_scaling.cu
transformer_engine/common/recipe/fp8_block_scaling.cu
+5
-5
transformer_engine/common/swizzle/swizzle.cu
transformer_engine/common/swizzle/swizzle.cu
+1
-2
transformer_engine/common/transformer_engine.cpp
transformer_engine/common/transformer_engine.cpp
+216
-62
transformer_engine/common/transpose/cast_transpose.cu
transformer_engine/common/transpose/cast_transpose.cu
+5
-5
transformer_engine/common/transpose/cast_transpose.h
transformer_engine/common/transpose/cast_transpose.h
+11
-9
transformer_engine/common/transpose/cast_transpose_fusion.cu
transformer_engine/common/transpose/cast_transpose_fusion.cu
+34
-33
transformer_engine/common/transpose/multi_cast_transpose.cu
transformer_engine/common/transpose/multi_cast_transpose.cu
+4
-4
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
...e/common/transpose/quantize_transpose_square_blockwise.cu
+1
-1
No files found.
transformer_engine/common/multi_tensor/compute_scale.cu
View file @
2b05e121
...
...
@@ -77,7 +77,7 @@ void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
using
namespace
transformer_engine
;
multi_tensor_compute_scale
::
multi_tensor_compute_scale_and_scale_inv_cuda
(
chunk_size
,
*
reinterpret_cast
<
Tensor
*>
(
noop_flag
),
chunk_size
,
*
convertNVTE
Tensor
Check
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
max_fp8
,
force_pow_2_scales
,
epsilon
,
device_id
,
stream
);
}
transformer_engine/common/multi_tensor/l2norm.cu
View file @
2b05e121
...
...
@@ -467,10 +467,10 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
using
namespace
transformer_engine
;
multi_tensor_l2norm
::
multi_tensor_l2norm_cuda
(
chunk_size
,
*
reinterpret_cast
<
Tensor
*>
(
noop_flag
),
chunk_size
,
*
convertNVTE
Tensor
Check
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
*
reinterpret_cast
<
Tensor
*>
(
output
),
*
reinterpret_cast
<
Tensor
*>
(
output_per_tensor
),
*
reinterpret_cast
<
Tensor
*>
(
ret
),
*
reinterpret_cast
<
Tensor
*>
(
ret_per_tensor
),
per_tensor
,
*
convertNVTE
Tensor
Check
(
output
),
*
convertNVTE
Tensor
Check
(
output_per_tensor
),
*
convertNVTE
Tensor
Check
(
ret
),
*
convertNVTE
Tensor
Check
(
ret_per_tensor
),
per_tensor
,
max_chunks_per_tensor
,
device_id
,
stream
);
}
...
...
@@ -485,9 +485,9 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
using
namespace
transformer_engine
;
multi_tensor_l2norm
::
multi_tensor_unscale_l2norm_cuda
(
chunk_size
,
*
reinterpret_cast
<
Tensor
*>
(
noop_flag
),
chunk_size
,
*
convertNVTE
Tensor
Check
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
*
reinterpret_cast
<
Tensor
*>
(
output
),
*
reinterpret_cast
<
Tensor
*>
(
output_per_tensor
),
*
reinterpret_cast
<
Tensor
*>
(
ret
),
*
reinterpret_cast
<
Tensor
*>
(
ret_per_tensor
),
*
reinterpret_cast
<
Tensor
*>
(
inv_scale
),
per_tensor
,
max_chunks_per_tensor
,
device_id
,
stream
);
*
convertNVTE
Tensor
Check
(
output
),
*
convertNVTE
Tensor
Check
(
output_per_tensor
),
*
convertNVTE
Tensor
Check
(
ret
),
*
convertNVTE
Tensor
Check
(
ret_per_tensor
),
*
convertNVTE
Tensor
Check
(
inv_scale
),
per_tensor
,
max_chunks_per_tensor
,
device_id
,
stream
);
}
transformer_engine/common/multi_tensor/scale.cu
View file @
2b05e121
...
...
@@ -124,7 +124,7 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens
using
namespace
transformer_engine
;
multi_tensor_scale
::
multi_tensor_scale_cuda
(
chunk_size
,
*
reinterpret_cast
<
Tensor
*>
(
noop_flag
),
chunk_size
,
*
convertNVTE
Tensor
Check
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
scale
,
device_id
,
stream
);
}
transformer_engine/common/multi_tensor/sgd.cu
View file @
2b05e121
...
...
@@ -196,7 +196,7 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor
using
namespace
transformer_engine
;
multi_tensor_sgd
::
multi_tensor_sgd_cuda
(
chunk_size
,
*
reinterpret_cast
<
Tensor
*>
(
noop_flag
),
chunk_size
,
*
convertNVTE
Tensor
Check
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
,
device_id
,
stream
);
}
transformer_engine/common/normalization/common.cpp
View file @
2b05e121
...
...
@@ -39,8 +39,6 @@ Compute always in FP32
namespace
transformer_engine
{
namespace
normalization
{
bool
&
use_zero_centered_gamma_in_weight_dtype
();
#ifndef __HIP_PLATFORM_AMD__
cudnn_frontend
::
NormFwdPhase_t
get_cudnn_forward_phase
(
const
bool
training
)
{
return
training
?
cudnn_frontend
::
NormFwdPhase_t
::
TRAINING
...
...
@@ -51,13 +49,17 @@ cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) {
TupleKeyType
get_key
(
NVTE_Norm_Backend
NormBackend
,
NVTE_Norm_Type
NormType
,
NVTE_Norm_Stage
NormStage
,
DType
wtype
,
DType
itype
,
DType
otype
,
DType
ctype
,
uint64_t
batch_size
,
uint64_t
hidden_size
,
bool
zero_centered_gamma
,
bool
is_tuned
,
NVTEScalingMode
mode
,
bool
training
)
{
// TODO: Add scaling_mode to general_key is needed
uint64_t
general_key
=
static_cast
<
uint32_t
>
(
itype
)
|
(
static_cast
<
uint32_t
>
(
otype
)
<<
3
)
|
(
static_cast
<
uint32_t
>
(
ctype
)
<<
6
)
|
(
static_cast
<
uint32_t
>
(
wtype
)
<<
9
)
|
(
uint32_t
(
NormType
)
<<
12
)
|
(
uint32_t
(
NormStage
))
<<
14
|
(
uint32_t
(
NormBackend
)
<<
16
)
|
(
uint32_t
(
zero_centered_gamma
)
<<
18
)
|
(
uint32_t
(
mode
)
<<
19
)
|
(
uint32_t
(
training
)
<<
22
);
bool
is_tuned
,
NVTEScalingMode
mode
,
bool
training
,
bool
gamma_in_weight_dtype
)
{
static_assert
(
NVTE_INVALID_SCALING
<
1024
,
"This function assumes at most 10 bits used in the scaling mode."
);
static_assert
(
kNVTENumTypes
<
32
,
"This function assumes at most 5 bits used in the NVTEDType"
);
uint64_t
general_key
=
static_cast
<
uint64_t
>
(
itype
)
|
(
static_cast
<
uint64_t
>
(
otype
)
<<
5
)
|
(
static_cast
<
uint64_t
>
(
ctype
)
<<
10
)
|
(
static_cast
<
uint64_t
>
(
wtype
)
<<
15
)
|
(
uint64_t
(
NormType
)
<<
20
)
|
(
uint64_t
(
NormStage
))
<<
22
|
(
uint64_t
(
NormBackend
)
<<
24
)
|
(
uint64_t
(
zero_centered_gamma
)
<<
26
)
|
(
uint64_t
(
mode
)
<<
27
)
|
(
uint64_t
(
training
)
<<
37
)
|
(
uint64_t
(
gamma_in_weight_dtype
)
<<
38
);
return
std
::
make_tuple
(
general_key
,
batch_size
,
hidden_size
,
is_tuned
);
}
...
...
@@ -216,8 +218,11 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
}
const
auto
gamma_dtype
=
use_zero_centered_gamma_in_weight_dtype
()
?
wtype
:
ctype
;
NVTE_CHECK
(
gamma_dtype
==
DType
::
kFloat32
||
gamma_dtype
==
DType
::
kFloat16
||
gamma_dtype
==
DType
::
kBFloat16
,
"Gamma of type FP4 is not supported"
);
_scalar_dptr
=
std
::
make_unique
<
char
[]
>
(
typeTo
Size
(
gamma_dtype
));
_scalar_dptr
=
std
::
make_unique
<
char
[]
>
(
typeTo
NumBits
(
gamma_dtype
)
/
8
);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
gamma_dtype
,
cpp_dtype
,
*
(
reinterpret_cast
<
cpp_dtype
*>
(
_scalar_dptr
.
get
()))
=
(
cpp_dtype
)
1.0
f
;);
...
...
@@ -490,11 +495,12 @@ NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan(
NVTE_Norm_Backend
NormBackend
,
NVTE_Norm_Type
NormType
,
NVTE_Norm_Stage
NormStage
,
DType
wtype
,
DType
itype
,
DType
otype
,
const
size_t
batch_size
,
const
size_t
hidden_size
,
const
size_t
sm_count
,
const
bool
zero_centered_gamma
,
const
bool
is_aligned
,
const
NVTEScalingMode
mode
,
const
bool
training
)
{
const
NVTEScalingMode
mode
,
const
bool
training
,
const
bool
gamma_in_weight_dtype
)
{
const
DType
ctype
=
DType
::
kFloat32
;
bool
is_tuned
=
is_aligned
&&
(
batch_size
%
4
==
0
);
auto
key
=
get_key
(
NormBackend
,
NormType
,
NormStage
,
wtype
,
itype
,
otype
,
ctype
,
batch_size
,
hidden_size
,
zero_centered_gamma
,
is_tuned
,
mode
,
training
);
auto
key
=
get_key
(
NormBackend
,
NormType
,
NormStage
,
wtype
,
itype
,
otype
,
ctype
,
batch_size
,
hidden_size
,
zero_centered_gamma
,
is_tuned
,
mode
,
training
,
gamma_in_weight_dtype
);
auto
it
=
normalizationPlanMap
.
find
(
key
);
if
(
it
!=
normalizationPlanMap
.
end
())
{
...
...
@@ -577,6 +583,7 @@ void nvte_enable_cudnn_norm_bwd(bool enable) {
#endif
}
// Only for testing, not thread-safe
void
nvte_enable_zero_centered_gamma_in_weight_dtype
(
bool
enable
)
{
NVTE_API_CALL
(
nvte_enable_zero_centered_gamma_in_weight_dtype
);
#ifdef USE_ROCM
...
...
transformer_engine/common/normalization/common.h
View file @
2b05e121
...
...
@@ -163,7 +163,7 @@ TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType,
NVTE_Norm_Stage
NormStage
,
DType
wtype
,
DType
itype
,
DType
otype
,
DType
ctype
,
uint64_t
batch_size
,
uint64_t
hidden_size
,
bool
zero_centered_gamma
,
bool
is_tuned
,
NVTEScalingMode
mode
=
NVTE_DELAYED_TENSOR_SCALING
,
bool
training
=
true
);
bool
training
=
true
,
bool
gamma_in_weight_dtype
=
false
);
template
<
typename
KernelParamsType
>
class
TeNormalizationRegistry
{
...
...
@@ -313,7 +313,8 @@ class NormalizationPlanRegistry {
NVTE_Norm_Backend
NormBackend
,
NVTE_Norm_Type
NormType
,
NVTE_Norm_Stage
NormStage
,
DType
wtype
,
DType
itype
,
DType
otype
,
const
size_t
batch_size
,
const
size_t
hidden_size
,
const
size_t
sm_count
,
const
bool
zero_centered_gamma
,
const
bool
is_aligned
,
const
NVTEScalingMode
mode
=
NVTE_DELAYED_TENSOR_SCALING
,
const
bool
training
=
true
);
const
NVTEScalingMode
mode
=
NVTE_DELAYED_TENSOR_SCALING
,
const
bool
training
=
true
,
const
bool
gamma_in_weight_dtype
=
false
);
private:
NormalizationPlanRegistry
()
{}
...
...
@@ -392,6 +393,8 @@ bool is_ptr_aligned(const Args*... ptrs) {
bool
use_cudnn_norm_fwd
();
bool
use_cudnn_norm_bwd
();
bool
&
use_zero_centered_gamma_in_weight_dtype
();
}
// namespace normalization
}
// namespace transformer_engine
...
...
transformer_engine/common/normalization/layernorm/ln_api.cpp
View file @
2b05e121
...
...
@@ -15,6 +15,7 @@
#include "../../common.h"
#include "../common.h"
#include "transformer_engine/transformer_engine.h"
namespace
transformer_engine
{
...
...
@@ -71,9 +72,11 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_mxfp_scaling
(
z
->
scaling_mode
);
#endif
bool
gamma_in_weight_dtype
=
false
;
if
(
cudnn_backend
)
{
// TODO: add check for GPU ARCH
norm_backend
=
NVTE_Norm_Backend
::
Cudnn
;
gamma_in_weight_dtype
=
use_zero_centered_gamma_in_weight_dtype
();
}
else
{
norm_backend
=
NVTE_Norm_Backend
::
Te
;
is_aligned
=
is_ptr_aligned
(
z
->
data
.
dptr
,
x
.
data
.
dptr
,
gamma
.
data
.
dptr
,
beta
.
data
.
dptr
,
...
...
@@ -90,7 +93,8 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
z
->
data
.
dtype
,
// otype
x
.
data
.
shape
[
0
],
// batch_size
x
.
data
.
shape
[
1
],
// hidden_size
multiprocessorCount
,
zero_centered_gamma
,
is_aligned
,
z
->
scaling_mode
,
training
);
multiprocessorCount
,
zero_centered_gamma
,
is_aligned
,
z
->
scaling_mode
,
training
,
gamma_in_weight_dtype
);
if
(
workspace
->
data
.
shape
.
empty
())
{
workspace
->
data
.
shape
=
plan
->
getWorkspaceShape
();
...
...
@@ -108,11 +112,11 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
// Compute FP8 transpose if required
if
(
z
->
has_columnwise_data
()
&&
is_tensor_scaling
(
z
->
scaling_mode
))
{
Tensor
transpose_data
;
transpose_data
.
data
=
z
->
columnwi
se_data
;
t
ranspose_data
.
scaling_mode
=
z
->
scaling_mode
;
nvte_transpose
(
reinterpret
_cast
<
NVTETensor
>
(
z
),
reinterpret_cast
<
NVTETensor
>
(
&
transpose_data
)
,
stream
);
NVTE
Tensor
transpose_data
=
nvte_create_tensor
(
z
->
scaling_mode
)
;
Tensor
&
t
=
*
convertNVTETensor
(
transpo
se_data
)
;
t
.
data
=
z
->
columnwise_data
;
nvte_transpose
(
static
_cast
<
NVTETensor
>
(
*
z
),
transpose_data
,
stream
);
nvte_destroy_tensor
(
transpose_data
);
}
return
;
...
...
@@ -157,9 +161,11 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
NVTE_Norm_Backend
norm_backend
;
bool
is_aligned
=
true
;
bool
gamma_in_weight_dtype
=
false
;
if
(
use_cudnn_norm_bwd
())
{
// TODO: add check for GPU ARCH
norm_backend
=
NVTE_Norm_Backend
::
Cudnn
;
gamma_in_weight_dtype
=
use_zero_centered_gamma_in_weight_dtype
();
}
else
{
norm_backend
=
NVTE_Norm_Backend
::
Te
;
is_aligned
=
is_ptr_aligned
(
x
.
data
.
dptr
,
gamma
.
data
.
dptr
,
mu
.
data
.
dptr
,
rsigma
.
data
.
dptr
,
...
...
@@ -172,7 +178,8 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
gamma
.
data
.
dtype
,
// otype
x
.
data
.
shape
[
0
],
// batch_size
x
.
data
.
shape
[
1
],
// hidden_size
multiprocessorCount
,
zero_centered_gamma
,
is_aligned
);
multiprocessorCount
,
zero_centered_gamma
,
is_aligned
,
NVTE_DELAYED_TENSOR_SCALING
,
true
,
gamma_in_weight_dtype
);
if
(
workspace
->
data
.
shape
.
empty
())
{
workspace
->
data
.
shape
=
plan
->
getWorkspaceShape
();
...
...
@@ -195,11 +202,10 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
const
bool
zero_centered_gamma
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_layernorm_fwd
);
using
namespace
transformer_engine
;
layernorm_fwd
(
*
reinterpret_cast
<
const
Tensor
*>
(
x
),
*
reinterpret_cast
<
const
Tensor
*>
(
gamma
),
*
reinterpret_cast
<
const
Tensor
*>
(
beta
),
epsilon
,
reinterpret_cast
<
Tensor
*>
(
z
),
reinterpret_cast
<
Tensor
*>
(
mu
),
reinterpret_cast
<
Tensor
*>
(
rsigma
),
reinterpret_cast
<
Tensor
*>
(
workspace
),
multiprocessorCount
,
zero_centered_gamma
,
stream
);
layernorm_fwd
(
*
convertNVTETensorCheck
(
x
),
*
convertNVTETensorCheck
(
gamma
),
*
convertNVTETensorCheck
(
beta
),
epsilon
,
convertNVTETensor
(
z
),
convertNVTETensor
(
mu
),
convertNVTETensor
(
rsigma
),
convertNVTETensor
(
workspace
),
multiprocessorCount
,
zero_centered_gamma
,
stream
);
}
void
nvte_layernorm_bwd
(
const
NVTETensor
dz
,
// BxSxhidden_size
...
...
@@ -212,10 +218,9 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_layernorm_bwd
);
using
namespace
transformer_engine
;
layernorm_bwd
(
*
reinterpret_cast
<
const
Tensor
*>
(
dz
),
*
reinterpret_cast
<
const
Tensor
*>
(
x
),
*
reinterpret_cast
<
const
Tensor
*>
(
mu
),
*
reinterpret_cast
<
const
Tensor
*>
(
rsigma
),
*
reinterpret_cast
<
const
Tensor
*>
(
gamma
),
reinterpret_cast
<
Tensor
*>
(
dx
),
reinterpret_cast
<
Tensor
*>
(
dgamma
),
reinterpret_cast
<
Tensor
*>
(
dbeta
),
reinterpret_cast
<
Tensor
*>
(
workspace
),
multiprocessorCount
,
zero_centered_gamma
,
stream
);
layernorm_bwd
(
*
convertNVTETensorCheck
(
dz
),
*
convertNVTETensorCheck
(
x
),
*
convertNVTETensorCheck
(
mu
),
*
convertNVTETensorCheck
(
rsigma
),
*
convertNVTETensorCheck
(
gamma
),
convertNVTETensor
(
dx
),
convertNVTETensor
(
dgamma
),
convertNVTETensor
(
dbeta
),
convertNVTETensor
(
workspace
),
multiprocessorCount
,
zero_centered_gamma
,
stream
);
}
transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
View file @
2b05e121
...
...
@@ -13,6 +13,7 @@
#include "../../common.h"
#include "../common.h"
#include "transformer_engine/normalization.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transpose.h"
namespace
transformer_engine
{
...
...
@@ -60,9 +61,11 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
bool
training
=
is_delayed_tensor_scaling
(
z
->
scaling_mode
)
||
(
z
->
columnwise_data
).
dptr
!=
nullptr
;
bool
gamma_in_weight_dtype
=
false
;
if
(
cudnn_backend
)
{
// TODO: add check for GPU ARCH
norm_backend
=
NVTE_Norm_Backend
::
Cudnn
;
gamma_in_weight_dtype
=
use_zero_centered_gamma_in_weight_dtype
();
}
else
{
norm_backend
=
NVTE_Norm_Backend
::
Te
;
is_aligned
=
is_ptr_aligned
(
z
->
data
.
dptr
,
x
.
data
.
dptr
,
gamma
.
data
.
dptr
,
rsigma
->
data
.
dptr
);
...
...
@@ -75,7 +78,8 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
z
->
data
.
dtype
,
// otype
x
.
data
.
shape
[
0
],
// batch_size
x
.
data
.
shape
[
1
],
// hidden_size
multiprocessorCount
,
zero_centered_gamma
,
is_aligned
,
z
->
scaling_mode
,
training
);
multiprocessorCount
,
zero_centered_gamma
,
is_aligned
,
z
->
scaling_mode
,
training
,
gamma_in_weight_dtype
);
if
(
workspace
->
data
.
shape
.
empty
())
{
workspace
->
data
.
shape
=
plan
->
getWorkspaceShape
();
...
...
@@ -93,11 +97,12 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
// Compute FP8 transpose if required
if
(
z
->
has_columnwise_data
()
&&
is_tensor_scaling
(
z
->
scaling_mode
))
{
Tensor
transpose_data
;
transpose_data
.
data
=
z
->
columnwise_data
;
transpose_data
.
scaling_mode
=
z
->
scaling_mode
;
nvte_transpose
(
reinterpret_cast
<
NVTETensor
>
(
z
),
reinterpret_cast
<
NVTETensor
>
(
&
transpose_data
),
stream
);
NVTETensor
transpose_data
=
nvte_create_tensor
(
z
->
scaling_mode
);
auto
*
t
=
convertNVTETensor
(
transpose_data
);
t
->
data
=
z
->
columnwise_data
;
nvte_transpose
(
static_cast
<
NVTETensor
>
(
*
z
),
transpose_data
,
stream
);
nvte_destroy_tensor
(
transpose_data
);
}
return
;
...
...
@@ -133,9 +138,11 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
NVTE_Norm_Backend
norm_backend
;
bool
is_aligned
=
true
;
bool
gamma_in_weight_dtype
=
false
;
if
(
use_cudnn_norm_bwd
())
{
// TODO: add check for GPU ARCH
norm_backend
=
NVTE_Norm_Backend
::
Cudnn
;
gamma_in_weight_dtype
=
use_zero_centered_gamma_in_weight_dtype
();
}
else
{
norm_backend
=
NVTE_Norm_Backend
::
Te
;
is_aligned
=
is_ptr_aligned
(
x
.
data
.
dptr
,
gamma
.
data
.
dptr
,
rsigma
.
data
.
dptr
,
dx
->
data
.
dptr
,
...
...
@@ -148,7 +155,8 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
gamma
.
data
.
dtype
,
// otype
x
.
data
.
shape
[
0
],
// batch_size
x
.
data
.
shape
[
1
],
// hidden_size
multiprocessorCount
,
zero_centered_gamma
,
is_aligned
);
multiprocessorCount
,
zero_centered_gamma
,
is_aligned
,
NVTE_DELAYED_TENSOR_SCALING
,
true
,
gamma_in_weight_dtype
);
if
(
workspace
->
data
.
shape
.
empty
())
{
workspace
->
data
.
shape
=
plan
->
getWorkspaceShape
();
...
...
@@ -171,10 +179,9 @@ void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_rmsnorm_fwd
);
using
namespace
transformer_engine
;
rmsnorm_fwd
(
*
reinterpret_cast
<
const
Tensor
*>
(
x
),
*
reinterpret_cast
<
const
Tensor
*>
(
gamma
),
epsilon
,
reinterpret_cast
<
Tensor
*>
(
z
),
reinterpret_cast
<
Tensor
*>
(
rsigma
),
reinterpret_cast
<
Tensor
*>
(
workspace
),
multiprocessorCount
,
zero_centered_gamma
,
stream
);
rmsnorm_fwd
(
*
convertNVTETensorCheck
(
x
),
*
convertNVTETensorCheck
(
gamma
),
epsilon
,
convertNVTETensor
(
z
),
convertNVTETensor
(
rsigma
),
convertNVTETensor
(
workspace
),
multiprocessorCount
,
zero_centered_gamma
,
stream
);
}
void
nvte_rmsnorm_bwd
(
const
NVTETensor
dz
,
// Nxhidden_size
...
...
@@ -186,9 +193,8 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_rmsnorm_bwd
);
using
namespace
transformer_engine
;
rmsnorm_bwd
(
*
reinterpret_cast
<
const
Tensor
*>
(
dz
),
*
reinterpret_cast
<
const
Tensor
*>
(
x
),
*
reinterpret_cast
<
const
Tensor
*>
(
rsigma
),
*
reinterpret_cast
<
const
Tensor
*>
(
gamma
),
reinterpret_cast
<
Tensor
*>
(
dx
),
reinterpret_cast
<
Tensor
*>
(
dgamma
),
reinterpret_cast
<
Tensor
*>
(
workspace
),
multiprocessorCount
,
zero_centered_gamma
,
stream
);
rmsnorm_bwd
(
*
convertNVTETensorCheck
(
dz
),
*
convertNVTETensorCheck
(
x
),
*
convertNVTETensorCheck
(
rsigma
),
*
convertNVTETensorCheck
(
gamma
),
convertNVTETensor
(
dx
),
convertNVTETensor
(
dgamma
),
convertNVTETensor
(
workspace
),
multiprocessorCount
,
zero_centered_gamma
,
stream
);
}
transformer_engine/common/permutation/permutation.cu
View file @
2b05e121
...
...
@@ -334,22 +334,16 @@ void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor so
NVTETensor
row_id_map
,
const
NVTETensor
prob
,
NVTETensor
prob_grad
,
const
NVTETensor
input_fwd
,
const
int
num_rows
,
const
int
topK
,
const
int
num_cols
,
const
int
num_out_tokens
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
NVTE_API_CALL
(
nvte_permute
);
const
transformer_engine
::
Tensor
*
input_cu
=
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
input
);
const
transformer_engine
::
Tensor
*
output_cu
=
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
output
);
const
transformer_engine
::
Tensor
*
sorted_row_id_cu
=
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
sorted_row_id
);
const
transformer_engine
::
Tensor
*
row_id_map_cu
=
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
row_id_map
);
const
transformer_engine
::
Tensor
*
prob_cu
=
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
prob
);
const
transformer_engine
::
Tensor
*
prob_grad_cu
=
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
prob_grad
);
const
transformer_engine
::
Tensor
*
input_fwd_cu
=
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
input_fwd
);
const
Tensor
*
input_cu
=
convertNVTETensorCheck
(
input
);
const
Tensor
*
output_cu
=
convertNVTETensorCheck
(
output
);
const
Tensor
*
sorted_row_id_cu
=
convertNVTETensorCheck
(
sorted_row_id
);
const
Tensor
*
row_id_map_cu
=
convertNVTETensorCheck
(
row_id_map
);
const
Tensor
*
prob_cu
=
convertNVTETensorCheck
(
prob
);
const
Tensor
*
prob_grad_cu
=
convertNVTETensorCheck
(
prob_grad
);
const
Tensor
*
input_fwd_cu
=
convertNVTETensorCheck
(
input_fwd
);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
input_cu
->
data
.
dtype
,
T
,
...
...
@@ -366,16 +360,13 @@ void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor so
void
nvte_unpermute
(
const
NVTETensor
input
,
NVTETensor
output
,
NVTETensor
row_id_map
,
const
NVTETensor
prob
,
const
int
num_rows
,
const
int
topK
,
const
int
num_cols
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
NVTE_API_CALL
(
nvte_unpermute
);
const
transformer_engine
::
Tensor
*
input_cu
=
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
input
);
const
transformer_engine
::
Tensor
*
output_cu
=
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
output
);
const
transformer_engine
::
Tensor
*
row_id_map_cu
=
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
row_id_map
);
const
transformer_engine
::
Tensor
*
prob_cu
=
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
prob
);
const
Tensor
*
input_cu
=
convertNVTETensorCheck
(
input
);
const
Tensor
*
output_cu
=
convertNVTETensorCheck
(
output
);
const
Tensor
*
row_id_map_cu
=
convertNVTETensorCheck
(
row_id_map
);
const
Tensor
*
prob_cu
=
convertNVTETensorCheck
(
prob
);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
input_cu
->
data
.
dtype
,
T
,
...
...
transformer_engine/common/recipe/__init__.py
View file @
2b05e121
...
...
@@ -180,6 +180,7 @@ class DelayedScaling(Recipe):
def
__repr__
(
self
)
->
str
:
return
(
f
"recipe_type=
{
self
.
__class__
.
__name__
}
, "
f
"margin=
{
self
.
margin
}
, "
f
"format=
{
str
(
self
.
fp8_format
).
split
(
'.'
)[
1
]
}
, "
f
"amax_history_len=
{
self
.
amax_history_len
}
, "
...
...
@@ -192,42 +193,12 @@ class DelayedScaling(Recipe):
class
Float8CurrentScaling
(
Recipe
):
"""
Use the per-tensor current scaling factor strategy.
Parameters
----------
fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID
Controls the FP8 data format used during forward and backward
pass.
fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0}
used for quantization of input tensor x
fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0}
used for quantization of weight tensor w
fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0}
used for quantization of gradient tensor dY
fp8_gemm_fprop: MMParams, default MMParams.use_split_accumulator=False
used for calculating output y in forward pass
fp8_gemm_dgrad: MMParams, default MMParams.use_split_accumulator=True
use for calculating dgrad in backward pass
fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True
use for calculating dgrad in backward pass
fp8_dpa: bool, default = `False`
Whether to enable FP8 dot product attention (DPA). When the model is placed in an
`fp8_autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the
inputs from higher precision to FP8, performs attention in FP8, and casts tensors
back to higher precision as outputs. FP8 DPA currently is only supported in the
`FusedAttention` backend.
fp8_mha: bool, default = `False`
Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting
operations mentioned above at the DPA boundaries. Currently only standard MHA modules
i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When
`fp8_mha = False, fp8_dpa = True`, a typical MHA module works as
`LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`.
When `fp8_mha = True, fp8_dpa = True`, it becomes
`LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`.
Notes
-----
* `fp8_dpa` and `fp8_mha` are Beta features, and their API and functionality are
subject to change in future Transformer Engine releases.
"""
fp8_format
:
Format
=
Format
.
HYBRID
...
...
@@ -242,9 +213,13 @@ class Float8CurrentScaling(Recipe):
def
__post_init__
(
self
)
->
None
:
assert
self
.
fp8_format
!=
Format
.
E5M2
,
"Pure E5M2 training is not supported."
assert
(
not
self
.
fp8_dpa
and
not
self
.
fp8_mha
),
"FP8 attention is not supported for Float8CurrentScaling."
def
__repr__
(
self
)
->
str
:
return
(
f
"recipe_type=
{
self
.
__class__
.
__name__
}
, "
f
"format=
{
str
(
self
.
fp8_format
).
split
(
'.'
)[
1
]
}
, "
f
"fp8_quant_fwd_inp=
{
self
.
fp8_quant_fwd_inp
}
, "
f
"fp8_quant_fwd_weight=
{
self
.
fp8_quant_fwd_weight
}
, "
...
...
@@ -291,7 +266,11 @@ class MXFP8BlockScaling(Recipe):
assert
self
.
fp8_format
!=
Format
.
E5M2
,
"Pure E5M2 training is not supported."
def
__repr__
(
self
)
->
str
:
return
f
"margin=
{
self
.
margin
}
, format=
{
str
(
self
.
fp8_format
).
split
(
'.'
)[
1
]
}
,"
return
(
f
"recipe_type=
{
self
.
__class__
.
__name__
}
, "
f
"margin=
{
self
.
margin
}
, "
f
"format=
{
str
(
self
.
fp8_format
).
split
(
'.'
)[
1
]
}
"
)
@
dataclass
()
...
...
@@ -313,32 +292,12 @@ class Float8BlockScaling(Recipe):
NOTE: To relax the default constraint that scales be powers of 2, set env variable
NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 to override it for the recipe defaults.
export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1
Or initialize the Recipe with non-default QParams in code for increased control.
Parameters
----------
fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3
Controls the FP8 data format used during forward and backward
pass.
fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0}
used for quantization of input tensor x
fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0}
used for quantization of weight tensor w
fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0}
used for quantization of gradient tensor dY
x_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional)
qblock scaling for x.
w_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional)
qblock scaling for w.
grad_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional)
qblock scaling for grad.
fp8_gemm_fprop: MMParams, default MMParams.use_split_accumulator=False
used for calculating output y in forward pass
fp8_gemm_dgrad: MMParams, default MMParams.use_split_accumulator=True
use for calculating dgrad in backward pass
fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True
use for calculating dgrad in backward pass
"""
use_f32_scales
:
bool
=
os
.
getenv
(
"NVTE_FP8_BLOCK_SCALING_FP32_SCALES"
,
"0"
)
==
"1"
...
...
@@ -372,9 +331,13 @@ class Float8BlockScaling(Recipe):
assert
self
.
fp8_gemm_fprop
.
use_split_accumulator
,
"Split accumulator required for fprop."
assert
self
.
fp8_gemm_dgrad
.
use_split_accumulator
,
"Split accumulator required for dgrad."
assert
self
.
fp8_gemm_wgrad
.
use_split_accumulator
,
"Split accumulator required for wgrad."
assert
(
not
self
.
fp8_dpa
and
not
self
.
fp8_mha
),
"FP8 attention is not supported for Float8BlockScaling."
def
__repr__
(
self
)
->
str
:
return
(
f
"recipe_type=
{
self
.
__class__
.
__name__
}
, "
f
"format=
{
str
(
self
.
fp8_format
).
split
(
'.'
)[
1
]
}
, "
f
"fp8_quant_fwd_inp=
{
self
.
fp8_quant_fwd_inp
}
, "
f
"fp8_quant_fwd_weight=
{
self
.
fp8_quant_fwd_weight
}
, "
...
...
transformer_engine/common/recipe/current_scaling.cu
View file @
2b05e121
...
...
@@ -112,7 +112,7 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
// Check input tensor
NVTE_CHECK
(
input_
!=
nullptr
,
"Invalid input tensor (got NULL)"
);
const
auto
&
input
=
*
reinterpret_cast
<
const
Tensor
*>
(
input_
);
const
auto
&
input
=
*
convertNVTE
Tensor
Check
(
input_
);
NVTE_CHECK
(
input
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
,
"Input tensor for amax computation must unquantized, "
"but got scaling_mode="
,
...
...
@@ -125,7 +125,7 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
// Check output tensor
NVTE_CHECK
(
output_
!=
nullptr
,
"Invalid output tensor (got NULL)"
);
auto
&
output
=
*
reinterpret_cast
<
Tensor
*>
(
output_
);
auto
&
output
=
*
convertNVTE
Tensor
Check
(
output_
);
NVTE_CHECK
(
output
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
,
"Output tensor for amax computation must be FP8 tensor with per-tensor scaling, "
"but got scaling_mode="
,
...
...
@@ -170,7 +170,7 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf
// Check output tensor
NVTE_CHECK
(
output_
!=
nullptr
,
"Invalid output tensor (got NULL)"
);
auto
&
output
=
*
reinterpret_cast
<
Tensor
*>
(
output_
);
auto
&
output
=
*
convertNVTE
Tensor
Check
(
output_
);
NVTE_CHECK
(
output
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
,
"Tensor must be FP8 tensor with per-tensor scaling, "
"but got scaling_mode="
,
...
...
transformer_engine/common/recipe/delayed_scaling.cu
View file @
2b05e121
...
...
@@ -397,9 +397,9 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(
NVTE_API_CALL
(
nvte_delayed_scaling_recipe_amax_and_scale_update
);
using
namespace
transformer_engine
;
delayed_scaling_recipe
::
amax_and_scale_update
(
*
reinterpret_cast
<
const
Tensor
*>
(
amax_history
),
*
reinterpret_cast
<
const
Tensor
*>
(
scale
),
reinterpret_cast
<
Tensor
*>
(
updated_amax_history
),
reinterpret_cast
<
Tensor
*>
(
updated_scale
),
amax_compute_algo
,
static_cast
<
DType
>
(
fp8_dtype
),
margin
,
stream
);
*
convertNVTE
Tensor
Check
(
amax_history
),
*
convertNVTE
Tensor
Check
(
scale
),
convertNVTE
Tensor
(
updated_amax_history
),
convertNVTE
Tensor
(
updated_scale
),
amax_compute_algo
,
static_cast
<
DType
>
(
fp8_dtype
),
margin
,
stream
);
}
void
nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction
(
...
...
@@ -411,10 +411,10 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
size_t
num_tensors
=
amax_histories
.
size
();
std
::
vector
<
Tensor
*>
t_amax_histories
,
t_scales
;
for
(
size_t
i
=
0
;
i
<
num_tensors
;
i
++
)
{
t_amax_histories
.
push_back
(
reinterpret_cast
<
Tensor
*>
(
amax_histories
[
i
]));
t_scales
.
push_back
(
reinterpret_cast
<
Tensor
*>
(
scales
[
i
]));
t_amax_histories
.
push_back
(
convertNVTE
Tensor
(
amax_histories
[
i
]));
t_scales
.
push_back
(
convertNVTE
Tensor
(
scales
[
i
]));
}
delayed_scaling_recipe
::
amax_and_scale_update_after_reduction
(
*
reinterpret_cast
<
const
Tensor
*>
(
amax_reduction_buffer
),
t_amax_histories
,
t_scales
,
amax_compute_algo
,
static_cast
<
DType
>
(
fp8_dtype
),
margin
,
stream
);
*
convertNVTE
Tensor
Check
(
amax_reduction_buffer
),
t_amax_histories
,
t_scales
,
amax_compute_algo
,
static_cast
<
DType
>
(
fp8_dtype
),
margin
,
stream
);
}
transformer_engine/common/recipe/fp8_block_scaling.cu
View file @
2b05e121
...
...
@@ -244,8 +244,8 @@ void nvte_fp8_block_scaling_compute_partial_amax(const NVTETensor inp, NVTETenso
NVTE_API_CALL
(
nvte_fp8_block_scaling_compute_partial_amax
);
using
namespace
transformer_engine
;
fp8_block_scaling_recipe
::
fp8_block_scaling_compute_partial_amax
(
*
reinterpret_cast
<
const
Tensor
*>
(
inp
),
*
reinterpret_cast
<
Tensor
*>
(
amax
),
h
,
w
,
amax_stride_h
,
amax_stride_w
,
start_offset
,
block_len
,
stream
);
*
convertNVTE
Tensor
Check
(
inp
),
*
convertNVTE
Tensor
Check
(
amax
),
h
,
w
,
amax_stride_h
,
amax_stride_w
,
start_offset
,
block_len
,
stream
);
}
void
nvte_fp8_block_scaling_partial_cast
(
const
NVTETensor
inp
,
NVTETensor
out
,
...
...
@@ -256,7 +256,7 @@ void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out,
NVTE_API_CALL
(
nvte_fp8_block_scaling_partial_cast
);
using
namespace
transformer_engine
;
fp8_block_scaling_recipe
::
fp8_block_scaling_partial_cast
(
*
reinterpret_cast
<
const
Tensor
*>
(
inp
),
*
reinterpret_cast
<
Tensor
*>
(
out
)
,
*
reinterpret_cast
<
const
Tensor
*>
(
scale
),
h
,
w
,
scale_stride_h
,
scale_stride_w
,
start_offset
,
block_len
,
static_cast
<
DType
>
(
out_dtype
),
stream
);
*
convertNVTE
Tensor
Check
(
inp
),
*
convertNVTETensorCheck
(
out
),
*
convertNVTETensorCheck
(
scale
),
h
,
w
,
scale_stride_h
,
scale_stride_w
,
start_offset
,
block_len
,
static_cast
<
DType
>
(
out_dtype
),
stream
);
}
transformer_engine/common/swizzle/swizzle.cu
View file @
2b05e121
...
...
@@ -514,6 +514,5 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
void
nvte_swizzle_scaling_factors
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_swizzle_scaling_factors
);
using
namespace
transformer_engine
;
swizzle_scaling_factors
(
reinterpret_cast
<
const
Tensor
*>
(
input
),
reinterpret_cast
<
Tensor
*>
(
output
),
stream
);
swizzle_scaling_factors
(
convertNVTETensorCheck
(
input
),
convertNVTETensorCheck
(
output
),
stream
);
}
transformer_engine/common/transformer_engine.cpp
View file @
2b05e121
...
...
@@ -6,20 +6,27 @@
#include <transformer_engine/transformer_engine.h>
#include <atomic>
#include <climits>
#include <cstring>
#include <iostream>
#include <mutex>
#include "common.h"
#include "common/util/cuda_runtime.h"
#include "common/util/logging.h"
namespace
transformer_engine
{
size_t
typeTo
Size
(
const
DType
type
)
{
size_t
typeTo
NumBits
(
const
DType
type
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL
(
type
,
T
,
return
TypeInfo
<
T
>::
size
;);
// NOLINT(*)
}
bool
is_fp8_dtype
(
const
DType
t
)
{
return
t
==
DType
::
kFloat8E4M3
||
t
==
DType
::
kFloat8E5M2
;
}
size_t
typeToSize
(
const
DType
type
)
{
NVTE_CHECK
(
type
!=
DType
::
kFloat4E2M1
,
"typeToSize() Does not support FP4 data type."
);
return
typeToNumBits
(
type
)
/
8
;
}
std
::
string
to_string
(
const
DType
type
)
{
switch
(
type
)
{
...
...
@@ -37,6 +44,8 @@ std::string to_string(const DType type) {
return
"Float8E5M2"
;
case
DType
::
kFloat8E8M0
:
return
"Float8E8M0"
;
case
DType
::
kFloat4E2M1
:
return
"Float4E2M1"
;
case
DType
::
kInt32
:
return
"Int32"
;
case
DType
::
kInt64
:
...
...
@@ -52,6 +61,8 @@ std::string to_string(const NVTEScalingMode &mode) {
return
"NVTE_DELAYED_TENSOR_SCALING"
;
case
NVTE_MXFP8_1D_SCALING
:
return
"NVTE_MXFP8_1D_SCALING"
;
case
NVTE_FWD_NVFP4_BWD_MXFP8_SCALING
:
return
"NVTE_FWD_NVFP4_BWD_MXFP8_SCALING"
;
case
NVTE_INVALID_SCALING
:
return
"NVTE_INVALID_SCALING"
;
}
...
...
@@ -81,10 +92,13 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
t
.
columnwise_scale_inv
.
shape
,
")"
);
}
}
else
{
if
(
t
.
scaling_mode
==
NVTE_MXFP8_1D_SCALING
)
{
if
(
t
.
scaling_mode
==
NVTE_MXFP8_1D_SCALING
||
t
.
scaling_mode
==
NVTE_FWD_NVFP4_BWD_MXFP8_SCALING
)
{
// Need (4, 128) alignment even for e8 scaling factor
auto
block_alignment
=
std
::
vector
<
size_t
>
{
128ul
,
4ul
};
size_t
expected_x
,
expected_y
,
alignment
;
const
size_t
block_size_rowwise
=
(
t
.
scaling_mode
==
NVTE_MXFP8_1D_SCALING
)
?
32
:
16
;
const
size_t
block_size_colwise
=
32
;
if
(
t
.
has_data
())
{
alignment
=
block_alignment
[
0
];
...
...
@@ -92,7 +106,8 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
DIVUP
(
DIVUP
(
t
.
flat_first_dim
(),
static_cast
<
size_t
>
(
1
)),
alignment
)
*
alignment
;
alignment
=
block_alignment
[
1
];
expected_y
=
DIVUP
(
DIVUP
(
t
.
flat_last_dim
(),
static_cast
<
size_t
>
(
32
)),
alignment
)
*
alignment
;
DIVUP
(
DIVUP
(
t
.
flat_last_dim
(),
static_cast
<
size_t
>
(
block_size_rowwise
)),
alignment
)
*
alignment
;
const
auto
&
expected
=
std
::
vector
<
size_t
>
{
expected_x
,
expected_y
};
NVTE_CHECK
(
t
.
scale_inv
.
shape
==
expected
,
"Tensor
\"
"
,
name
,
"
\"
has invalid scale_inv shape (expected "
,
expected
,
", got "
,
...
...
@@ -101,7 +116,8 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
if
(
t
.
has_columnwise_data
())
{
alignment
=
block_alignment
[
1
];
expected_x
=
DIVUP
(
DIVUP
(
t
.
flat_first_dim
(),
static_cast
<
size_t
>
(
32
)),
alignment
)
*
alignment
;
DIVUP
(
DIVUP
(
t
.
flat_first_dim
(),
static_cast
<
size_t
>
(
block_size_colwise
)),
alignment
)
*
alignment
;
alignment
=
block_alignment
[
0
];
expected_y
=
DIVUP
(
DIVUP
(
t
.
flat_last_dim
(),
static_cast
<
size_t
>
(
1
)),
alignment
)
*
alignment
;
const
auto
&
expected
=
std
::
vector
<
size_t
>
{
expected_x
,
expected_y
};
...
...
@@ -192,24 +208,139 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
CheckScaleTensorShape
(
t
,
name
);
}
class
TensorAllocator
{
public:
static
TensorAllocator
&
instance
()
{
static
TensorAllocator
allocator
;
return
allocator
;
}
~
TensorAllocator
()
{}
NVTETensor
Allocate
(
NVTEScalingMode
mode
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex
);
if
(
!
free_list
.
empty
())
{
uintptr_t
index
=
free_list
.
back
();
NVTETensor
ret
=
reinterpret_cast
<
NVTETensor
>
(
index
);
free_list
.
pop_back
();
if
(
debug
)
{
std
::
cout
<<
"Allocated "
<<
index
<<
" from free list. Free list size: "
<<
free_list
.
size
()
<<
" and capacity "
<<
free_list
.
capacity
()
<<
std
::
endl
;
}
// 1-based indexing
memory
[
index
-
1
].
scaling_mode
=
mode
;
return
ret
;
}
if
(
memory
.
size
()
<
memory
.
capacity
())
{
memory
.
emplace_back
();
Tensor
&
t
=
memory
.
back
();
size
=
memory
.
size
();
// 1-based indexing
uintptr_t
index
=
memory
.
size
();
if
(
debug
)
{
std
::
cout
<<
"Allocated "
<<
index
<<
". Memory size: "
<<
memory
.
size
()
<<
" and capacity "
<<
memory
.
capacity
()
<<
std
::
endl
;
}
t
.
scaling_mode
=
mode
;
t
.
nvte_tensor
=
reinterpret_cast
<
NVTETensor
>
(
index
);
return
reinterpret_cast
<
NVTETensor
>
(
index
);
}
NVTE_ERROR
(
"Cannot allocate a new NVTETensor. Maximum number of tensors reached: "
,
MAX_TENSOR_NUM
,
". There is probably a memory leak in your application."
);
}
void
Free
(
NVTETensor
t
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex
);
uintptr_t
index
=
reinterpret_cast
<
uintptr_t
>
(
t
);
if
(
index
==
0
)
return
;
NVTE_CHECK
(
index
<=
memory
.
size
(),
"Invalid tensor."
);
free_list
.
push_back
(
index
);
// Clean up
memory
[
index
-
1
].
clear
();
if
(
debug
)
{
std
::
cout
<<
"Freed "
<<
index
<<
". Free list size: "
<<
free_list
.
size
()
<<
" and capacity "
<<
free_list
.
capacity
()
<<
std
::
endl
;
}
}
void
Free
(
NVTETensor
*
t
,
size_t
N
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex
);
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
uintptr_t
index
=
reinterpret_cast
<
uintptr_t
>
(
t
[
i
]);
if
(
index
==
0
)
continue
;
NVTE_CHECK
(
index
<=
memory
.
size
(),
"Invalid tensor."
);
free_list
.
push_back
(
index
);
// Clean up
memory
[
index
-
1
].
clear
();
}
if
(
debug
)
{
std
::
cout
<<
"Freed range of"
<<
N
<<
" tensors. Free list size: "
<<
free_list
.
size
()
<<
" and capacity "
<<
free_list
.
capacity
()
<<
std
::
endl
;
}
}
Tensor
*
convertNVTETensor
(
NVTETensor
t
)
{
uintptr_t
index
=
reinterpret_cast
<
uintptr_t
>
(
t
);
// 1-based indexing to enable 0-initialization of NVTETensor
// to be invalid tensor
static_assert
(
nullptr
==
0
);
if
(
index
!=
0
&&
index
<=
size
)
{
return
&
(
memory
[
index
-
1
]);
}
return
nullptr
;
}
void
setDebug
(
bool
debug
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex
);
this
->
debug
=
debug
;
}
private:
TensorAllocator
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex
);
memory
.
reserve
(
MAX_TENSOR_NUM
);
}
std
::
mutex
mutex
;
std
::
atomic
<
size_t
>
size
;
// Allocate at most 20 MB for tensors
// Should be replaced by virtual memory allocation
const
size_t
MAX_TENSOR_NUM
=
20
*
1024
*
1024
/
sizeof
(
Tensor
);
std
::
vector
<
uintptr_t
>
free_list
;
std
::
vector
<
Tensor
>
memory
;
bool
debug
=
false
;
};
Tensor
*
convertNVTETensor
(
const
NVTETensor
t
)
{
return
TensorAllocator
::
instance
().
convertNVTETensor
(
t
);
}
Tensor
*
convertNVTETensorCheck
(
const
NVTETensor
t
)
{
Tensor
*
ptr
=
TensorAllocator
::
instance
().
convertNVTETensor
(
t
);
NVTE_CHECK
(
ptr
!=
nullptr
,
"Invalid tensor."
);
return
ptr
;
}
}
// namespace transformer_engine
NVTETensor
nvte_create_tensor
(
NVTEScalingMode
scaling_mode
)
{
transformer_engine
::
Tensor
*
ret
=
new
transformer_engine
::
Tensor
;
ret
->
scaling_mode
=
scaling_mode
;
NVTETensor
ret
=
transformer_engine
::
TensorAllocator
::
instance
().
Allocate
(
scaling_mode
);
return
ret
;
}
void
nvte_destroy_tensor
(
NVTETensor
tensor
)
{
if
(
tensor
==
nullptr
)
return
;
auto
*
t
=
reinterpret_cast
<
transformer_engine
::
Tensor
*>
(
tensor
);
delete
t
;
transformer_engine
::
TensorAllocator
::
instance
().
Free
(
tensor
);
}
void
nvte_destroy_tensors
(
NVTETensor
*
tensors
,
size_t
N
)
{
transformer_engine
::
TensorAllocator
::
instance
().
Free
(
tensors
,
N
);
}
NVTEDType
nvte_tensor_type
(
const
NVTETensor
tensor
)
{
if
(
tensor
==
nullptr
)
return
kNVTEFloat32
;
return
static_cast
<
NVTEDType
>
(
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
)
->
dtype
());
auto
*
t
=
transformer_engine
::
convertNVTETensor
(
tensor
)
;
if
(
t
==
nullptr
)
return
kNVTEFloat32
;
return
static_cast
<
NVTEDType
>
(
t
->
dtype
());
}
NVTEShape
nvte_make_shape
(
const
size_t
*
data
,
size_t
ndim
)
{
...
...
@@ -227,23 +358,24 @@ NVTEShape nvte_make_shape(const size_t *data, size_t ndim) {
}
NVTEShape
nvte_tensor_shape
(
const
NVTETensor
tensor
)
{
if
(
tensor
==
nullptr
)
{
auto
*
t
=
transformer_engine
::
convertNVTETensor
(
tensor
);
if
(
t
==
nullptr
)
{
NVTE_ERROR
(
"Invalid tensor"
);
}
// Determine tensor shape depending on tensor format
const
auto
&
t
=
*
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
);
std
::
vector
<
size_t
>
shape
=
t
.
shape
();
const
std
::
vector
<
size_t
>
&
shape
=
t
->
shape
();
return
nvte_make_shape
(
shape
.
data
(),
shape
.
size
());
}
NVTEShape
nvte_tensor_columnwise_shape
(
const
NVTETensor
tensor
)
{
if
(
tensor
==
nullptr
)
{
auto
*
t
=
transformer_engine
::
convertNVTETensor
(
tensor
);
if
(
t
==
nullptr
)
{
NVTE_ERROR
(
"Invalid tensor"
);
}
const
auto
&
t
=
*
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
)
;
return
nvte_make_shape
(
t
.
columnwise_data
.
shape
.
data
(),
t
.
columnwise_data
.
shape
.
size
());
const
std
::
vector
<
size_t
>
&
shape
=
t
->
columnwise_data
.
shape
;
return
nvte_make_shape
(
shape
.
data
(),
shape
.
size
());
}
size_t
nvte_tensor_ndims
(
const
NVTETensor
tensor
)
{
return
nvte_tensor_shape
(
tensor
).
ndim
;
}
...
...
@@ -264,83 +396,97 @@ size_t nvte_tensor_numel(const NVTETensor tensor) {
return
numel
;
}
size_t
nvte_tensor_element_size_bits
(
const
NVTETensor
tensor
)
{
auto
*
t
=
transformer_engine
::
convertNVTETensor
(
tensor
);
if
(
t
==
nullptr
)
return
8
*
sizeof
(
float
);
return
transformer_engine
::
typeToNumBits
(
t
->
dtype
());
}
size_t
nvte_tensor_element_size
(
const
NVTETensor
tensor
)
{
if
(
tensor
==
nullptr
)
return
sizeof
(
float
);
const
auto
&
t
=
*
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
);
return
transformer_engine
::
typeToSize
(
t
.
dtype
());
auto
*
t
=
transformer_engine
::
convertNVTETensor
(
tensor
);
if
(
t
==
nullptr
)
return
sizeof
(
float
);
NVTE_CHECK
(
!
is_fp4_dtype
(
t
->
dtype
()),
"For FP4 type please use the nvte_tensor_element_size_bits."
);
return
nvte_tensor_element_size_bits
(
tensor
)
/
8
;
}
size_t
nvte_tensor_size_bytes
(
const
NVTETensor
tensor
)
{
auto
*
t
=
transformer_engine
::
convertNVTETensor
(
tensor
);
if
(
t
==
nullptr
)
return
0
;
return
(
nvte_tensor_numel
(
tensor
)
*
nvte_tensor_element_size_bits
(
tensor
))
/
8
;
}
void
*
nvte_tensor_data
(
const
NVTETensor
tensor
)
{
if
(
tensor
==
nullptr
)
return
nullptr
;
const
auto
&
t
=
*
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
)
;
return
t
.
data
.
dptr
;
auto
*
t
=
transformer_engine
::
convertNVTETensor
(
tensor
)
;
if
(
t
=
=
nullptr
)
return
nullptr
;
return
t
->
data
.
dptr
;
}
void
*
nvte_tensor_columnwise_data
(
const
NVTETensor
tensor
)
{
if
(
tensor
==
nullptr
)
return
nullptr
;
const
auto
&
t
=
*
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
)
;
return
t
.
columnwise_data
.
dptr
;
auto
*
t
=
transformer_engine
::
convertNVTETensor
(
tensor
)
;
if
(
t
=
=
nullptr
)
return
nullptr
;
return
t
->
columnwise_data
.
dptr
;
}
float
*
nvte_tensor_amax
(
const
NVTETensor
tensor
)
{
if
(
tensor
==
nullptr
)
return
nullptr
;
const
auto
&
t
=
*
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
)
;
NVTE_CHECK
(
t
.
amax
.
dtype
==
transformer_engine
::
DType
::
kFloat32
,
auto
*
t
=
transformer_engine
::
convertNVTETensor
(
tensor
)
;
if
(
t
=
=
nullptr
)
return
nullptr
;
NVTE_CHECK
(
t
->
amax
.
dtype
==
transformer_engine
::
DType
::
kFloat32
,
"Tensor's amax must have Float32 type!"
);
return
reinterpret_cast
<
float
*>
(
t
.
amax
.
dptr
);
return
reinterpret_cast
<
float
*>
(
t
->
amax
.
dptr
);
}
float
*
nvte_tensor_scale
(
const
NVTETensor
tensor
)
{
if
(
tensor
==
nullptr
)
return
nullptr
;
const
auto
&
t
=
*
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
)
;
NVTE_CHECK
(
t
.
scale
.
dtype
==
transformer_engine
::
DType
::
kFloat32
,
auto
*
t
=
transformer_engine
::
convertNVTETensor
(
tensor
)
;
if
(
t
=
=
nullptr
)
return
nullptr
;
NVTE_CHECK
(
t
->
scale
.
dtype
==
transformer_engine
::
DType
::
kFloat32
,
"Tensor's scale must have Float32 type!"
);
return
reinterpret_cast
<
float
*>
(
t
.
scale
.
dptr
);
return
reinterpret_cast
<
float
*>
(
t
->
scale
.
dptr
);
}
float
*
nvte_tensor_scale_inv
(
const
NVTETensor
tensor
)
{
if
(
tensor
==
nullptr
)
return
nullptr
;
const
auto
&
t
=
*
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
)
;
return
reinterpret_cast
<
float
*>
(
t
.
scale_inv
.
dptr
);
auto
*
t
=
transformer_engine
::
convertNVTETensor
(
tensor
)
;
if
(
t
=
=
nullptr
)
return
nullptr
;
return
reinterpret_cast
<
float
*>
(
t
->
scale_inv
.
dptr
);
}
void
*
nvte_tensor_columnwise_scale_inv
(
const
NVTETensor
tensor
)
{
if
(
tensor
==
nullptr
)
return
nullptr
;
const
auto
&
t
=
*
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
)
;
return
t
.
columnwise_scale_inv
.
dptr
;
auto
*
t
=
transformer_engine
::
convertNVTETensor
(
tensor
)
;
if
(
t
=
=
nullptr
)
return
nullptr
;
return
t
->
columnwise_scale_inv
.
dptr
;
}
NVTEShape
nvte_tensor_scale_inv_shape
(
const
NVTETensor
tensor
)
{
if
(
tensor
==
nullptr
)
{
auto
*
t
=
transformer_engine
::
convertNVTETensor
(
tensor
);
if
(
t
==
nullptr
)
{
return
nvte_make_shape
(
nullptr
,
0
);
}
const
auto
&
t
=
*
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
);
return
nvte_make_shape
(
t
.
scale_inv
.
shape
.
data
(),
t
.
scale_inv
.
shape
.
size
());
return
nvte_make_shape
(
t
->
scale_inv
.
shape
.
data
(),
t
->
scale_inv
.
shape
.
size
());
}
void
nvte_set_tensor_param
(
NVTETensor
*
tensor
,
NVTETensorParam
param_name
,
const
NVTEBasicTensor
*
param
)
{
NVTE_CHECK
(
tensor
!=
nullptr
,
"Tensor pointer can't be NULL."
);
NVTE_CHECK
(
*
tensor
!=
nullptr
,
"Te
nsor
is not allocated."
);
auto
&
t
=
*
reinterpret_cast
<
tra
ns
f
or
mer_engine
::
Tensor
*>
(
*
tensor
);
auto
*
t
=
tra
ns
f
or
mer_engine
::
convertNVTETensor
(
*
tensor
);
NVTE_CHECK
(
t
!
=
nullptr
,
"Te
nsor
is not allocated."
);
switch
(
param_name
)
{
case
kNVTERowwiseData
:
t
.
data
=
*
param
;
t
->
data
=
*
param
;
break
;
case
kNVTEColumnwiseData
:
t
.
columnwise_data
=
*
param
;
t
->
columnwise_data
=
*
param
;
break
;
case
kNVTEScale
:
t
.
scale
=
*
param
;
t
->
scale
=
*
param
;
break
;
case
kNVTEAmax
:
t
.
amax
=
*
param
;
t
->
amax
=
*
param
;
break
;
case
kNVTERowwiseScaleInv
:
t
.
scale_inv
=
*
param
;
t
->
scale_inv
=
*
param
;
break
;
case
kNVTEColumnwiseScaleInv
:
t
.
columnwise_scale_inv
=
*
param
;
t
->
columnwise_scale_inv
=
*
param
;
break
;
default:
NVTE_ERROR
(
"Unknown tensor parameter!"
);
...
...
@@ -351,7 +497,7 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p
if
(
tensor
==
nullptr
)
{
return
{
nullptr
,
kNVTEFloat32
,
nvte_make_shape
(
nullptr
,
0
)};
}
const
auto
&
t
=
*
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
);
const
auto
&
t
=
*
transformer_engine
::
convertNVTE
Tensor
Check
(
tensor
);
switch
(
param_name
)
{
case
kNVTERowwiseData
:
return
t
.
data
;
...
...
@@ -371,28 +517,30 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p
}
NVTEScalingMode
nvte_tensor_scaling_mode
(
const
NVTETensor
tensor
)
{
const
auto
&
t
=
*
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
);
if
(
tensor
==
nullptr
)
{
return
NVTE_DELAYED_TENSOR_SCALING
;
}
const
auto
&
t
=
*
transformer_engine
::
convertNVTETensorCheck
(
tensor
);
return
t
.
scaling_mode
;
}
void
nvte_tensor_pack_create
(
NVTETensorPack
*
pack
)
{
for
(
int
i
=
0
;
i
<
pack
->
MAX_SIZE
;
i
++
)
{
pack
->
tensors
[
i
]
=
reinterpret_cast
<
NVTETensor
>
(
new
transformer_engine
::
Tensor
);
pack
->
tensors
[
i
]
=
transformer_engine
::
TensorAllocator
::
instance
().
Allocate
(
NVTE_DELAYED_TENSOR_SCALING
);
}
}
void
nvte_tensor_pack_destroy
(
NVTETensorPack
*
pack
)
{
for
(
int
i
=
0
;
i
<
pack
->
MAX_SIZE
;
i
++
)
{
auto
*
t
=
reinterpret_cast
<
transformer_engine
::
Tensor
*>
(
pack
->
tensors
[
i
]);
delete
t
;
}
transformer_engine
::
TensorAllocator
::
instance
().
Free
(
pack
->
tensors
,
pack
->
MAX_SIZE
);
}
void
nvte_zero_tensor
(
const
NVTETensor
tensor
,
cudaStream_t
stream
)
{
const
auto
&
t
=
*
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
);
if
(
tensor
==
nullptr
)
return
;
const
auto
&
t
=
*
transformer_engine
::
convertNVTETensorCheck
(
tensor
);
// Zero out tensor data if allocated
if
(
t
.
data
.
dptr
!=
nullptr
)
{
size_t
size_in_bytes
=
nvte_tensor_
element_size
(
tensor
)
*
nvte_tensor_numel
(
tensor
);
const
size_t
size_in_bytes
=
nvte_tensor_
size_bytes
(
tensor
);
cudaMemsetAsync
(
t
.
data
.
dptr
,
0
,
size_in_bytes
,
stream
);
}
// Set amax to 0 if allocated
...
...
@@ -440,6 +588,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
case
kNVTEQuantizationConfigNoopTensor
:
std
::
memcpy
(
buf
,
&
config_
.
noop_tensor
,
attr_size
);
break
;
case
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat
:
std
::
memcpy
(
buf
,
&
config_
.
float8_block_scale_tensor_format
,
attr_size
);
break
;
default:
NVTE_ERROR
(
"Unsupported NVTEQuantizationConfigAttribute (got "
,
static_cast
<
int
>
(
attr
),
")"
);
}
...
...
@@ -472,6 +623,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
case
kNVTEQuantizationConfigNoopTensor
:
std
::
memcpy
(
&
config_
.
noop_tensor
,
buf
,
attr_size
);
break
;
case
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat
:
std
::
memcpy
(
&
config_
.
float8_block_scale_tensor_format
,
buf
,
attr_size
);
break
;
default:
NVTE_ERROR
(
"Unsupported NVTEQuantizationConfigAttribute (got "
,
static_cast
<
int
>
(
attr
),
")"
);
}
...
...
transformer_engine/common/transpose/cast_transpose.cu
View file @
2b05e121
...
...
@@ -348,15 +348,15 @@ void nvte_cast_transpose(const NVTETensor input, NVTETensor output, cudaStream_t
NVTE_API_CALL
(
nvte_cast_transpose
);
using
namespace
transformer_engine
;
auto
noop
=
Tensor
();
transformer_engine
::
detail
::
cast_transpose
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
noop
,
reinterpret_cast
<
Tensor
*>
(
output
),
stream
);
transformer_engine
::
detail
::
cast_transpose
(
*
convertNVTE
Tensor
Check
(
input
),
noop
,
convertNVTE
Tensor
(
output
),
stream
);
}
void
nvte_cast_transpose_with_noop
(
const
NVTETensor
input
,
const
NVTETensor
noop
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_cast_transpose_with_noop
);
using
namespace
transformer_engine
;
transformer_engine
::
detail
::
cast_transpose
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
*
reinterpret_cast
<
const
Tensor
*>
(
noop
),
reinterpret_cast
<
Tensor
*>
(
output
),
stream
);
transformer_engine
::
detail
::
cast_transpose
(
*
convertNVTE
Tensor
Check
(
input
),
*
convertNVTE
Tensor
Check
(
noop
),
convertNVTE
Tensor
(
output
),
stream
);
}
transformer_engine/common/transpose/cast_transpose.h
View file @
2b05e121
...
...
@@ -31,25 +31,27 @@ void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor
// enum class for rowwise usage
enum
class
FP8BlockwiseRowwiseOption
{
// No rowwise data
// No rowwise data
, skip rowwise quantization
NONE
,
// Rowwise data, scales in GEMM format
ROWWISE
//
TODO: FP8 all gather requires some changes.
// 1. Compact scales are better for gathering than the GEMM format.
ROWWISE
_GEMM_READY
,
//
Rowwise data, scales in compact format, needs extra processing (padding, transposing) before GEMM
ROWWISE_COMPACT
};
// enum class for columnwise usage
// For Hopper sm90 with only TN fp8 gemm, there is need to do columnwise transpose when doing 1D block scaling
enum
class
FP8BlockwiseColumnwiseOption
{
// No columnwise data
// No columnwise data
, skip columnwise quantization
NONE
,
// Columnwise data transposed from original shape.
// Scales in GEMM format corresponding to GEMM ingesting transposed column data.
COLUMNWISE_TRANSPOSE
// TODO: FP8 all gather requires some changes.
// 1. The transpose gets in the way of the all gather.
// 2. Compact scales are better for gathering than the GEMM format.
// On Hopper sm90, GEMM_READY means that columnwise quantization also fuses transpose op
// On higher sm versions with TN,NT,NN fp8 gemm, GEMM_READY doesn't fuse transpose
COLUMNWISE_GEMM_READY
,
// Columnwise data in original shape
// Scales in compact format, needs extra processing (padding, transposing) before GEMM
COLUMNWISE_COMPACT
};
void
quantize_transpose_vector_blockwise
(
const
SimpleTensor
&
input
,
SimpleTensor
&
scale_inv
,
...
...
transformer_engine/common/transpose/cast_transpose_fusion.cu
View file @
2b05e121
...
...
@@ -17,6 +17,7 @@
#include "../util/string.h"
#include "../utils.cuh"
#include "cast_transpose.h"
#include "common/common.h"
namespace
transformer_engine
{
...
...
@@ -196,17 +197,18 @@ void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /
workspace
->
data
.
dtype
=
DType
::
kFloat32
;
}
else
{
// Check that workspace matches expected size
const
size_t
workspace_size
=
const
size_t
workspace_size
=
get_buffer_size_bytes
(
std
::
accumulate
(
workspace
->
data
.
shape
.
begin
(),
workspace
->
data
.
shape
.
end
(),
1
,
std
::
multiplies
<
size_t
>
())
*
typeToSize
(
workspace
->
data
.
dtype
);
const
size_t
required_size
=
num_rows_partial_dbias
*
row_length
*
typeToSize
(
DType
::
kFloat32
);
std
::
multiplies
<
size_t
>
()),
workspace
->
data
.
dtype
);
const
size_t
required_size
=
get_buffer_size_bytes
(
num_rows_partial_dbias
,
row_length
,
DType
::
kFloat32
);
NVTE_CHECK
(
!
workspace
->
data
.
shape
.
empty
(),
"Invalid workspace dims (expected ("
,
num_rows_partial_dbias
,
","
,
row_length
,
"), found ())"
);
NVTE_CHECK
(
workspace_size
>=
required_size
,
"Invalid workspace (expected dims=("
,
num_rows_partial_dbias
,
","
,
row_length
,
"), dtype="
,
to_string
(
DType
::
kFloat32
),
"; found dims="
,
workspace
->
data
.
shape
,
", dtype="
,
typeTo
Size
(
workspace
->
data
.
dtype
),
")"
);
", dtype="
,
typeTo
NumBits
(
workspace
->
data
.
dtype
),
"
bits
)"
);
}
}
...
...
@@ -1337,9 +1339,8 @@ void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor output, NVTETe
constexpr
const
NVTETensor
activation_input
=
nullptr
;
cast_transpose_fused
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ComputeType
,
Empty
,
nullptr
>
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
reinterpret_cast
<
const
Tensor
*>
(
activation_input
),
reinterpret_cast
<
Tensor
*>
(
output
),
reinterpret_cast
<
Tensor
*>
(
dbias
),
reinterpret_cast
<
Tensor
*>
(
workspace
),
stream
);
*
convertNVTETensorCheck
(
input
),
convertNVTETensor
(
activation_input
),
convertNVTETensor
(
output
),
convertNVTETensor
(
dbias
),
convertNVTETensor
(
workspace
),
stream
);
}
void
nvte_cast_transpose_dbias_dgelu
(
const
NVTETensor
input
,
const
NVTETensor
act_input
,
...
...
@@ -1354,9 +1355,9 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor ac
constexpr
bool
IS_ACT
=
false
;
cast_transpose_fused
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ComputeType
,
Empty
,
dgelu
<
fp32
,
fp32
>>
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
reinterpret_cast
<
const
Tensor
*>
(
act_input
),
reinterpret_cast
<
Tensor
*>
(
output
),
reinterpret_cast
<
Tensor
*>
(
dbias
),
reinterpret_cast
<
Tensor
*>
(
workspace
),
stream
);
*
convertNVTE
Tensor
Check
(
input
),
convertNVTE
Tensor
Check
(
act_input
),
convertNVTE
Tensor
Check
(
output
),
convertNVTETensorCheck
(
dbias
),
convertNVTETensor
(
workspace
),
stream
);
}
void
nvte_cast_transpose_dbias_dsilu
(
const
NVTETensor
input
,
const
NVTETensor
silu_input
,
...
...
@@ -1371,9 +1372,9 @@ void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor si
constexpr
bool
IS_ACT
=
false
;
cast_transpose_fused
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ComputeType
,
Empty
,
dsilu
<
fp32
,
fp32
>>
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
reinterpret_cast
<
const
Tensor
*>
(
silu_input
),
reinterpret_cast
<
Tensor
*>
(
output
),
reinterpret_cast
<
Tensor
*>
(
dbias
),
reinterpret_cast
<
Tensor
*>
(
workspace
),
stream
);
*
convertNVTE
Tensor
Check
(
input
),
convertNVTE
Tensor
Check
(
silu_input
),
convertNVTE
Tensor
Check
(
output
),
convertNVTETensorCheck
(
dbias
),
convertNVTETensor
(
workspace
),
stream
);
}
void
nvte_cast_transpose_dbias_drelu
(
const
NVTETensor
input
,
const
NVTETensor
relu_input
,
...
...
@@ -1388,9 +1389,9 @@ void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor re
constexpr
bool
IS_ACT
=
false
;
cast_transpose_fused
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ComputeType
,
Empty
,
drelu
<
fp32
,
fp32
>>
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
reinterpret_cast
<
const
Tensor
*>
(
relu_input
),
reinterpret_cast
<
Tensor
*>
(
output
),
reinterpret_cast
<
Tensor
*>
(
dbias
),
reinterpret_cast
<
Tensor
*>
(
workspace
),
stream
);
*
convertNVTE
Tensor
Check
(
input
),
convertNVTE
Tensor
Check
(
relu_input
),
convertNVTE
Tensor
Check
(
output
),
convertNVTETensorCheck
(
dbias
),
convertNVTETensor
(
workspace
),
stream
);
}
void
nvte_cast_transpose_dbias_dsrelu
(
const
NVTETensor
input
,
const
NVTETensor
srelu_input
,
...
...
@@ -1405,9 +1406,9 @@ void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor s
constexpr
bool
IS_ACT
=
false
;
cast_transpose_fused
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ComputeType
,
Empty
,
dsrelu
<
fp32
,
fp32
>>
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
reinterpret_cast
<
const
Tensor
*>
(
srelu_input
),
reinterpret_cast
<
Tensor
*>
(
output
),
reinterpret_cast
<
Tensor
*>
(
dbias
),
reinterpret_cast
<
Tensor
*>
(
workspace
),
stream
);
*
convertNVTE
Tensor
Check
(
input
),
convertNVTE
Tensor
Check
(
srelu_input
),
convertNVTE
Tensor
Check
(
output
),
convertNVTETensorCheck
(
dbias
),
convertNVTETensor
(
workspace
),
stream
);
}
void
nvte_cast_transpose_dbias_dqgelu
(
const
NVTETensor
input
,
const
NVTETensor
qgelu_input
,
...
...
@@ -1422,9 +1423,9 @@ void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor q
constexpr
bool
IS_ACT
=
false
;
cast_transpose_fused
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ComputeType
,
Empty
,
dqgelu
<
fp32
,
fp32
>>
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
reinterpret_cast
<
const
Tensor
*>
(
qgelu_input
),
reinterpret_cast
<
Tensor
*>
(
output
),
reinterpret_cast
<
Tensor
*>
(
dbias
),
reinterpret_cast
<
Tensor
*>
(
workspace
),
stream
);
*
convertNVTE
Tensor
Check
(
input
),
convertNVTE
Tensor
Check
(
qgelu_input
),
convertNVTE
Tensor
Check
(
output
),
convertNVTETensorCheck
(
dbias
),
convertNVTETensor
(
workspace
),
stream
);
}
void
nvte_dgeglu_cast_transpose
(
const
NVTETensor
input
,
const
NVTETensor
gated_act_input
,
...
...
@@ -1434,8 +1435,8 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_a
using
namespace
transformer_engine
::
detail
;
dgated_act_cast_transpose
<
ComputeType
,
Empty
,
dgelu
<
fp32
,
fp32
>
,
gelu
<
fp32
,
fp32
>>
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
*
reinterpret_cast
<
const
Tensor
*>
(
gated_act_input
),
reinterpret_cast
<
Tensor
*>
(
output
),
stream
);
*
convertNVTE
Tensor
Check
(
input
),
*
convertNVTE
Tensor
Check
(
gated_act_input
),
convertNVTE
Tensor
Check
(
output
),
stream
);
}
void
nvte_dswiglu_cast_transpose
(
const
NVTETensor
input
,
const
NVTETensor
swiglu_input
,
...
...
@@ -1445,8 +1446,8 @@ void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor swiglu
using
namespace
transformer_engine
::
detail
;
dgated_act_cast_transpose
<
ComputeType
,
Empty
,
dsilu
<
fp32
,
fp32
>
,
silu
<
fp32
,
fp32
>>
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
*
reinterpret_cast
<
const
Tensor
*>
(
swiglu_input
),
reinterpret_cast
<
Tensor
*>
(
output
),
stream
);
*
convertNVTE
Tensor
Check
(
input
),
*
convertNVTE
Tensor
Check
(
swiglu_input
),
convertNVTE
Tensor
Check
(
output
),
stream
);
}
void
nvte_dreglu_cast_transpose
(
const
NVTETensor
input
,
const
NVTETensor
gated_act_input
,
...
...
@@ -1456,8 +1457,8 @@ void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_a
using
namespace
transformer_engine
::
detail
;
dgated_act_cast_transpose
<
ComputeType
,
Empty
,
drelu
<
fp32
,
fp32
>
,
relu
<
fp32
,
fp32
>>
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
*
reinterpret_cast
<
const
Tensor
*>
(
gated_act_input
),
reinterpret_cast
<
Tensor
*>
(
output
),
stream
);
*
convertNVTE
Tensor
Check
(
input
),
*
convertNVTE
Tensor
Check
(
gated_act_input
),
convertNVTE
Tensor
Check
(
output
),
stream
);
}
void
nvte_dsreglu_cast_transpose
(
const
NVTETensor
input
,
const
NVTETensor
gated_act_input
,
...
...
@@ -1467,8 +1468,8 @@ void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_
using
namespace
transformer_engine
::
detail
;
dgated_act_cast_transpose
<
ComputeType
,
Empty
,
dsrelu
<
fp32
,
fp32
>
,
srelu
<
fp32
,
fp32
>>
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
*
reinterpret_cast
<
const
Tensor
*>
(
gated_act_input
),
reinterpret_cast
<
Tensor
*>
(
output
),
stream
);
*
convertNVTE
Tensor
Check
(
input
),
*
convertNVTE
Tensor
Check
(
gated_act_input
),
convertNVTE
Tensor
Check
(
output
),
stream
);
}
void
nvte_dqgeglu_cast_transpose
(
const
NVTETensor
input
,
const
NVTETensor
gated_act_input
,
...
...
@@ -1478,6 +1479,6 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_
using
namespace
transformer_engine
::
detail
;
dgated_act_cast_transpose
<
ComputeType
,
Empty
,
dqgelu
<
fp32
,
fp32
>
,
qgelu
<
fp32
,
fp32
>>
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
*
reinterpret_cast
<
const
Tensor
*>
(
gated_act_input
),
reinterpret_cast
<
Tensor
*>
(
output
),
stream
);
*
convertNVTE
Tensor
Check
(
input
),
*
convertNVTE
Tensor
Check
(
gated_act_input
),
convertNVTE
Tensor
Check
(
output
),
stream
);
}
transformer_engine/common/transpose/multi_cast_transpose.cu
View file @
2b05e121
...
...
@@ -237,8 +237,8 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, std::vector<Ten
// Input matrices are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles
const
int
tile_dim_m
=
THREADS_PER_WARP
*
desired_store_size
/
typeTo
Size
(
otype
);
const
int
tile_dim_n
=
THREADS_PER_WARP
*
desired_load_size
/
typeTo
Size
(
itype
);
const
int
tile_dim_m
=
THREADS_PER_WARP
*
desired_store_size
*
8
/
typeTo
NumBits
(
otype
);
const
int
tile_dim_n
=
THREADS_PER_WARP
*
desired_load_size
*
8
/
typeTo
NumBits
(
itype
);
// Add tensors to kernel argument struct
MultiCastTransposeArgs
kernel_args_aligned
,
kernel_args_unaligned
;
...
...
@@ -334,8 +334,8 @@ void nvte_multi_cast_transpose(size_t num_tensors, const NVTETensor* input_list,
using
namespace
transformer_engine
;
std
::
vector
<
Tensor
*>
input_list_
,
output_list_
;
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
input_list_
.
push_back
(
reinterpret_cast
<
Tensor
*>
(
const_cast
<
NVTETensor
&>
(
input_list
[
i
]))
)
;
output_list_
.
push_back
(
reinterpret_cast
<
Tensor
*>
(
output_list
[
i
]));
input_list_
.
push_back
(
convert
NVTETensor
Check
(
input_list
[
i
]));
output_list_
.
push_back
(
convertNVTE
Tensor
Check
(
output_list
[
i
]));
}
multi_cast_transpose
(
input_list_
,
output_list_
,
stream
);
}
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
View file @
2b05e121
...
...
@@ -483,7 +483,7 @@ CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size
CUtensorMap
tensor_map_output_trans
{};
create_2D_tensor_map
(
tensor_map_output_trans
,
tensor
,
global_dim_y
,
global_dim_x
,
/*shmemY=*/
BLOCK_TILE_DIM
,
/*shmemX=*/
BLOCK_TILE_DIM
,
/*stride_elems=*/
global_dim_x
,
/*offset_elems=*/
0
,
sizeof
(
OutputType
));
/*stride_elems=*/
global_dim_x
,
/*offset_elems=*/
0
,
sizeof
(
OutputType
)
*
8
);
return
tensor_map_output_trans
;
}
#endif
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment