Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
c520cba3
Commit
c520cba3
authored
Mar 20, 2025
by
yuguo
Browse files
[DCU] Preliminary adaptation
parent
5b6ef054
Changes
79
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
1089 additions
and
71 deletions
+1089
-71
transformer_engine/pytorch/attention.py
transformer_engine/pytorch/attention.py
+30
-28
transformer_engine/pytorch/cpp_extensions/__init__.py
transformer_engine/pytorch/cpp_extensions/__init__.py
+3
-1
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+82
-0
transformer_engine/pytorch/csrc/common.h
transformer_engine/pytorch/csrc/common.h
+6
-2
transformer_engine/pytorch/csrc/extensions.h
transformer_engine/pytorch/csrc/extensions.h
+10
-0
transformer_engine/pytorch/csrc/extensions/attention.cu
transformer_engine/pytorch/csrc/extensions/attention.cu
+15
-0
transformer_engine/pytorch/csrc/extensions/gemm.cpp
transformer_engine/pytorch/csrc/extensions/gemm.cpp
+121
-0
transformer_engine/pytorch/csrc/extensions/misc.cpp
transformer_engine/pytorch/csrc/extensions/misc.cpp
+6
-0
transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu
...pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu
+9
-0
transformer_engine/pytorch/csrc/extensions/pybind.cpp
transformer_engine/pytorch/csrc/extensions/pybind.cpp
+6
-0
transformer_engine/pytorch/csrc/type_shim.h
transformer_engine/pytorch/csrc/type_shim.h
+10
-0
transformer_engine/pytorch/dot_product_attention/utils.py
transformer_engine/pytorch/dot_product_attention/utils.py
+24
-7
transformer_engine/pytorch/fp8.py
transformer_engine/pytorch/fp8.py
+15
-8
transformer_engine/pytorch/jit.py
transformer_engine/pytorch/jit.py
+23
-21
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+19
-0
transformer_engine/pytorch/module/batched_linear.py
transformer_engine/pytorch/module/batched_linear.py
+672
-0
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+3
-1
transformer_engine/pytorch/triton/permutation.py
transformer_engine/pytorch/triton/permutation.py
+8
-1
transformer_engine/pytorch/utils.py
transformer_engine/pytorch/utils.py
+27
-2
No files found.
transformer_engine/pytorch/attention.py
View file @
c520cba3
...
...
@@ -17,6 +17,7 @@ import numpy as np
from
packaging.version
import
Version
as
PkgVersion
import
torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.utils
import
(
...
...
@@ -98,7 +99,7 @@ try:
except
PackageNotFoundError
:
if
(
torch
.
cuda
.
is_available
()
and
get_device_compute_capability
()
>=
(
8
,
0
)
and
(
IS_HIP_EXTENSION
or
get_device_compute_capability
()
>=
(
8
,
0
)
)
and
dpa_utils
.
_NVTE_FLASH_ATTN
):
attn_log
.
fa_logger
.
debug
(
...
...
@@ -128,7 +129,7 @@ else:
fa_utils
.
set_flash_attention_version
()
elif
(
torch
.
cuda
.
is_available
()
and
get_device_compute_capability
()
>=
(
8
,
0
)
and
(
IS_HIP_EXTENSION
or
get_device_compute_capability
()
>=
(
8
,
0
)
)
and
dpa_utils
.
_NVTE_FLASH_ATTN
):
attn_log
.
fa_logger
.
warning
(
...
...
@@ -147,33 +148,34 @@ else:
# Detect flash-attn v3 in the environment
# This section will be removed when FA3 is released as a regular FA package,
# i.e. flashattn-hopper 3.0.0 as flash-attn 3.0.0
try
:
fa_utils
.
fa3_version
=
PkgVersion
(
get_pkg_version
(
"flashattn-hopper"
))
except
PackageNotFoundError
:
if
(
torch
.
cuda
.
is_available
()
and
get_device_compute_capability
()
>=
(
9
,
0
)
and
dpa_utils
.
_NVTE_FLASH_ATTN
):
attn_log
.
fa_logger
.
debug
(
"flash-attn v3 is not installed. To use, please install it by
\n
%s"
,
fa_utils
.
v3_installation_steps
,
if
not
IS_HIP_EXTENSION
:
try
:
fa_utils
.
fa3_version
=
PkgVersion
(
get_pkg_version
(
"flashattn-hopper"
))
except
PackageNotFoundError
:
if
(
torch
.
cuda
.
is_available
()
and
get_device_compute_capability
()
>=
(
9
,
0
)
and
dpa_utils
.
_NVTE_FLASH_ATTN
):
attn_log
.
fa_logger
.
debug
(
"flash-attn v3 is not installed. To use, please install it by
\n
%s"
,
fa_utils
.
v3_installation_steps
,
)
else
:
from
flashattn_hopper.flash_attn_interface
import
flash_attn_func
as
flash_attn_func_v3
from
flashattn_hopper.flash_attn_interface
import
(
flash_attn_varlen_func
as
flash_attn_varlen_func_v3
,
)
else
:
from
flashattn_hopper.flash_attn_interface
import
flash_attn_func
as
flash_attn_func_v3
from
flashattn_hopper.flash_attn_interface
import
(
flash_attn_varlen_func
as
flash_attn_varlen_func_v3
,
)
from
flashattn_hopper.flash_attn_interface
import
_flash_attn_forward
as
_flash_attn_fwd_v3
from
flashattn_hopper.flash_attn_interface
import
_flash_attn_backward
as
_flash_attn_bwd_v3
from
flashattn_hopper.flash_attn_interface
import
(
_flash_attn_varlen_forward
as
_flash_attn_varlen_fwd_v3
,
)
from
flashattn_hopper.flash_attn_interface
import
(
_flash_attn_varlen_backward
as
_flash_attn_varlen_bwd_v3
,
)
fa_utils
.
set_flash_attention_3_params
()
from
flashattn_hopper.flash_attn_interface
import
_flash_attn_forward
as
_flash_attn_fwd_v3
from
flashattn_hopper.flash_attn_interface
import
_flash_attn_backward
as
_flash_attn_bwd_v3
from
flashattn_hopper.flash_attn_interface
import
(
_flash_attn_varlen_forward
as
_flash_attn_varlen_fwd_v3
,
)
from
flashattn_hopper.flash_attn_interface
import
(
_flash_attn_varlen_backward
as
_flash_attn_varlen_bwd_v3
,
)
fa_utils
.
set_flash_attention_3_params
()
# Global vars for available attention backends and ALiBi cache
_attention_backends
=
{
...
...
transformer_engine/pytorch/cpp_extensions/__init__.py
View file @
c520cba3
...
...
@@ -5,5 +5,7 @@
"""Python interface for c++ extensions"""
from
transformer_engine_torch
import
*
from
.fused_attn
import
*
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
if
not
IS_HIP_EXTENSION
:
from
.fused_attn
import
*
from
.gemm
import
*
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
c520cba3
...
...
@@ -224,3 +224,85 @@ def general_grouped_gemm(
)
return
out
,
bias
,
gelu_input
def
general_batched_gemm
(
A
:
List
[
torch
.
Tensor
],
B
:
List
[
torch
.
Tensor
],
out
:
List
[
torch
.
Tensor
],
out_dtype
:
torch
.
dtype
,
workspaces
:
List
[
torch
.
Tensor
],
layout
:
str
=
"TN"
,
m_splits
:
Optional
[
List
[
int
]]
=
None
,
gelu
:
bool
=
False
,
grad
=
False
,
accumulate
:
bool
=
False
,
bias
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
use_bias
:
bool
=
False
,
use_split_accumulator
:
bool
=
False
,
D_dtype
:
Optional
[
tex
.
DType
]
=
None
,
single_output
=
False
,
)
->
Tuple
[
List
[
torch
.
Tensor
],
...]:
"""
TN layout Grouped GEMM with fp8 inputs.
"""
num_gemms
=
len
(
A
)
transa
=
layout
[
0
]
==
"T"
transb
=
layout
[
1
]
==
"T"
# assert [a.is_contiguous() for a in A]
# assert [b.is_contiguous() for b in B]
if
isinstance
(
A
[
0
],
Float8TensorBase
):
for
a
,
b
in
zip
(
A
,
B
):
assert_dim_for_fp8_exec
(
a
.
_data
)
assert_dim_for_fp8_exec
(
b
.
_data
)
empty_tensor
=
_empty_tensor
()
empty_tensors
=
[
empty_tensor
]
*
num_gemms
# Use bfloat16 as default bias_dtype
gelu_input
=
empty_tensors
out_dtype
=
TE_DType
[
out
[
0
].
dtype
]
if
D_dtype
is
None
else
D_dtype
sm_count
=
get_sm_count
()
if
grad
and
use_bias
:
grad_bias
=
[
torch
.
empty
(
B
[
i
].
shape
[
1
],
dtype
=
out
[
0
].
dtype
,
device
=
"cuda"
)
for
i
in
range
(
num_gemms
)
]
else
:
grad_bias
=
empty_tensors
bias
=
bias
if
use_bias
else
empty_tensors
if
use_bias
:
bias_dtype
=
TE_DType
[
grad_bias
[
0
].
dtype
]
if
grad
else
TE_DType
[
bias
[
0
].
dtype
]
else
:
bias_dtype
=
TE_DType
[
torch
.
bfloat16
]
if
gelu
:
gelu_input
=
[
torch
.
empty_like
(
o
,
dtype
=
bias_dtype
,
memory_format
=
torch
.
contiguous_format
)
for
o
in
out
]
# this should differ with respect to single output
bias
=
tex
.
te_general_batched_gemm
(
A
,
transa
,
B
,
transb
,
out
,
out_dtype
,
m_splits
,
grad_bias
if
grad
else
bias
,
bias_dtype
,
single_output
,
gelu_input
,
# this is pre_gelu_out
grad
,
# grad
workspaces
,
workspaces
[
0
].
shape
[
0
],
accumulate
,
use_split_accumulator
,
sm_count
-
int
(
os
.
getenv
(
"NVTE_EXT_MARGIN_SM"
,
str
(
sm_count
))),
)
return
out
,
bias
,
gelu_input
transformer_engine/pytorch/csrc/common.h
View file @
c520cba3
...
...
@@ -14,11 +14,15 @@
#include <ATen/cudnn/Handle.h>
#include <ATen/native/DispatchStub.h>
#include <c10/macros/Macros.h>
#include <cuda_runtime.h>
#ifndef USE_ROCM
#include <cublasLt.h>
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <cudnn.h>
#include <cuda_bf16.h>
#else
#include <hip/hip_bf16.h>
#endif
#include <torch/extension.h>
#include <torch/torch.h>
#include <transformer_engine/activation.h>
...
...
transformer_engine/pytorch/csrc/extensions.h
View file @
c520cba3
...
...
@@ -93,6 +93,16 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
bool
grad
,
std
::
vector
<
at
::
Tensor
>
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
);
#ifdef __HIP_PLATFORM_AMD__
std
::
optional
<
std
::
vector
<
at
::
Tensor
>>
te_general_batched_gemm
(
std
::
vector
<
py
::
handle
>
A
,
bool
transa
,
std
::
vector
<
py
::
handle
>
B
,
bool
transb
,
std
::
optional
<
std
::
vector
<
at
::
Tensor
>>
D
,
transformer_engine
::
DType
D_type
,
std
::
vector
<
int64_t
>
m_splits
,
std
::
vector
<
at
::
Tensor
>
bias
,
transformer_engine
::
DType
bias_type
,
bool
single_output
,
std
::
vector
<
at
::
Tensor
>
pre_gelu_out
,
bool
grad
,
std
::
vector
<
at
::
Tensor
>
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
)
#endif
/***************************************************************************************************
* Transpose
**************************************************************************************************/
...
...
transformer_engine/pytorch/csrc/extensions/attention.cu
View file @
c520cba3
...
...
@@ -16,11 +16,16 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
float
p_dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
#ifdef __HIP_PLATFORM_AMD__
static_assert
(
false
,
"Get_fused_attn_backend is not surpported in rocm for normalization yet."
);
#else
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
static_cast
<
NVTEDType
>
(
q_dtype
),
static_cast
<
NVTEDType
>
(
kv_dtype
),
qkv_layout
,
bias_type
,
attn_mask_type
,
p_dropout
,
num_attn_heads
,
num_gqa_groups
,
max_seqlen_q
,
max_seqlen_kv
,
head_dim_qk
,
head_dim_v
,
window_size_left
,
window_size_right
);
return
fused_attention_backend
;
#endif
}
// fast zero-fills of tensors
...
...
@@ -93,6 +98,10 @@ std::vector<py::object> fused_attn_fwd(
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
py
::
handle
s_quantizer
,
py
::
handle
o_quantizer
,
const
c10
::
optional
<
at
::
Tensor
>
Bias
,
const
c10
::
optional
<
at
::
Generator
>
rng_gen
,
size_t
rng_elts_per_thread
)
{
#ifdef __HIP_PLATFORM_AMD__
static_assert
(
false
,
"Fused_attn_fwd is not surpported in rocm for normalization yet."
);
#else
using
namespace
transformer_engine
;
using
namespace
transformer_engine
::
pytorch
;
TensorWrapper
te_Q
,
te_K
,
te_V
,
te_O
,
te_S
;
...
...
@@ -254,6 +263,7 @@ std::vector<py::object> fused_attn_fwd(
// if training, [O, softmax-related tensors, rng_state]; if inference, [O]
return
output_tensors
;
#endif
}
// fused attention BWD with separate Q, K and V
...
...
@@ -267,6 +277,10 @@ std::vector<py::object> fused_attn_bwd(
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
py
::
handle
s_quantizer
,
py
::
handle
dp_quantizer
,
py
::
handle
dqkv_quantizer
)
{
#ifdef __HIP_PLATFORM_AMD__
static_assert
(
false
,
"Fused_attn_bwd is not surpported in rocm for normalization yet."
);
#else
using
namespace
transformer_engine
;
using
namespace
transformer_engine
::
pytorch
;
auto
none
=
py
::
none
();
...
...
@@ -492,6 +506,7 @@ std::vector<py::object> fused_attn_bwd(
nvte_tensor_pack_destroy
(
&
nvte_aux_tensor_pack
);
return
{
py_dQ
,
py_dK
,
py_dV
,
py
::
cast
(
dBias
)};
#endif
}
namespace
flash_attention
{
...
...
transformer_engine/pytorch/csrc/extensions/gemm.cpp
View file @
c520cba3
...
...
@@ -411,3 +411,124 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
math_sm_count
,
at
::
cuda
::
getCurrentCUDAStream
());
return
bias
;
}
#ifdef USE_ROCM
std
::
optional
<
std
::
vector
<
at
::
Tensor
>>
te_general_batched_gemm
(
std
::
vector
<
py
::
handle
>
A
,
bool
transa
,
std
::
vector
<
py
::
handle
>
B
,
bool
transb
,
std
::
optional
<
std
::
vector
<
at
::
Tensor
>>
D
,
transformer_engine
::
DType
D_type
,
std
::
vector
<
int64_t
>
m_splits
,
std
::
vector
<
at
::
Tensor
>
bias
,
transformer_engine
::
DType
bias_type
,
bool
single_output
,
std
::
vector
<
at
::
Tensor
>
pre_gelu_out
,
bool
grad
,
std
::
vector
<
at
::
Tensor
>
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
::
pytorch
;
std
::
vector
<
NVTETensor
>
te_A_vector
,
te_B_vector
,
te_D_vector
,
te_bias_vector
,
te_pre_gelu_out_vector
,
te_workspace_vector
;
std
::
vector
<
TensorWrapper
>
wrappers
;
std
::
vector
<
at
::
Tensor
>
D_vectors
;
auto
none
=
py
::
none
();
std
::
vector
<
size_t
>
single_output_begins
;
std
::
vector
<
size_t
>
single_output_ends
;
int
slicing_dim
;
if
(
single_output
&&
D
==
std
::
nullopt
)
{
NVTE_ERROR
(
"not implemented, D should be allocated for single output case."
);
}
void
*
output_data_ptr
;
if
(
single_output
)
{
output_data_ptr
=
(
*
D
)[
0
].
data_ptr
();
}
for
(
size_t
i
=
0
;
i
<
A
.
size
();
i
++
)
{
auto
te_A
=
makeTransformerEngineTensor
(
A
[
i
],
none
);
auto
te_B
=
makeTransformerEngineTensor
(
B
[
i
],
none
);
// if there is single output
at
::
Tensor
out_tensor
;
auto
size_t_shape
=
pytorch
::
detail
::
getGemmOutputShape
(
te_A
.
shape
(),
transa
,
te_B
.
shape
(),
transb
);
bool
D_numel_is_zero
=
false
;
std
::
vector
<
int64_t
>
D_shape
;
for
(
size_t
t
:
size_t_shape
)
{
D_shape
.
push_back
(
t
);
if
(
t
==
0
)
{
D_numel_is_zero
=
true
;
}
}
auto
dtype
=
GetATenDType
(
D_type
);
auto
opts
=
torch
::
TensorOptions
().
dtype
(
dtype
).
device
(
torch
::
kCUDA
);
if
(
single_output
)
{
if
(
output_data_ptr
==
nullptr
)
{
out_tensor
=
at
::
empty
(
D_shape
,
opts
);
}
else
{
// We need to check !D_numel_is_zero because if the final input portion has zero elements,
// output_data_ptr would point beyond the allocated memory of D. This would cause
// at::from_blob to fail as it would reference memory not allocated by CUDA.
if
(
!
D_numel_is_zero
)
{
out_tensor
=
at
::
from_blob
(
output_data_ptr
,
D_shape
,
opts
);
}
}
char
*
char_ptr
=
reinterpret_cast
<
char
*>
(
output_data_ptr
);
char_ptr
+=
D_shape
[
0
]
*
D_shape
[
1
]
*
(
*
D
)[
0
].
element_size
();
output_data_ptr
=
reinterpret_cast
<
void
*>
(
char_ptr
);
D_vectors
.
emplace_back
(
out_tensor
);
}
else
{
if
(
D
==
std
::
nullopt
)
{
auto
opts
=
torch
::
TensorOptions
().
dtype
(
dtype
).
device
(
torch
::
kCUDA
);
out_tensor
=
at
::
empty
(
D_shape
,
opts
);
D_vectors
.
emplace_back
(
out_tensor
);
}
else
{
out_tensor
=
(
*
D
)[
i
];
}
}
if
(
te_A
.
numel
()
==
0
||
te_B
.
numel
()
==
0
)
{
if
(
out_tensor
.
numel
()
!=
0
&&
!
accumulate
)
out_tensor
.
zero_
();
if
(
bias
[
i
].
numel
()
!=
0
&&
grad
)
{
bias
[
i
].
zero_
();
}
if
(
pre_gelu_out
[
i
].
numel
()
!=
0
)
pre_gelu_out
[
i
].
zero_
();
continue
;
}
auto
te_D
=
makeTransformerEngineTensor
(
out_tensor
);
auto
te_bias
=
makeTransformerEngineTensor
(
bias
[
i
]);
auto
te_pre_gelu_out
=
makeTransformerEngineTensor
(
pre_gelu_out
[
i
]);
const
auto
gelu_shape
=
pre_gelu_out
[
i
].
data_ptr
()
==
nullptr
?
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
te_pre_gelu_out
.
size
(
0
))}
:
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
te_pre_gelu_out
.
size
(
0
)),
static_cast
<
size_t
>
(
te_pre_gelu_out
.
size
(
1
))};
DType
gelu_type
=
bias_type
;
te_pre_gelu_out
=
makeTransformerEngineTensor
(
get_data_ptr
(
pre_gelu_out
[
i
]),
gelu_shape
,
gelu_type
);
te_A_vector
.
emplace_back
(
te_A
.
data
());
te_B_vector
.
emplace_back
(
te_B
.
data
());
te_D_vector
.
emplace_back
(
te_D
.
data
());
te_bias_vector
.
emplace_back
(
te_bias
.
data
());
te_pre_gelu_out_vector
.
emplace_back
(
te_pre_gelu_out
.
data
());
wrappers
.
emplace_back
(
std
::
move
(
te_A
));
wrappers
.
emplace_back
(
std
::
move
(
te_B
));
wrappers
.
emplace_back
(
std
::
move
(
te_D
));
wrappers
.
emplace_back
(
std
::
move
(
te_bias
));
wrappers
.
emplace_back
(
std
::
move
(
te_pre_gelu_out
));
}
for
(
size_t
i
=
0
;
i
<
workspace
.
size
();
i
++
)
{
auto
wsp
=
makeTransformerEngineTensor
(
workspace
[
i
].
data_ptr
(),
{
workspaceSize
},
DType
::
kByte
);
te_workspace_vector
.
emplace_back
(
wsp
.
data
());
wrappers
.
emplace_back
(
std
::
move
(
wsp
));
}
// For now, we only have multi-stream cublas backend.
nvte_multi_stream_cublas_batchgemm
(
te_A
.
data
(),
te_B
.
data
(),
te_D
.
data
(),
te_bias
.
data
(),
te_pre_gelu_out
.
data
(),
te_A
.
size
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
accumulate
,
use_split_accumulator
,
math_sm_count
,
at
::
cuda
::
getCurrentCUDAStream
());
return
bias
;
}
#endif
transformer_engine/pytorch/csrc/extensions/misc.cpp
View file @
c520cba3
...
...
@@ -6,6 +6,12 @@
#include "extensions.h"
#ifdef USE_ROCM
size_t
get_cublasLt_version
()
{
int
version
=
10000000
;
return
version
;
}
size_t
get_cudnn_version
()
{
int
version
=
0
;
return
version
;
}
#else
size_t
get_cublasLt_version
()
{
return
cublasLtGetVersion
();
}
size_t
get_cudnn_version
()
{
return
cudnnGetVersion
();
}
#endif
transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu
View file @
c520cba3
...
...
@@ -8,7 +8,11 @@
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#ifdef __HIP_PLATFORM_AMD__
#include "amd_detail/hip_float8.h"
#else
#include <cuda_fp8.h>
#endif
// Another possibility:
// #include <torch/all.h>
...
...
@@ -28,8 +32,13 @@ typedef enum {
}
adamMode_t
;
using
MATH_T
=
float
;
#ifndef __HIP_PLATFORM_AMD__
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
#else
using
fp8e4m3
=
hip_f8
<
hip_f8_type
::
fp8
>
;
using
fp8e5m2
=
hip_f8
<
hip_f8_type
::
bf8
>
;
#endif
using
transformer_engine
::
DType
;
template
<
typename
T
>
...
...
transformer_engine/pytorch/csrc/extensions/pybind.cpp
View file @
c520cba3
...
...
@@ -174,6 +174,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"input_list"
),
py
::
arg
(
"output_list"
),
py
::
arg
(
"quantizer_list"
),
py
::
arg
(
"otype"
));
m
.
def
(
"te_general_grouped_gemm"
,
&
te_general_grouped_gemm
,
"Grouped GEMM"
);
#ifdef USE_ROCM
m
.
def
(
"te_general_batched_gemm"
,
&
te_general_batched_gemm
,
"Batched GEMM"
);
/// rocblas
#endif
m
.
def
(
"fused_attn_fwd"
,
&
fused_attn_fwd
,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"
);
m
.
def
(
"fused_attn_bwd"
,
&
fused_attn_bwd
,
...
...
@@ -207,6 +210,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"get_cudnn_version"
,
&
get_cudnn_version
,
"Get cuDNN version"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
attr
(
"_num_cublas_streams"
)
=
py
::
int_
(
transformer_engine
::
num_streams
);
#ifdef USE_ROCM
m
.
attr
(
"_num_cublas_batchgemm_streams"
)
=
py
::
int_
(
transformer_engine
::
num_batchgemm_streams
);
#endif
// Support THD format for Context Parallel
m
.
def
(
"thd_read_half_tensor"
,
&
thd_read_half_tensor
,
...
...
transformer_engine/pytorch/csrc/type_shim.h
View file @
c520cba3
...
...
@@ -267,6 +267,8 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
constexpr
uint32_t
THREADS_PER_WARP
=
32
;
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
T
val
,
int
lanes
=
1
,
...
...
@@ -295,7 +297,11 @@ reduce_block_into_lanes(T *x, T val, int lanes = 1,
// __SYNCWARP();
#pragma unroll
#ifdef __HIP_PLATFORM_AMD__
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down
(
final
,
i
,
THREADS_PER_WARP
);
#else
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
#endif
}
if
(
share_result
)
{
...
...
@@ -337,7 +343,11 @@ reduce_block_into_lanes_max_op(T *x, T val, int lanes = 1,
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
#ifdef __HIP_PLATFORM_AMD__
final
=
fmaxf
(
fabsf
(
final
),
fabsf
(
__shfl_down
(
final
,
i
,
THREADS_PER_WARP
)));
#else
final
=
fmaxf
(
fabsf
(
final
),
fabsf
(
__shfl_down_sync
(
0xffffffff
,
final
,
i
)));
#endif
}
if
(
share_result
)
{
...
...
transformer_engine/pytorch/dot_product_attention/utils.py
View file @
c520cba3
...
...
@@ -37,7 +37,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
from
transformer_engine.pytorch.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.fp8
import
get_fp8_te_dtype
from
transformer_engine.pytorch.constants
import
TE_DType
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
transformer_engine.pytorch.utils
import
(
get_device_compute_capability
,
...
...
@@ -347,7 +347,7 @@ def get_attention_backend(
logger
.
debug
(
"Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0"
)
# Filter: Compute capability
if
device_compute_capability
<
(
8
,
0
):
if
not
IS_HIP_EXTENSION
and
device_compute_capability
<
(
8
,
0
):
if
use_flash_attention
and
FlashAttentionUtils
.
is_installed
:
logger
.
debug
(
"Disabling FlashAttention as it requires compute capability sm80+"
)
use_flash_attention
=
False
...
...
@@ -395,12 +395,22 @@ def get_attention_backend(
if
use_unfused_attention
:
logger
.
debug
(
"Disabling UnfusedDotProductAttention as it does not support FP8"
)
use_unfused_attention
=
False
# TODO: rocm fused attention backends does not support fp8 yet
if
IS_HIP_EXTENSION
and
use_fused_attention
:
logger
.
debug
(
"Disabling ROCm FusedAttention as it does not support FP8"
)
use_fused_attention
=
False
# Filter: Head dimension
if
use_flash_attention
and
head_dim_qk
!=
head_dim_v
:
if
FlashAttentionUtils
.
is_installed
:
logger
.
debug
(
"Disabling FlashAttention as it does not support MLA."
)
use_flash_attention
=
False
if
not
IS_HIP_EXTENSION
:
if
use_flash_attention
and
head_dim_qk
!=
head_dim_v
:
if
FlashAttentionUtils
.
is_installed
:
logger
.
debug
(
"Disabling FlashAttention as it does not support MLA."
)
use_flash_attention
=
False
else
:
if
use_fused_attention
and
head_dim_qk
!=
head_dim_v
:
logger
.
debug
(
"Disabling FusedAttention as it does not support MLA in rocm backend."
)
use_fused_attention
=
False
if
use_flash_attention
and
(
head_dim_qk
>
256
or
head_dim_qk
%
8
!=
0
...
...
@@ -441,6 +451,12 @@ def get_attention_backend(
"padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
)
use_flash_attention
=
False
if
IS_HIP_EXTENSION
and
use_fused_attention
and
pad_between_seqs
:
logger
.
debug
(
"Disabling rocm fused attn for qkv_format = thd when there is "
"padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
)
use_fused_attention
=
False
# Filter: Dropout
if
attention_dropout
!=
0.0
and
use_flash_attention
and
FlashAttentionUtils
.
use_v3
:
...
...
@@ -839,7 +855,7 @@ def get_attention_backend(
# Select FusedAttention for performance
if
(
use_flash_attention
use_flash_attention
and
(
not
IS_HIP_EXTENSION
)
and
use_fused_attention
and
fused_attention_backend
==
FusedAttnBackend
[
"F16_arbitrary_seqlen"
]
):
...
...
@@ -852,6 +868,7 @@ def get_attention_backend(
if
(
use_flash_attention
and
use_fused_attention
and
not
IS_HIP_EXTENSION
and
fused_attention_backend
==
FusedAttnBackend
[
"FP8"
]
and
FlashAttentionUtils
.
use_v3
):
...
...
transformer_engine/pytorch/fp8.py
View file @
c520cba3
...
...
@@ -24,6 +24,7 @@ from transformer_engine.common.recipe import (
from
.constants
import
dist_group_type
from
.utils
import
get_device_compute_capability
from
.jit
import
jit_fuser
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
__all__
=
[
"fp8_autocast"
,
"fp8_model_init"
]
...
...
@@ -31,14 +32,20 @@ __all__ = ["fp8_autocast", "fp8_model_init"]
def
check_fp8_support
()
->
Tuple
[
bool
,
str
]:
"""Return if fp8 support is available"""
if
get_device_compute_capability
()
>=
(
9
,
0
):
# hopper and above
return
True
,
""
if
get_device_compute_capability
()
<
(
8
,
9
):
# pre-ada
return
False
,
"Device compute capability 8.9 or higher required for FP8 execution."
if
tex
.
get_cublasLt_version
()
<
120103
:
return
False
,
"CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
if
float
(
torch
.
version
.
cuda
)
<
12.1
:
return
False
,
"Cuda version 12.1 or higher required for FP8 execution on Ada."
if
IS_HIP_EXTENSION
:
if
get_device_compute_capability
()
==
(
9
,
4
):
return
True
,
""
else
:
return
False
,
"DCU not support fp8 for now"
else
:
if
get_device_compute_capability
()
>=
(
9
,
0
):
# hopper and above
return
True
,
""
if
get_device_compute_capability
()
<
(
8
,
9
):
# pre-ada
return
False
,
"Device compute capability 8.9 or higher required for FP8 execution."
if
tex
.
get_cublasLt_version
()
<
120103
:
return
False
,
"CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
if
float
(
torch
.
version
.
cuda
)
<
12.1
:
return
False
,
"Cuda version 12.1 or higher required for FP8 execution on Ada."
return
True
,
""
...
...
transformer_engine/pytorch/jit.py
View file @
c520cba3
...
...
@@ -7,6 +7,7 @@ import os
from
typing
import
Callable
,
Optional
,
Tuple
import
torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
# pylint: disable=unnecessary-lambda-assignment
...
...
@@ -27,27 +28,28 @@ no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(f, recu
def
set_jit_fusion_options
()
->
None
:
"""Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
1
])
if
TORCH_MAJOR
==
2
and
TORCH_MINOR
>=
2
:
pass
elif
(
TORCH_MAJOR
==
2
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>=
10
):
# nvfuser
torch
.
_C
.
_jit_set_profiling_executor
(
True
)
torch
.
_C
.
_jit_set_profiling_mode
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
False
)
torch
.
_C
.
_jit_set_texpr_fuser_enabled
(
False
)
torch
.
_C
.
_jit_set_nvfuser_enabled
(
True
)
torch
.
_C
.
_debug_set_autodiff_subgraph_inlining
(
False
)
else
:
# legacy pytorch fuser
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
if
not
IS_HIP_EXTENSION
:
"""Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
1
])
if
TORCH_MAJOR
==
2
and
TORCH_MINOR
>=
2
:
pass
elif
(
TORCH_MAJOR
==
2
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>=
10
):
# nvfuser
torch
.
_C
.
_jit_set_profiling_executor
(
True
)
torch
.
_C
.
_jit_set_profiling_mode
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
False
)
torch
.
_C
.
_jit_set_texpr_fuser_enabled
(
False
)
torch
.
_C
.
_jit_set_nvfuser_enabled
(
True
)
torch
.
_C
.
_debug_set_autodiff_subgraph_inlining
(
False
)
else
:
# legacy pytorch fuser
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
@
jit_fuser
...
...
transformer_engine/pytorch/module/base.py
View file @
c520cba3
...
...
@@ -35,6 +35,7 @@ from ..constants import dist_group_type
from
..tensor
import
QuantizedTensor
,
Quantizer
from
..tensor._internal.float8_tensor_base
import
Float8TensorBase
from
..tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
__all__
=
[
"initialize_ub"
,
"destroy_ub"
]
...
...
@@ -42,6 +43,7 @@ _2X_ACC_FPROP = False
_2X_ACC_DGRAD
=
True
_2X_ACC_WGRAD
=
True
_multi_stream_cublas_workspace
=
[]
_multi_stream_cublas_batchgemm_workspace
=
[]
_cublas_workspace
=
None
_ub_communicators
=
None
_NUM_MAX_UB_STREAMS
=
3
...
...
@@ -51,6 +53,13 @@ layers_atomic_ring_exchange = []
def
get_cublas_workspace_size_bytes
()
->
None
:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
# Add env for control the padding for blaslt
if
IS_HIP_EXTENSION
:
nvte_blaslt_nopad
=
int
(
os
.
environ
.
get
(
"NVTE_BLASLT_NOPAD"
,
0
))
if
(
nvte_blaslt_nopad
):
return
536_870_912
else
:
return
1_073_741_824
if
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
major
>=
9
:
return
33_554_432
return
4_194_304
...
...
@@ -76,6 +85,16 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
)
return
_multi_stream_cublas_workspace
def
get_multi_stream_cublas_batchgemm_workspace
()
->
List
[
torch
.
Tensor
]:
"""Returns workspace for multi-stream cublas."""
global
_multi_stream_cublas_batchgemm_workspace
if
not
_multi_stream_cublas_batchgemm_workspace
:
for
_
in
range
(
tex
.
_num_cublas_batchgemm_streams
):
_multi_stream_cublas_batchgemm_workspace
.
append
(
torch
.
empty
(
get_cublas_workspace_size_bytes
(),
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
)
return
_multi_stream_cublas_batchgemm_workspace
def
initialize_ub
(
shape
:
list
,
...
...
transformer_engine/pytorch/module/batched_linear.py
0 → 100644
View file @
c520cba3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""BatchedLinear API"""
from
typing
import
Union
,
Optional
,
Callable
,
Tuple
,
List
import
torch
import
transformer_engine_torch
as
tex
from
.base
import
(
get_multi_stream_cublas_batchgemm_workspace
,
TransformerEngineBaseModule
,
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
)
from
..fp8
import
FP8GlobalStateManager
from
..utils
import
(
divide
,
cast_if_needed
,
assert_dim_for_fp8_exec
,
clear_tensor_data
,
init_method_constant
,
requires_grad
,
)
from
..distributed
import
(
set_tensor_model_parallel_attributes
,
get_distributed_world_size
,
is_fp8_activation_recompute_enabled
,
in_fp8_activation_recompute_phase
,
)
from
..cpp_extensions
import
(
general_batched_gemm
,
)
from
..constants
import
GemmParallelModes
,
dist_group_type
,
TE_DType
from
..jit
import
no_torch_dynamo
from
..graph
import
is_graph_capturing
from
..tensor.float8_tensor
import
Float8Tensor
from
..cpu_offload
import
is_cpu_offload_enabled
from
..tensor.quantized_tensor
import
(
QuantizedTensor
,
Quantizer
,
prepare_for_saving
,
restore_from_saved
,
)
__all__
=
[
"BatchedLinear"
]
class
_BatchedLinear
(
torch
.
autograd
.
Function
):
"""BatchedLinear semi-top level module
Calls custom cuda extensions.
"""
@
staticmethod
def
forward
(
ctx
,
inp
:
torch
.
Tensor
,
m_splits
:
List
[
int
],
use_bias
:
bool
,
is_first_microbatch
:
Union
[
bool
,
None
],
fp8
:
bool
,
fp8_calibration
:
bool
,
input_quantizers
:
List
[
Quantizer
],
weight_quantizers
:
List
[
Quantizer
],
output_quantizers
:
List
[
Quantizer
],
grad_output_quantizers
:
List
[
Quantizer
],
fuse_wgrad_accumulation
:
bool
,
cpu_offloading
:
bool
,
sequence_parallel
:
bool
,
activation_dtype
:
torch
.
dtype
,
is_grad_enabled
:
bool
,
module
,
skip_fp8_weight_update
,
*
weights_and_biases
,
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
num_gemms
=
len
(
m_splits
)
weights
=
weights_and_biases
[:
num_gemms
]
biases
=
weights_and_biases
[
num_gemms
:]
device
=
inp
.
device
# TODO Support MXFP8 # pylint: disable=fixme
if
fp8
and
FP8GlobalStateManager
.
get_fp8_recipe
().
mxfp8
():
raise
NotImplementedError
(
"BatchedLinear does not yet support MXFP8"
)
# TODO Support Float8 Current Scaling # pylint: disable=fixme
if
fp8
and
FP8GlobalStateManager
.
get_fp8_recipe
().
float8_current_scaling
():
raise
NotImplementedError
(
"BatchedLinear does not yet support Float8 Current Scaling"
)
# Make sure input dimensions are compatible
in_features
=
weights
[
0
].
shape
[
-
1
]
assert
inp
.
shape
[
-
1
]
==
in_features
,
"GEMM not possible"
inputmats
=
torch
.
split
(
inp
.
view
(
-
1
,
in_features
),
m_splits
)
if
fp8
:
assert_dim_for_fp8_exec
(
*
inputmats
,
*
weights
)
# Cast input to expected dtype
inputmats_no_fp8
=
[
cast_if_needed
(
mat
,
activation_dtype
)
for
mat
in
inputmats
]
inputmats
=
[]
weight_requires_grad
=
weights
[
0
].
requires_grad
if
input_quantizers
[
0
]
is
not
None
:
for
input_quantizer
in
input_quantizers
:
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
(
is_grad_enabled
and
weight_requires_grad
),
)
columnwise_usage
=
is_grad_enabled
and
inp
.
requires_grad
if
not
columnwise_usage
:
columnwise_usage
=
(
is_fp8_activation_recompute_enabled
()
and
not
in_fp8_activation_recompute_phase
()
)
if
weight_quantizers
[
0
]
is
not
None
:
for
weight_quantizer
in
weight_quantizers
:
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
columnwise_usage
)
if
output_quantizers
[
0
]
is
not
None
:
for
output_quantizer
in
output_quantizers
:
output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
fp8
:
inputmats
=
tex
.
fused_multi_quantize
(
inputmats_no_fp8
,
None
,
input_quantizers
,
TE_DType
[
activation_dtype
]
)
weights_fp8
=
[]
bias_dtype
=
torch
.
bfloat16
if
activation_dtype
==
torch
.
float32
else
activation_dtype
if
not
isinstance
(
weights
[
0
],
QuantizedTensor
):
# FP8 cast to workspace buffer
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
for
i
in
range
(
num_gemms
):
weight_fp8
=
module
.
get_weight_workspace
(
tensor
=
weights
[
i
],
quantizer
=
weight_quantizers
[
i
],
cache_name
=
(
None
if
is_first_microbatch
is
None
else
f
"weight
{
i
}
"
),
update_workspace
=
update_workspace
,
skip_update_flag
=
skip_fp8_weight_update
,
)
weights_fp8
.
append
(
weight_fp8
)
else
:
weights_fp8
=
weights
else
:
inputmats
=
inputmats_no_fp8
bias_dtype
=
activation_dtype
weights_fp8
=
[
cast_if_needed
(
weight
,
activation_dtype
)
for
weight
in
weights
]
biases
=
[
cast_if_needed
(
bias
,
bias_dtype
)
for
bias
in
biases
]
if
use_bias
else
biases
out
=
torch
.
empty
(
[
sum
(
m_splits
),
weights_fp8
[
0
].
size
(
0
)],
dtype
=
activation_dtype
,
device
=
device
,
)
_
=
general_batched_gemm
(
weights_fp8
,
inputmats
,
[
out
],
activation_dtype
,
get_multi_stream_cublas_batchgemm_workspace
(),
single_output
=
True
,
m_splits
=
m_splits
,
bias
=
biases
,
use_bias
=
use_bias
,
use_split_accumulator
=
_2X_ACC_FPROP
,
)
if
fp8_calibration
:
for
i
in
range
(
num_gemms
):
# amax of input
for
i
in
range
(
num_gemms
):
input_quantizers
[
i
].
calibrate
(
inputmats
[
i
])
for
i
in
range
(
num_gemms
):
weight_quantizers
[
i
].
calibrate
(
weights
[
i
])
if
is_grad_enabled
:
ctx
.
weights_shape_1
=
weights
[
0
].
shape
[
1
]
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
*
inputmats
,
*
weights_fp8
,
*
biases
)
ctx
.
save_for_backward
(
*
tensors_to_save
)
ctx
.
tensor_objects
=
tensor_objects
ctx
.
weights_requires_grad
=
weights
[
0
].
requires_grad
if
fuse_wgrad_accumulation
and
ctx
.
weights_requires_grad
:
ctx
.
main_grads
=
[
weights
[
i
].
main_grad
for
i
in
range
(
num_gemms
)]
else
:
ctx
.
main_grads
=
[
None
]
*
num_gemms
ctx
.
device
=
device
ctx
.
grad_output_quantizers
=
grad_output_quantizers
ctx
.
m_splits
=
m_splits
ctx
.
num_gemms
=
num_gemms
ctx
.
activation_dtype
=
activation_dtype
ctx
.
fp8
=
fp8
ctx
.
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
ctx
.
cpu_offloading
=
cpu_offloading
ctx
.
is_first_microbatch
=
is_first_microbatch
ctx
.
use_bias
=
use_bias
ctx
.
sequence_parallel
=
sequence_parallel
ctx
.
inp_shape
=
inp
.
shape
ctx
.
requires_dgrad
=
inp
.
requires_grad
ctx
.
reduce_and_update_bwd_fp8_tensors
=
False
if
ctx
.
fp8
and
requires_grad
(
inp
,
weights
[
0
],
biases
[
0
]):
ctx
.
reduce_and_update_bwd_fp8_tensors
=
(
ctx
.
reduce_and_update_bwd_fp8_tensors
or
FP8GlobalStateManager
.
is_first_fp8_module
()
)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
return
out
.
view
(
-
1
,
*
inp
.
shape
[
1
:
-
1
],
out
.
shape
[
-
1
])
@
staticmethod
def
backward
(
ctx
,
grad_output
:
torch
.
Tensor
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
...]:
# pylint: disable=missing-function-docstring
with
torch
.
cuda
.
nvtx
.
range
(
"_BatchedLinear_backward"
):
saved_tensors
=
restore_from_saved
(
ctx
.
tensor_objects
,
ctx
.
saved_tensors
)
N
=
ctx
.
num_gemms
inputmats
=
saved_tensors
[:
N
]
weights
=
saved_tensors
[
N
:
2
*
N
]
biases
=
saved_tensors
[
2
*
N
:
3
*
N
]
main_grads
=
ctx
.
main_grads
if
ctx
.
cpu_offloading
and
ctx
.
fuse_wgrad_accumulation
:
# TOSO
for
i
in
ctx
.
num_gemms
:
w
=
torch
.
nn
.
Parameter
(
weights
[
i
],
weights
[
i
].
requires_grad
)
w
.
main_grad
=
main_grads
[
i
]
weights
[
i
]
=
w
# preprocess grad_output
grad_output
=
grad_output
.
contiguous
()
grad_output_mats
=
torch
.
split
(
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]),
ctx
.
m_splits
)
grad_output
=
[
None
]
*
ctx
.
num_gemms
grad_biases
=
[
None
]
*
ctx
.
num_gemms
if
ctx
.
fp8
:
if
ctx
.
use_bias
:
for
i
in
range
(
ctx
.
num_gemms
):
grad_biases
[
i
],
grad_output
[
i
]
=
tex
.
bgrad_quantize
(
grad_output_mats
[
i
],
ctx
.
grad_output_quantizers
[
i
]
)
else
:
grad_output
=
tex
.
fused_multi_quantize
(
grad_output_mats
,
None
,
ctx
.
grad_output_quantizers
,
TE_DType
[
ctx
.
activation_dtype
],
)
else
:
grad_output
=
grad_output_mats
if
ctx
.
is_first_microbatch
is
not
None
:
accumulate_wgrad_into_param_main_grad
=
(
ctx
.
fuse_wgrad_accumulation
and
not
ctx
.
is_first_microbatch
)
else
:
accumulate_wgrad_into_param_main_grad
=
ctx
.
fuse_wgrad_accumulation
if
ctx
.
requires_dgrad
:
dgrad
=
torch
.
empty
(
(
sum
(
ctx
.
m_splits
),
ctx
.
weights_shape_1
),
dtype
=
ctx
.
activation_dtype
,
device
=
ctx
.
device
,
)
general_batched_gemm
(
weights
,
grad_output
,
[
dgrad
],
ctx
.
activation_dtype
,
get_multi_stream_cublas_batchgemm_workspace
(),
single_output
=
True
,
layout
=
"NN"
,
m_splits
=
ctx
.
m_splits
,
grad
=
True
,
use_split_accumulator
=
_2X_ACC_DGRAD
,
)
if
ctx
.
weights_requires_grad
:
if
ctx
.
fuse_wgrad_accumulation
:
wgrad_list
=
main_grads
else
:
wgrad_list
=
[
torch
.
empty
(
w
.
size
(),
dtype
=
ctx
.
activation_dtype
,
device
=
ctx
.
device
)
for
w
in
weights
]
# WGRAD
_
,
grad_biases_
,
_
=
general_batched_gemm
(
inputmats
,
grad_output
,
wgrad_list
,
ctx
.
activation_dtype
,
get_multi_stream_cublas_batchgemm_workspace
(),
layout
=
"NT"
,
grad
=
True
,
m_splits
=
ctx
.
m_splits
,
use_bias
=
ctx
.
use_bias
if
grad_biases
[
0
]
is
None
else
None
,
bias
=
biases
,
use_split_accumulator
=
_2X_ACC_WGRAD
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
)
for
i
in
range
(
ctx
.
num_gemms
):
if
grad_biases
[
i
]
is
None
:
grad_biases
[
i
]
=
grad_biases_
[
i
]
del
grad_biases_
# Deallocate input tensor
clear_tensor_data
(
*
inputmats
)
def
handle_custom_ddp_from_mcore
(
w
,
wgrad
):
if
ctx
.
weights_requires_grad
:
if
ctx
.
fuse_wgrad_accumulation
and
hasattr
(
w
,
"grad_added_to_main_grad"
):
w
.
grad_added_to_main_grad
=
True
if
getattr
(
w
,
"zero_out_wgrad"
,
False
):
wgrad
=
torch
.
zeros
(
w
.
main_grad
.
shape
,
dtype
=
w
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
else
:
wgrad
=
torch
.
empty
(
w
.
main_grad
.
shape
,
dtype
=
w
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
elif
ctx
.
fuse_wgrad_accumulation
:
wgrad
=
None
else
:
wgrad
=
None
return
wgrad
wgrad_list
=
[
handle_custom_ddp_from_mcore
(
w
,
wgrad
)
for
w
,
wgrad
in
zip
(
weights
,
wgrad_list
)
]
else
:
wgrad_list
=
[
None
]
*
ctx
.
num_gemms
if
not
ctx
.
use_bias
:
grad_biases
=
[
None
]
*
ctx
.
num_gemms
if
ctx
.
reduce_and_update_bwd_fp8_tensors
and
not
is_graph_capturing
():
FP8GlobalStateManager
.
reduce_and_update_fp8_tensors
(
forward
=
False
)
return
(
dgrad
.
view
(
ctx
.
inp_shape
)
if
ctx
.
requires_dgrad
else
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
# is_grad_enabled
None
,
# is_grad_enabled
*
wgrad_list
,
*
grad_biases
,
)
class
BatchedLinear
(
TransformerEngineBaseModule
):
"""Applies linear transformations to the incoming data list
:math:`y_i = x_iA_i^T + b_i` in a batched way.
Parameters
----------
num_gemms : int
number of GEMMs to be performed simutaneously.
in_features : int
size of each input sample.
out_features : int
size of each output sample.
bias : bool, default = `True`
if set to `False`, the layer will not learn an additive bias.
init_method : Callable, default = `None`
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
get_rng_state_tracker : Callable, default = `None`
used to get the random number generator state tracker for initializing weights.
rng_tracker_name : str, default = `None`
the param passed to get_rng_state_tracker to get the specific rng tracker.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default = `torch.get_default_dtype()`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
"""
def
__init__
(
self
,
num_gemms
:
int
,
in_features
:
int
,
out_features
:
int
,
sequence_parallel
:
bool
=
False
,
fuse_wgrad_accumulation
:
bool
=
False
,
tp_group
:
Optional
[
dist_group_type
]
=
None
,
tp_size
:
int
=
1
,
get_rng_state_tracker
:
Optional
[
Callable
]
=
None
,
rng_tracker_name
:
Optional
[
str
]
=
None
,
init_method
:
Optional
[
Callable
]
=
None
,
bias
:
bool
=
True
,
return_bias
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
parallel_mode
:
Optional
[
str
]
=
None
,
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
,
ub_overlap_rs
:
bool
=
False
,
ub_overlap_ag
:
bool
=
False
,
ub_name
:
Optional
[
str
]
=
None
,
)
->
None
:
super
().
__init__
()
params_dtype
=
torch
.
get_default_dtype
()
if
params_dtype
is
None
else
params_dtype
self
.
num_gemms
=
num_gemms
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
self
.
use_bias
=
bias
self
.
return_bias
=
return_bias
self
.
apply_bias
=
bias
and
not
return_bias
self
.
ub_overlap_rs
=
ub_overlap_rs
self
.
ub_overlap_ag
=
ub_overlap_ag
self
.
ub_name
=
ub_name
assert
(
not
ub_overlap_rs
and
not
ub_overlap_ag
),
"BatchedLinear doesn't support Userbuffer overlap."
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
rng_tracker_name
=
rng_tracker_name
self
.
_offsets
=
{
"input"
:
0
,
"weight"
:
num_gemms
,
"output"
:
2
*
num_gemms
,
"grad_output"
:
0
}
if
tp_group
is
None
:
self
.
tp_size
=
tp_size
if
tp_size
==
1
:
self
.
set_tensor_parallel_group
(
tp_group
)
else
:
self
.
tp_size
=
get_distributed_world_size
(
tp_group
)
self
.
set_tensor_parallel_group
(
tp_group
)
self
.
set_nccl_overlap_warning_if_tp
()
self
.
parallel_mode
=
parallel_mode
assert
(
self
.
parallel_mode
in
GemmParallelModes
),
f
"parallel_mode
{
parallel_mode
}
not supported"
if
self
.
parallel_mode
==
"column"
:
self
.
out_features
=
divide
(
self
.
out_features
,
self
.
tp_size
)
elif
self
.
parallel_mode
==
"row"
:
self
.
in_features
=
divide
(
self
.
in_features
,
self
.
tp_size
)
self
.
sequence_parallel
=
(
self
.
tp_size
>
1
)
and
sequence_parallel
for
i
in
range
(
self
.
num_gemms
):
# Construct weight parameter
self
.
register_parameter
(
f
"weight
{
i
}
"
,
torch
.
nn
.
Parameter
(
torch
.
empty
(
self
.
out_features
,
self
.
in_features
,
device
=
device
,
dtype
=
params_dtype
,
),
),
init_fn
=
init_method
,
get_rng_state_tracker
=
get_rng_state_tracker
,
fp8_meta_index
=
self
.
_offsets
[
"weight"
]
+
i
,
)
# Construct bias parameters if needed
if
self
.
use_bias
:
self
.
register_parameter
(
f
"bias
{
i
}
"
,
torch
.
nn
.
Parameter
(
torch
.
empty
(
self
.
out_features
,
device
=
device
,
dtype
=
params_dtype
,
),
),
init_fn
=
init_method_constant
(
0.0
),
)
else
:
bias
=
torch
.
Tensor
().
to
(
dtype
=
params_dtype
,
device
=
device
)
setattr
(
self
,
f
"bias
{
i
}
"
,
bias
)
if
self
.
primary_weights_in_fp8
:
self
.
init_fp8_metadata
(
num_gemms
=
self
.
num_gemms
)
self
.
reset_parameters
(
defer_init
=
device
==
"meta"
)
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
if
self
.
parallel_mode
==
"row"
and
self
.
apply_bias
:
self
.
gemm_bias_unfused_add
=
True
else
:
self
.
gemm_bias_unfused_add
=
False
def
reset_parameters
(
self
,
defer_init
=
False
):
super
().
reset_parameters
(
defer_init
=
defer_init
)
if
not
defer_init
:
# Set parallelism attributes for linear weights
for
i
in
range
(
self
.
num_gemms
):
set_tensor_model_parallel_attributes
(
tensor
=
getattr
(
self
,
f
"weight
{
i
}
"
),
is_parallel
=
True
,
dim
=
1
if
self
.
parallel_mode
==
"row"
else
0
,
stride
=
1
,
)
# Set parallelism attributes for linear biases
if
self
.
use_bias
:
for
i
in
range
(
self
.
num_gemms
):
if
self
.
parallel_mode
==
"row"
:
setattr
(
getattr
(
self
,
f
"bias
{
i
}
"
),
"sequence_parallel"
,
self
.
sequence_parallel
,
)
elif
self
.
parallel_mode
==
"column"
:
set_tensor_model_parallel_attributes
(
getattr
(
self
,
f
"bias
{
i
}
"
),
True
,
0
,
1
)
@
no_torch_dynamo
()
def
forward
(
self
,
inp
:
torch
.
Tensor
,
m_splits
:
List
[
int
],
is_first_microbatch
:
Optional
[
bool
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
...]]:
"""
Apply the linear transformation to the input.
Parameters
----------
inp : torch.Tensor
Input tensor.
m_splits : List[int]
List of integers representing the split of the input tensor.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
* it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being
produced)
"""
assert
not
isinstance
(
inp
,
Float8Tensor
),
"BatchedLinear doesn't support input tensor in FP8."
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
skip_fp8_weight_update
=
FP8GlobalStateManager
.
get_skip_fp8_weight_update_tensor
()
if
skip_fp8_weight_update
is
not
None
:
is_first_microbatch
=
False
with
self
.
prepare_forward
(
inp
,
num_gemms
=
self
.
num_gemms
)
as
inp
:
weight_tensors
=
[
getattr
(
self
,
f
"weight
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
bias_tensors
=
[
getattr
(
self
,
f
"bias
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
if
not
self
.
fp8
:
weight_tensors
=
[
w
.
dequantize
()
if
isinstance
(
w
,
QuantizedTensor
)
else
w
for
w
in
weight_tensors
]
input_quantizers
,
weight_quantizers
,
output_quantizers
=
(
[
None
]
*
self
.
num_gemms
,
[
None
]
*
self
.
num_gemms
,
[
None
]
*
self
.
num_gemms
,
)
grad_output_quantizers
,
_
=
[
None
]
*
self
.
num_gemms
,
[
None
]
*
self
.
num_gemms
if
self
.
fp8
:
input_quantizers
=
[
self
.
quantizers
[
"scaling_fwd"
][
self
.
_offsets
[
"input"
]
+
i
]
for
i
in
range
(
self
.
num_gemms
)
]
for
i
in
range
(
self
.
num_gemms
):
input_quantizers
[
i
].
internal
=
True
weight_quantizers
=
[
self
.
quantizers
[
"scaling_fwd"
][
self
.
_offsets
[
"weight"
]
+
i
]
for
i
in
range
(
self
.
num_gemms
)
]
for
i
in
range
(
self
.
num_gemms
):
weight_quantizers
[
i
].
internal
=
True
if
torch
.
is_grad_enabled
():
grad_output_quantizers
=
[
self
.
quantizers
[
"scaling_bwd"
][
self
.
_offsets
[
"input"
]
+
i
]
for
i
in
range
(
self
.
num_gemms
)
]
for
i
in
range
(
self
.
num_gemms
):
grad_output_quantizers
[
i
].
internal
=
True
if
torch
.
is_grad_enabled
():
linear_fn
=
_BatchedLinear
.
apply
args
=
[]
else
:
linear_fn
=
_BatchedLinear
.
forward
args
=
[
None
]
args
+=
(
inp
,
m_splits
,
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
,
is_first_microbatch
,
self
.
fp8
,
self
.
fp8_calibration
,
input_quantizers
,
weight_quantizers
,
output_quantizers
,
grad_output_quantizers
,
self
.
fuse_wgrad_accumulation
,
is_cpu_offload_enabled
(),
self
.
sequence_parallel
,
self
.
activation_dtype
,
torch
.
is_grad_enabled
(),
self
,
skip_fp8_weight_update
,
*
weight_tensors
,
*
bias_tensors
,
)
out
=
linear_fn
(
*
args
)
if
self
.
gemm_bias_unfused_add
:
out_shape
=
out
.
shape
out
=
torch
.
cat
(
[
o
+
cast_if_needed
(
b
,
self
.
activation_dtype
)
for
o
,
b
in
zip
(
torch
.
split
(
out
.
view
(
-
1
,
self
.
out_features
),
m_splits
),
bias_tensors
)
]
).
view
(
out_shape
)
if
self
.
return_bias
:
return
out
,
[
cast_if_needed
(
b
,
self
.
activation_dtype
)
for
b
in
bias_tensors
]
return
out
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
c520cba3
...
...
@@ -12,6 +12,7 @@ from operator import mul as multiply_op
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn
import
init
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine_torch
as
tex
...
...
@@ -1454,7 +1455,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_weight
=
fc2_weight
.
from_float8
()
# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if
self
.
bias_gelu_nvfusion
and
not
use_reentrant_activation_recompute
():
if
(
not
IS_HIP_EXTENSION
and
self
.
bias_gelu_nvfusion
and
not
use_reentrant_activation_recompute
()
):
self
.
bias_gelu_nvfusion
=
False
if
torch
.
is_grad_enabled
():
...
...
transformer_engine/pytorch/triton/permutation.py
View file @
c520cba3
...
...
@@ -11,7 +11,14 @@ import triton
import
triton.language
as
tl
from
transformer_engine_torch
import
DType
as
TE_DType
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
if
IS_HIP_EXTENSION
:
e5m2_data_type
=
tl
.
float8e5b16
e4m3_data_type
=
tl
.
float8e4b8
else
:
e5m2_data_type
=
tl
.
float8e5
e4m3_data_type
=
tl
.
float8e4nv
@
triton
.
jit
def
_row_id_map_pass_1_kernel
(
...
...
transformer_engine/pytorch/utils.py
View file @
c520cba3
...
...
@@ -13,7 +13,7 @@ import torch
import
transformer_engine.pytorch.cpp_extensions
as
ext
from
.tensor.quantized_tensor
import
QuantizedTensor
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
def
requires_grad
(
*
tensors
:
Tuple
[
Optional
[
torch
.
Tensor
],
...])
->
None
:
"""Check if any of the given tensors require gradient."""
...
...
@@ -242,12 +242,34 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None:
f
"but got tensor with dims=
{
list
(
tensor
.
size
())
}
"
)
if
IS_HIP_EXTENSION
:
def
is_mi200
():
"""check whether this machine is mi200/210/250"""
import
re
return
(
re
.
search
(
'AMD Instinct MI2.0'
,
torch
.
cuda
.
get_device_name
(
torch
.
cuda
.
current_device
()))
is
not
None
)
def
is_K100_AI
():
"""check whether this machine is K100_AI"""
import
re
return
(
re
.
search
(
'K100_AI'
,
torch
.
cuda
.
get_device_name
(
torch
.
cuda
.
current_device
()))
is
not
None
)
def
is_BW3000
():
"""check whether this machine is BW"""
import
re
return
(
re
.
search
(
'BW'
,
torch
.
cuda
.
get_device_name
(
torch
.
cuda
.
current_device
()))
is
not
None
)
def
is_bf16_compatible
()
->
None
:
"""Replaces torch.cuda.is_bf16_compatible() with an explicit
check on device compute capability to enforce sm_80 or higher.
"""
return
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
if
IS_HIP_EXTENSION
:
# only MI200 and MI300 machines support bf16
if
get_device_compute_capability
()
==
(
9
,
4
)
or
is_mi200
()
or
is_K100_AI
()
or
is_BW3000
():
return
True
else
:
return
False
else
:
return
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
def
non_tn_fp8_gemm_supported
()
->
bool
:
...
...
@@ -260,6 +282,9 @@ def non_tn_fp8_gemm_supported() -> bool:
@
functools
.
lru_cache
(
maxsize
=
None
)
def
get_cudnn_version
()
->
Tuple
[
int
,
int
,
int
]:
"""Runtime cuDNN version (major, minor, patch)"""
# ROCm fused attn does not use cudnn, return high numbers to avoid tests filtering out
if
IS_HIP_EXTENSION
:
return
(
99
,
0
,
0
)
encoded_version
=
ext
.
get_cudnn_version
()
major_version_magnitude
=
1000
if
encoded_version
<
90000
else
10000
major
,
encoded_version
=
divmod
(
encoded_version
,
major_version_magnitude
)
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment