Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
2389ed3f
Commit
2389ed3f
authored
Aug 27, 2025
by
yuguo
Browse files
Merge branch 'release_v2.7' of
https://github.com/NVIDIA/TransformerEngine
into release_v2.7
parents
87e3e56e
58c3ac80
Changes
22
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
244 additions
and
193 deletions
+244
-193
.gitmodules
.gitmodules
+0
-3
build_tools/VERSION.txt
build_tools/VERSION.txt
+1
-1
examples/jax/encoder/test_single_gpu_encoder.py
examples/jax/encoder/test_single_gpu_encoder.py
+3
-1
examples/jax/mnist/test_single_gpu_mnist.py
examples/jax/mnist/test_single_gpu_mnist.py
+3
-1
tests/jax/test_distributed_layernorm_mlp.py
tests/jax/test_distributed_layernorm_mlp.py
+2
-2
tests/jax/test_layer.py
tests/jax/test_layer.py
+17
-4
tests/pytorch/attention/test_attention.py
tests/pytorch/attention/test_attention.py
+2
-0
transformer_engine/common/fused_attn/fused_attn.cpp
transformer_engine/common/fused_attn/fused_attn.cpp
+3
-2
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+15
-13
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
...e/common/transpose/quantize_transpose_square_blockwise.cu
+20
-43
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
...e/common/transpose/quantize_transpose_vector_blockwise.cu
+12
-42
transformer_engine/common/util/cast_kernels.cuh
transformer_engine/common/util/cast_kernels.cuh
+21
-33
transformer_engine/jax/cpp_extensions/gemm.py
transformer_engine/jax/cpp_extensions/gemm.py
+10
-1
transformer_engine/jax/flax/transformer.py
transformer_engine/jax/flax/transformer.py
+26
-0
transformer_engine/jax/quantize/helper.py
transformer_engine/jax/quantize/helper.py
+0
-3
transformer_engine/jax/sharding.py
transformer_engine/jax/sharding.py
+6
-1
transformer_engine/pytorch/graph.py
transformer_engine/pytorch/graph.py
+25
-4
transformer_engine/pytorch/ops/basic/basic_linear.py
transformer_engine/pytorch/ops/basic/basic_linear.py
+30
-11
transformer_engine/pytorch/ops/fused/backward_linear_add.py
transformer_engine/pytorch/ops/fused/backward_linear_add.py
+24
-14
transformer_engine/pytorch/ops/fused/backward_linear_scale.py
...sformer_engine/pytorch/ops/fused/backward_linear_scale.py
+24
-14
No files found.
.gitmodules
View file @
2389ed3f
[submodule "3rdparty/googletest"]
path = 3rdparty/googletest
url = https://github.com/google/googletest.git
[submodule "3rdparty/cudnn-frontend"]
path = 3rdparty/cudnn-frontend
url = https://github.com/NVIDIA/cudnn-frontend.git
[submodule "3rdparty/hipify_torch"]
path = 3rdparty/hipify_torch
url = https://github.com/ROCm/hipify_torch.git
build_tools/VERSION.txt
View file @
2389ed3f
2.
8.0.dev
0
2.
7.
0
examples/jax/encoder/test_single_gpu_encoder.py
View file @
2389ed3f
...
...
@@ -219,7 +219,9 @@ def train_and_evaluate(args):
else
:
fp8_recipe
=
None
with
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
):
with
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
te
.
sharding
.
MeshResource
()
):
encoder
=
Net
(
num_embed
)
# We use nn.Embed, thus inputs need to be in int
inputs
=
jnp
.
zeros
(
input_shape
,
dtype
=
jnp
.
int32
)
...
...
examples/jax/mnist/test_single_gpu_mnist.py
View file @
2389ed3f
...
...
@@ -193,7 +193,9 @@ def train_and_evaluate(args):
else
:
fp8_recipe
=
None
with
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
):
with
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
te
.
sharding
.
MeshResource
()
):
cnn
=
Net
(
args
.
use_te
)
var_collect
=
cnn
.
init
(
init_rngs
,
jnp
.
empty
(
input_shape
,
dtype
=
jnp
.
bfloat16
))
tx
=
optax
.
sgd
(
args
.
lr
,
args
.
momentum
)
...
...
tests/jax/test_distributed_layernorm_mlp.py
View file @
2389ed3f
...
...
@@ -173,7 +173,7 @@ class TestDistributedLayernormMLP:
)
# Single GPU
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
):
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
MeshResource
()
):
single_jitter
=
jax
.
jit
(
value_and_grad_func
,
static_argnums
=
range
(
len
(
inputs
),
len
(
static_inputs
)
+
len
(
inputs
)),
...
...
@@ -330,7 +330,7 @@ class TestDistributedLayernormMLP:
with
use_jax_gemm
(
enabled
=
with_jax_gemm
):
# Single GPUs
with
fp8_autocast
(
enabled
=
use_fp8
,
fp8_recipe
=
fp8_recipe
):
with
fp8_autocast
(
enabled
=
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
MeshResource
()
):
ln_mlp_single
=
LayerNormMLP
(
layernorm_type
=
layernorm_type
,
intermediate_dim
=
INTERMEDIATE
,
...
...
tests/jax/test_layer.py
View file @
2389ed3f
...
...
@@ -28,6 +28,7 @@ from transformer_engine.jax.quantize import (
is_fp8_available
,
update_collections
,
)
from
transformer_engine.jax.sharding
import
MeshResource
,
global_shard_guard
@
pytest
.
fixture
(
autouse
=
True
,
scope
=
"function"
)
...
...
@@ -490,11 +491,17 @@ class BaseTester:
def
test_forward
(
self
,
data_shape
,
dtype
,
attrs
):
"""Test normal datatype forward"""
QuantizeConfig
.
finalize
()
# Ensure FP8 disabled.
with
global_shard_guard
(
MeshResource
()
):
# Empty MeshResource is used as we are running on a single device
self
.
runner
(
attrs
).
test_forward
(
data_shape
,
dtype
)
def
test_backward
(
self
,
data_shape
,
dtype
,
attrs
):
"""Test normal datatype backward"""
QuantizeConfig
.
finalize
()
# Ensure FP8 disabled.
with
global_shard_guard
(
MeshResource
()
):
# Empty MeshResource is used as we are running on a single device
self
.
runner
(
attrs
).
test_backward
(
data_shape
,
dtype
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
...
...
@@ -502,6 +509,9 @@ class BaseTester:
def
test_forward_with_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_recipe
):
"""Test forward with fp8 enabled"""
QuantizeConfig
.
initialize
(
fp8_recipe
=
fp8_recipe
)
with
global_shard_guard
(
MeshResource
()
):
# Empty MeshResource is used as we are running on a single device
self
.
runner
(
attrs
).
test_forward
(
data_shape
,
dtype
,
rtol
=
1e-4
,
atol
=
1e-3
)
QuantizeConfig
.
finalize
()
...
...
@@ -510,6 +520,9 @@ class BaseTester:
def
test_backward_with_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_recipe
):
"""Test backward with fp8 enabled"""
QuantizeConfig
.
initialize
(
fp8_recipe
=
fp8_recipe
)
with
global_shard_guard
(
MeshResource
()
):
# Empty MeshResource is used as we are running on a single device
self
.
runner
(
attrs
).
test_backward
(
data_shape
,
dtype
,
rtol
=
1e-4
,
atol
=
1e-3
)
QuantizeConfig
.
finalize
()
...
...
tests/pytorch/attention/test_attention.py
View file @
2389ed3f
...
...
@@ -274,6 +274,8 @@ model_configs_mla = {
"mla_3_0"
:
ModelConfig
(
8
,
1
,
16
,
128
,
max_seqlen_kv
=
2048
,
head_dim_v
=
64
),
# inference
"mla_3_1"
:
ModelConfig
(
8
,
1
,
16
,
256
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
# inference
"mla_3_2"
:
ModelConfig
(
8
,
1
,
16
,
192
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
# inference
"mla_3_3"
:
ModelConfig
(
8
,
1
,
16
,
160
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
# inference
"mla_3_4"
:
ModelConfig
(
8
,
1
,
16
,
160
,
max_seqlen_kv
=
2048
,
head_dim_v
=
160
),
# inference
}
...
...
transformer_engine/common/fused_attn/fused_attn.cpp
View file @
2389ed3f
...
...
@@ -252,8 +252,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(
head_dim_qk
==
192
&&
head_dim_v
==
128
&&
is_training
&&
sm_arch_
>=
100
&&
cudnn_runtime_version
>=
91100
))
&&
// 9.11/9.12 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
(
!
((
cudnn_runtime_version
==
91100
||
cudnn_runtime_version
==
91200
)
&&
is_training
&&
sm_arch_
==
90
&&
head_dim_qk
>=
128
&&
head_dim_v
>=
128
&&
(
!
((
cudnn_runtime_version
==
91100
||
cudnn_runtime_version
==
91200
||
cudnn_runtime_version
==
91300
)
&&
is_training
&&
sm_arch_
==
90
&&
head_dim_qk
>=
128
&&
head_dim_v
>=
128
&&
!
(
head_dim_qk
==
192
&&
head_dim_v
==
128
)
&&
head_dim_qk
!=
head_dim_v
)))
&&
// bias type
((
cudnn_runtime_version
<
8906
&&
bias_type
==
NVTE_Bias_Type
::
NVTE_NO_BIAS
)
||
...
...
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
2389ed3f
...
...
@@ -532,22 +532,22 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
&
epilogue
,
sizeof
(
epilogue
)));
if
(
counter
!=
nullptr
)
{
#if !(CUDA_VERSION >= 12020 && CU
BLAS
_VERSION
>=
13000)
NVTE_ERROR
(
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is "
,
#if !(CUDA_VERSION >= 12020 && CU
DA
_VERSION
<
13000)
NVTE_ERROR
(
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA vers
i
on is "
,
CUDA_VERSION
);
#endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR
(
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is "
,
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS vers
i
on is "
,
CUBLAS_VERSION
);
#endif
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \
CUBLAS_VERSION < 130000
NVTE_CHECK
(
cuda
::
cudart_version
()
>=
12020
&&
cuda
::
cudart_version
()
<
13000
,
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA verson is "
,
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA vers
i
on is "
,
cuda
::
cudart_version
());
NVTE_CHECK
(
cublas_version
()
>=
120205
&&
cublas_version
()
<
130000
,
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS verson is "
,
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS vers
i
on is "
,
cublas_version
());
if
(
m_split
==
0
)
m_split
=
1
;
if
(
n_split
==
0
)
n_split
=
1
;
...
...
@@ -783,20 +783,22 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
#ifndef __HIP_PLATFORM_AMD__
// Check CUDA and cuBLAS versions
#if !(CUDA_VERSION >= 12020 && CU
BLAS
_VERSION
>=
13000)
NVTE_ERROR
(
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is "
,
#if !(CUDA_VERSION >= 12020 && CU
DA
_VERSION
<
13000)
NVTE_ERROR
(
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA vers
i
on is "
,
CUDA_VERSION
);
#endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR
(
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is "
,
NVTE_ERROR
(
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is "
,
CUBLAS_VERSION
);
#endif
NVTE_CHECK
(
cuda
::
cudart_version
()
>=
12020
&&
cuda
::
cudart_version
()
<
13000
,
"Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA verson is "
,
NVTE_CHECK
(
cuda
::
cudart_version
()
>=
12020
&&
cuda
::
cudart_version
()
<
13000
,
"Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is "
,
cuda
::
cudart_version
());
NVTE_CHECK
(
cublas_version
()
>=
120205
&&
cublas_version
()
<
130000
,
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS verson is "
,
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS vers
i
on is "
,
cublas_version
());
#endif
...
...
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
View file @
2389ed3f
...
...
@@ -18,7 +18,6 @@
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/util/cuda_runtime.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
...
...
@@ -185,12 +184,6 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
}
}
// Trigger the next kernel here so that it's load from global memory can overlap with this kernel's
// store to global memory.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion
();
#endif
// Step 3: Store cast output, Step 4: do transpose within thread tile
OVecCast
tmp_output_c
;
...
...
@@ -426,12 +419,6 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
}
}
// Trigger the next kernel here so that it's load from global memory can overlap with this kernel's
// store to global memory.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion
();
#endif
// Step 3: Store cast output, Step 4: do transpose within thread tile
// Edge case: in the non-full tile case, there are three subcases
// for full thread tile, it's the same thing here
...
...
@@ -939,15 +926,6 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
#else
const
size_t
num_blocks_x
=
DIVUP
(
row_length
,
BLOCK_TILE_DIM
);
const
size_t
num_blocks_y
=
DIVUP
(
num_rows
,
BLOCK_TILE_DIM
);
dim3
grid
(
num_blocks_x
,
num_blocks_y
,
1
);
cudaLaunchAttribute
attribute
[
1
];
attribute
[
0
].
id
=
cudaLaunchAttributeProgrammaticStreamSerialization
;
attribute
[
0
].
val
.
programmaticStreamSerializationAllowed
=
1
;
cudaLaunchConfig_t
cfg
=
{
grid
,
THREADS_PER_BLOCK
,
0
,
stream
,
NULL
,
0
};
if
(
transformer_engine
::
cuda
::
sm_arch
(
transformer_engine
::
cuda
::
current_device
())
>=
90
)
{
cfg
.
attrs
=
attribute
;
cfg
.
numAttrs
=
1
;
}
#endif
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
...
...
@@ -962,6 +940,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
dim3
grid
(
num_blocks_x
,
num_blocks_y
,
1
);
const
bool
full_tile
=
row_length
%
block_len
==
0
&&
num_rows
%
block_len
==
0
;
#else
dim3
grid
(
num_blocks_x
,
num_blocks_y
,
1
);
const
bool
full_tile
=
row_length
%
BLOCK_TILE_DIM
==
0
&&
num_rows
%
BLOCK_TILE_DIM
==
0
;
#endif
...
...
@@ -972,21 +951,19 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
tensor_map_output_trans
=
get_tensor_map
<
OutputType
>
(
output_t
,
num_rows
,
row_length
);
}
cudaLaunchKernelEx
(
&
cfg
,
block_scaled_cast_transpose_kernel
<
kReturnTranspose
,
float
,
InputType
,
OutputType
>
,
block_scaled_cast_transpose_kernel
<
kReturnTranspose
,
float
,
InputType
,
OutputType
>
<<<
grid
,
THREADS_PER_BLOCK
,
0
,
stream
>>>
(
reinterpret_cast
<
const
InputType
*>
(
input
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output_t
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
epsilon
,
tensor_map_output_trans
,
pow_2_scale
);
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
epsilon
,
tensor_map_output_trans
,
pow_2_scale
);
}
else
{
cudaLaunchKernelEx
(
&
cfg
,
block_scaled_cast_transpose_kernel_notaligned
<
kReturnTranspose
,
float
,
InputType
,
OutputType
>
,
block_scaled_cast_transpose_kernel_notaligned
<
kReturnTranspose
,
float
,
InputType
,
OutputType
>
<<<
grid
,
THREADS_PER_BLOCK
,
0
,
stream
>>>
(
reinterpret_cast
<
const
InputType
*>
(
input
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output_t
.
dptr
),
...
...
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
View file @
2389ed3f
...
...
@@ -24,7 +24,6 @@
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h"
#include "common/util/cuda_runtime.h"
#include "common/utils.cuh"
namespace
transformer_engine
{
...
...
@@ -252,14 +251,6 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
__syncthreads
();
// If not return columnwise, we trigger the next kernel here so that it's load from global memory
// can overlap with this kernel's return rowwise.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if
(
!
return_columnwise_gemm_ready
&&
!
return_columnwise_compact
)
{
cudaTriggerProgrammaticLaunchCompletion
();
}
#endif
// Step 2: Cast and store to output_c
if
(
return_rowwise
)
{
constexpr
int
r_stride
=
...
...
@@ -365,14 +356,6 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}
}
// If return columnwise, we trigger the next kernel here so that it's load from global memory
// can overlap with this kernel's return columnwise.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if
(
return_columnwise_gemm_ready
||
return_columnwise_compact
)
{
cudaTriggerProgrammaticLaunchCompletion
();
}
#endif
// Step 3 (return_columnwise_gemm_ready): Transpose, cast and store to output_t
if
(
return_columnwise_gemm_ready
)
{
constexpr
int
c_stride
=
...
...
@@ -1448,12 +1431,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
#else
const
size_t
num_blocks_x
=
DIVUP
(
row_length
,
(
size_t
)
kTileDim
);
const
size_t
num_blocks_y
=
DIVUP
(
num_rows
,
(
size_t
)
kTileDim
);
dim3
grid
(
num_blocks_x
,
num_blocks_y
,
1
);
cudaLaunchAttribute
attribute
[
1
];
attribute
[
0
].
id
=
cudaLaunchAttributeProgrammaticStreamSerialization
;
attribute
[
0
].
val
.
programmaticStreamSerializationAllowed
=
1
;
#endif
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
dtype
,
InputType
,
...
...
@@ -1463,6 +1441,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
dim3
grid
(
num_blocks_x
,
num_blocks_y
,
1
);
const
bool
full_tile
=
row_length
%
block_len
==
0
&&
num_rows
%
block_len
==
0
;
#else
dim3
grid
(
num_blocks_x
,
num_blocks_y
,
1
);
const
bool
full_tile
=
row_length
%
kTileDim
==
0
&&
num_rows
%
kTileDim
==
0
;
#endif
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
...
...
@@ -1532,30 +1511,21 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
}
#else
size_t
smem_bytes
=
kSMemSize
*
sizeof
(
InputType
);
cudaLaunchConfig_t
cfg
=
{
grid
,
kThreadsPerBlock
,
smem_bytes
,
stream
,
NULL
,
0
};
if
(
transformer_engine
::
cuda
::
sm_arch
(
transformer_engine
::
cuda
::
current_device
())
>=
90
)
{
cfg
.
attrs
=
attribute
;
cfg
.
numAttrs
=
1
;
}
// shared memory must be requested up
if
(
smem_bytes
>=
48
*
1024
)
{
cudaError_t
err
=
cudaFuncSetAttribute
(
&
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_bytes
);
NVTE_CHECK
(
err
==
cudaSuccess
,
"Failed to set dynamic shared memory size."
);
}
cudaLaunchKernelEx
(
&
cfg
,
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
,
}
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
<<<
grid
,
kThreadsPerBlock
,
smem_bytes
,
stream
>>>
(
reinterpret_cast
<
const
InputType
*>
(
input
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output_t
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
epsilon
,
rowwise_option
,
columnwise_option
,
pow2_scale
);
reinterpret_cast
<
float
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
epsilon
,
rowwise_option
,
columnwise_option
,
pow2_scale
);
#endif
)
// kAligned
)
// OutputType
...
...
transformer_engine/common/util/cast_kernels.cuh
View file @
2389ed3f
...
...
@@ -205,11 +205,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Wait for the data to have arrived
ptx
::
mbarrier_wait_parity
(
&
mbar
[
stage
],
parity
);
// Trigger the next kernel, so its TMA load can be overlapped with the current kernel
if
(
stage
==
STAGES
-
1
)
{
cudaTriggerProgrammaticLaunchCompletion
();
}
float
thread_amax
=
0.0
f
;
if
constexpr
(
COLWISE_SCALING
)
{
const
size_t
shmem_offset_base_colwise
=
buff
*
BUFF_DIM
+
tid_X_colwise
;
...
...
@@ -1139,13 +1134,6 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
const
size_t
dshmem_size
=
in_mem
+
out_mem
+
TMA_SHMEM_ALIGNMENT
;
cudaLaunchConfig_t
cfg
=
{
grid
,
block_size
,
dshmem_size
,
stream
,
NULL
,
0
};
// This kernel will only be called on sm100+, so no need to check sm_arch
cudaLaunchAttribute
attribute
[
1
];
attribute
[
0
].
id
=
cudaLaunchAttributeProgrammaticStreamSerialization
;
attribute
[
0
].
val
.
programmaticStreamSerializationAllowed
=
1
;
cfg
.
attrs
=
attribute
;
cfg
.
numAttrs
=
1
;
switch
(
scaling_type
)
{
case
ScalingType
::
ROWWISE
:
cudaFuncSetAttribute
(
...
...
@@ -1153,13 +1141,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
false
,
CHUNK_DIM_Y
,
CHUNK_DIM_X
,
THREADS_PER_CHUNK
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
dshmem_size
);
cudaLaunchKernelEx
(
&
cfg
,
cast_mxfp8_2D_kernel
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ParamOP
,
OP
,
IType
,
OType
,
true
,
false
,
CHUNK_DIM_Y
,
CHUNK_DIM_X
,
THREADS_PER_CHUNK
>
,
false
,
CHUNK_DIM_Y
,
CHUNK_DIM_X
,
THREADS_PER_CHUNK
>
<<<
grid
,
block_size
,
dshmem_size
,
stream
>>>
(
tensor_map_input
,
tensor_map_act_input
,
tensor_map_output_rowwise
,
tensor_map_output_colwise
,
scales_rowwise_ptr
,
scales_colwise_ptr
,
noop_ptr
,
workspace_ptr
,
amax_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scale_stride_colwise
);
workspace_ptr
,
amax_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scale_stride_colwise
);
break
;
case
ScalingType
::
COLWISE
:
cudaFuncSetAttribute
(
...
...
@@ -1167,13 +1155,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
true
,
CHUNK_DIM_Y
,
CHUNK_DIM_X
,
THREADS_PER_CHUNK
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
dshmem_size
);
cudaLaunchKernelEx
(
&
cfg
,
cast_mxfp8_2D_kernel
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ParamOP
,
OP
,
IType
,
OType
,
false
,
true
,
CHUNK_DIM_Y
,
CHUNK_DIM_X
,
THREADS_PER_CHUNK
>
,
true
,
CHUNK_DIM_Y
,
CHUNK_DIM_X
,
THREADS_PER_CHUNK
>
<<<
grid
,
block_size
,
dshmem_size
,
stream
>>>
(
tensor_map_input
,
tensor_map_act_input
,
tensor_map_output_rowwise
,
tensor_map_output_colwise
,
scales_rowwise_ptr
,
scales_colwise_ptr
,
noop_ptr
,
workspace_ptr
,
amax_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scale_stride_colwise
);
workspace_ptr
,
amax_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scale_stride_colwise
);
break
;
case
ScalingType
::
BIDIMENSIONAL
:
cudaFuncSetAttribute
(
...
...
@@ -1181,13 +1169,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
true
,
CHUNK_DIM_Y
,
CHUNK_DIM_X
,
THREADS_PER_CHUNK
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
dshmem_size
);
cudaLaunchKernelEx
(
&
cfg
,
cast_mxfp8_2D_kernel
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ParamOP
,
OP
,
IType
,
OType
,
true
,
true
,
CHUNK_DIM_Y
,
CHUNK_DIM_X
,
THREADS_PER_CHUNK
>
,
cast_mxfp8_2D_kernel
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ParamOP
,
OP
,
IType
,
OType
,
true
,
true
,
CHUNK_DIM_Y
,
CHUNK_DIM_X
,
THREADS_PER_CHUNK
>
<<<
grid
,
block_size
,
dshmem_size
,
stream
>>>
(
tensor_map_input
,
tensor_map_act_input
,
tensor_map_output_rowwise
,
tensor_map_output_colwise
,
scales_rowwise_ptr
,
scales_colwise_ptr
,
noop_ptr
,
workspace_ptr
,
amax_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scale_stride_colwise
);
workspace_ptr
,
amax_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scale_stride_colwise
);
break
;
}
...
...
transformer_engine/jax/cpp_extensions/gemm.py
View file @
2389ed3f
...
...
@@ -8,6 +8,7 @@ import operator
from
collections.abc
import
Iterable
from
typing
import
Tuple
,
Sequence
,
Union
from
functools
import
partial
,
reduce
import
warnings
import
jax
import
jax.numpy
as
jnp
...
...
@@ -34,6 +35,7 @@ from ..quantize import (
is_fp8_gemm_with_all_layouts_supported
,
apply_padding_to_scale_inv
,
)
from
..sharding
import
global_mesh_resource
from
.misc
import
get_padded_spec
...
...
@@ -490,7 +492,8 @@ class GemmPrimitive(BasePrimitive):
# Non-contracting dims of RHS always needs to be gathered along the FSDP axis
rhs_non_cspecs
=
tuple
(
None
if
spec
is
not
None
and
"fsdp"
in
spec
else
spec
for
spec
in
rhs_non_cspecs
None
if
spec
is
not
None
and
spec
==
global_mesh_resource
().
fsdp_resource
else
spec
for
spec
in
rhs_non_cspecs
)
# Non-contracting dims of LHS to be gathered along the SP axis.
...
...
@@ -656,6 +659,12 @@ class GemmPrimitive(BasePrimitive):
prefix
=
"GemmPrimitive_"
warnings
.
warn
(
"Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now,"
" please turn off Shardy by exporting the environment variable"
" 'JAX_USE_SHARDY_PARTITIONER=0' if you experience any problems."
)
def
_generate_operand_rules
(
name
,
ndim
,
cdims
):
specs
=
[]
ldims
=
tuple
(
i
for
i
in
range
(
ndim
)
if
i
not
in
cdims
)
...
...
transformer_engine/jax/flax/transformer.py
View file @
2389ed3f
...
...
@@ -26,6 +26,7 @@ from .module import LayerNorm, Softmax
from
..attention
import
AttnBiasType
,
AttnMaskType
,
QKVLayout
,
SequenceDescriptor
from
..attention
import
is_fused_attn_kernel_available
,
make_swa_mask
,
canonicalize_attn_mask_type
from
..attention
import
fused_attn
from
..attention
import
CPStrategy
from
..softmax
import
SoftmaxType
from
..sharding
import
num_of_devices
from
..sharding
import
get_sharding_map_logic_axis_to_mesh_axis
...
...
@@ -274,6 +275,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq
:
Optional
[
int
]
=
1
context_parallel_causal_load_balanced
:
bool
=
False
context_parallel_axis
:
str
=
""
context_parallel_strategy
:
CPStrategy
=
CPStrategy
.
DEFAULT
context_checkpoint_name
:
str
=
"context"
@
nn
.
compact
...
...
@@ -323,6 +325,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq
=
self
.
max_segments_per_seq
,
context_parallel_causal_load_balanced
=
self
.
context_parallel_causal_load_balanced
,
context_parallel_axis
=
self
.
context_parallel_axis
,
context_parallel_strategy
=
self
.
context_parallel_strategy
,
context_checkpoint_name
=
self
.
context_checkpoint_name
,
)
elif
self
.
qkv_layout
.
is_kvpacked
():
...
...
@@ -350,6 +353,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq
=
self
.
max_segments_per_seq
,
context_parallel_causal_load_balanced
=
self
.
context_parallel_causal_load_balanced
,
context_parallel_axis
=
self
.
context_parallel_axis
,
context_parallel_strategy
=
self
.
context_parallel_strategy
,
context_checkpoint_name
=
self
.
context_checkpoint_name
,
)
elif
self
.
qkv_layout
.
is_separate
():
...
...
@@ -372,6 +376,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq
=
self
.
max_segments_per_seq
,
context_parallel_causal_load_balanced
=
self
.
context_parallel_causal_load_balanced
,
context_parallel_axis
=
self
.
context_parallel_axis
,
context_parallel_strategy
=
self
.
context_parallel_strategy
,
context_checkpoint_name
=
self
.
context_checkpoint_name
,
)
else
:
...
...
@@ -505,6 +510,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
context_parallel_strategy (CPStrategy): The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING.
context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention.
Optimization parameters
...
...
@@ -529,6 +535,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
max_segments_per_seq
:
Optional
[
int
]
=
1
context_parallel_causal_load_balanced
:
bool
=
False
context_parallel_axis
:
str
=
""
context_parallel_strategy
:
str
=
"DEFAULT"
context_checkpoint_name
:
str
=
"context"
@
nn
.
compact
...
...
@@ -648,6 +655,24 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_factor
=
self
.
scale_factor
del
self
.
scale_factor
# case-insensitive mapping for context parallel strategy
cp_strategy_map
=
{
"DEFAULT"
:
CPStrategy
.
DEFAULT
,
"ALL_GATHER"
:
CPStrategy
.
ALL_GATHER
,
"ALLGATHER"
:
CPStrategy
.
ALL_GATHER
,
# Alternative spelling
"RING"
:
CPStrategy
.
RING
,
}
strategy_key
=
self
.
context_parallel_strategy
.
upper
()
if
strategy_key
in
cp_strategy_map
:
context_parallel_strategy
=
cp_strategy_map
[
strategy_key
]
else
:
valid_strategies
=
list
(
cp_strategy_map
.
keys
())
raise
ValueError
(
f
"Invalid context parallel strategy:
{
self
.
context_parallel_strategy
}
. "
f
"Valid options are:
{
valid_strategies
}
(case insensitive)"
)
if
not
use_fused_attn
:
# unfused attention only supports splitted query, key, value
if
qkv_layout
.
is_qkvpacked
():
...
...
@@ -696,6 +721,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
max_segments_per_seq
=
self
.
max_segments_per_seq
,
context_parallel_causal_load_balanced
=
self
.
context_parallel_causal_load_balanced
,
context_parallel_axis
=
self
.
context_parallel_axis
,
context_parallel_strategy
=
context_parallel_strategy
,
context_checkpoint_name
=
self
.
context_checkpoint_name
,
)(
query
,
...
...
transformer_engine/jax/quantize/helper.py
View file @
2389ed3f
...
...
@@ -404,9 +404,6 @@ def fp8_autocast(
if
fp8_recipe
is
None
:
fp8_recipe
=
recipe
.
DelayedScaling
()
if
mesh_resource
is
None
:
mesh_resource
=
MeshResource
()
Config
=
DelayedScalingQuantizeConfig
if
isinstance
(
fp8_recipe
,
recipe
.
MXFP8BlockScaling
):
Config
=
BlockScalingQuantizeConfig
...
...
transformer_engine/jax/sharding.py
View file @
2389ed3f
...
...
@@ -286,7 +286,7 @@ class MeshResource:
cp_resource
:
str
=
None
_GLOBAL_MESH_RESOURCE
=
MeshResource
()
_GLOBAL_MESH_RESOURCE
=
None
@
contextmanager
...
...
@@ -314,6 +314,11 @@ def global_mesh_resource() -> MeshResource:
Returns:
The current MeshResource instance
"""
assert
_GLOBAL_MESH_RESOURCE
is
not
None
,
(
"Global mesh resource is not set. Please set the MeshResource via a global_shard_guard"
" context. If you are not using multiple GPUs, you can use an empty MeshResource by"
" wrapping your program in 'with global_shard_guard(MeshResource()):'"
)
return
_GLOBAL_MESH_RESOURCE
...
...
transformer_engine/pytorch/graph.py
View file @
2389ed3f
...
...
@@ -4,6 +4,8 @@
"""Functions for CUDA Graphs support in FP8"""
from
collections.abc
import
Iterable
import
contextlib
import
gc
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
TypeVar
,
Union
import
torch
...
...
@@ -58,6 +60,25 @@ def graph_pool_handle():
return
_graph_pool_handle
()
@
contextlib
.
contextmanager
def
_graph_context_wrapper
(
*
args
,
**
kwargs
):
"""Wrapper around `torch.cuda.graph`.
This wrapper is a temporary workaround for a PyTorch bug:
automatic garbage collection can destroy a graph while another
graph is being captured, resulting in a CUDA error. See
https://github.com/pytorch/pytorch/pull/161037.
"""
gc_is_enabled
=
gc
.
isenabled
()
if
gc_is_enabled
:
gc
.
disable
()
with
torch
.
cuda
.
graph
(
*
args
,
**
kwargs
):
yield
if
gc_is_enabled
:
gc
.
enable
()
def
_make_graphed_callables
(
callables
:
SingleOrTuple
[
Callable
],
sample_args
:
SingleOrTuple
[
Tuple
[
torch
.
Tensor
,
...]],
...
...
@@ -445,7 +466,7 @@ def _make_graphed_callables(
args
=
sample_args
[
per_callable_fwd_idx
]
kwargs
=
sample_kwargs
[
per_callable_fwd_idx
]
fwd_graph
=
fwd_graphs
[
per_callable_fwd_idx
]
with
torch
.
cuda
.
graph
(
fwd_graph
,
pool
=
mempool
):
with
_graph_context_wrapper
(
fwd_graph
,
pool
=
mempool
):
outputs
=
func
(
*
args
,
**
kwargs
)
flatten_outputs
,
spec
=
_tree_flatten
(
outputs
)
per_callable_static_outputs
[
per_callable_fwd_idx
]
=
tuple
(
flatten_outputs
)
...
...
@@ -483,7 +504,7 @@ def _make_graphed_callables(
torch
.
empty_like
(
o
)
if
o
.
requires_grad
else
None
for
o
in
static_outputs
)
if
is_training
:
with
torch
.
cuda
.
graph
(
bwd_graph
,
pool
=
mempool
):
with
_graph_context_wrapper
(
bwd_graph
,
pool
=
mempool
):
grad_inputs
=
torch
.
autograd
.
grad
(
outputs
=
tuple
(
o
for
o
in
static_outputs
if
o
.
requires_grad
),
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
),
...
...
@@ -548,7 +569,7 @@ def _make_graphed_callables(
per_callable_output_unflatten_spec
=
[]
graph_id
=
0
for
func
,
args
,
kwargs
,
fwd_graph
in
zip
(
callables
,
sample_args
,
sample_kwargs
,
fwd_graphs
):
with
torch
.
cuda
.
graph
(
fwd_graph
,
pool
=
mempool
):
with
_graph_context_wrapper
(
fwd_graph
,
pool
=
mempool
):
outputs
=
func
(
*
args
,
**
kwargs
)
graph_callables
[
graph_id
]
=
func
graph_id
+=
1
...
...
@@ -570,7 +591,7 @@ def _make_graphed_callables(
torch
.
empty_like
(
o
)
if
o
.
requires_grad
else
None
for
o
in
static_outputs
)
if
is_training
:
with
torch
.
cuda
.
graph
(
bwd_graph
,
pool
=
mempool
):
with
_graph_context_wrapper
(
bwd_graph
,
pool
=
mempool
):
grad_inputs
=
torch
.
autograd
.
grad
(
outputs
=
tuple
(
o
for
o
in
static_outputs
if
o
.
requires_grad
),
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
),
...
...
transformer_engine/pytorch/ops/basic/basic_linear.py
View file @
2389ed3f
...
...
@@ -12,7 +12,6 @@ from typing import Any, Optional
import
torch
from
transformer_engine.pytorch.module.base
import
get_workspace
from
...cpp_extensions
import
general_gemm
from
...distributed
import
(
CudaRNGStatesTracker
,
...
...
@@ -20,18 +19,24 @@ from ...distributed import (
reduce_scatter_along_first_dim
,
)
from
...fp8
import
FP8GlobalStateManager
,
Recipe
from
...module.base
import
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
from
...module.base
import
(
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
get_dummy_wgrad
,
get_workspace
,
)
from
...tensor
import
Quantizer
from
...tensor.float8_tensor
import
Float8Quantizer
from
...tensor._internal.float8_tensor_base
import
Float8TensorBase
from
..op
import
BasicOperation
,
OperationContext
from
.._common
import
maybe_dequantize
,
is_quantized_tensor
from
...utils
import
(
canonicalize_device
,
canonicalize_dtype
,
clear_tensor_data
,
devices_match
,
)
from
..op
import
BasicOperation
,
OperationContext
from
.._common
import
maybe_dequantize
,
is_quantized_tensor
def
_wait_async
(
handle
:
Optional
[
Any
])
->
None
:
...
...
@@ -73,7 +78,8 @@ class BasicLinear(BasicOperation):
weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be
meaningful.
meaningful. This is primarily intented to integrate with
Megatron-LM.
userbuffers_options, dict, optional
Options for overlapping tensor-parallel communication with
compute using Userbuffers. This feature is highly
...
...
@@ -979,20 +985,22 @@ class BasicLinear(BasicOperation):
# Saved tensors from forward pass
(
x_local
,
w
)
=
ctx
.
saved_tensors
# wgrad fusion
# Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad
=
self
.
_accumulate_into_main_grad
grad_weight
=
None
if
ctx
.
weight_requires_grad
and
accumulate_into_main_grad
:
if
hasattr
(
self
.
weight
,
"__fsdp_param__"
):
self
.
weight
.
main_grad
=
self
.
weight
.
get_main_grad
()
if
not
hasattr
(
self
.
weight
,
"main_grad"
):
weight_param
=
self
.
weight
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
if
not
hasattr
(
weight
_param
,
"main_grad"
):
raise
RuntimeError
(
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
)
grad_weight
=
self
.
weight
.
main_grad
.
detach
()
grad_weight
=
weight
_param
.
main_grad
.
detach
()
else
:
accumulate_into_main_grad
=
False
...
...
@@ -1019,6 +1027,17 @@ class BasicLinear(BasicOperation):
# Clear input tensor if possible
clear_tensor_data
(
x_local
)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if
accumulate_into_main_grad
:
grad_weight
=
None
weight_param
=
self
.
weight
if
hasattr
(
weight_param
,
"grad_added_to_main_grad"
):
weight_param
.
grad_added_to_main_grad
=
True
grad_weight
=
get_dummy_wgrad
(
list
(
weight_param
.
size
()),
weight_param
.
dtype
,
zero
=
getattr
(
weight_param
,
"zero_out_wgrad"
,
False
),
)
return
grad_input
,
[
grad_weight
]
transformer_engine/pytorch/ops/fused/backward_linear_add.py
View file @
2389ed3f
...
...
@@ -9,13 +9,10 @@ from typing import Optional
import
torch
from
transformer_engine.pytorch.ops.basic
import
BasicLinear
,
MakeExtraOutput
from
transformer_engine.pytorch.ops.op
import
(
FusedOperation
,
FusibleOperation
,
OperationContext
,
)
from
...module.base
import
get_dummy_wgrad
from
...utils
import
clear_tensor_data
from
..basic
import
BasicLinear
,
MakeExtraOutput
from
..op
import
FusedOperation
,
FusibleOperation
,
OperationContext
class
BackwardLinearAdd
(
FusedOperation
):
...
...
@@ -53,20 +50,22 @@ class BackwardLinearAdd(FusedOperation):
# Saved tensors from forward pass
(
x_local
,
w
)
=
linear_op_ctx
.
saved_tensors
# wgrad fusion
# Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad
=
linear_op
.
_accumulate_into_main_grad
grad_weight
=
None
if
linear_op_ctx
.
weight_requires_grad
and
accumulate_into_main_grad
:
if
hasattr
(
linear_op
.
weight
,
"__fsdp_param__"
):
linear_op
.
weight
.
main_grad
=
linear_op
.
weight
.
get_main_grad
()
if
not
hasattr
(
linear_op
.
weight
,
"main_grad"
):
weight_param
=
linear_op
.
weight
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
if
not
hasattr
(
weight
_param
,
"main_grad"
):
raise
RuntimeError
(
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
)
grad_weight
=
linear_op
.
weight
.
main_grad
.
detach
()
grad_weight
=
weight
_param
.
main_grad
.
detach
()
else
:
accumulate_into_main_grad
=
False
...
...
@@ -92,12 +91,23 @@ class BackwardLinearAdd(FusedOperation):
grad_output_quantizer
=
linear_op_ctx
.
grad_output_quantizer
,
grad_input_quantizer
=
linear_op_ctx
.
grad_input_quantizer
,
)
if
accumulate_into_main_grad
:
grad_weight
=
None
# Clear input tensor if possible
clear_tensor_data
(
x_local
)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if
accumulate_into_main_grad
:
grad_weight
=
None
weight_param
=
linear_op
.
weight
if
hasattr
(
weight_param
,
"grad_added_to_main_grad"
):
weight_param
.
grad_added_to_main_grad
=
True
grad_weight
=
get_dummy_wgrad
(
list
(
weight_param
.
size
()),
weight_param
.
dtype
,
zero
=
getattr
(
weight_param
,
"zero_out_wgrad"
,
False
),
)
return
grad_input
,
[(
grad_weight
,),
()],
[(),
()]
...
...
transformer_engine/pytorch/ops/fused/backward_linear_scale.py
View file @
2389ed3f
...
...
@@ -9,13 +9,10 @@ from typing import Optional
import
torch
from
..basic
import
BasicLinear
,
ConstantScale
from
..op
import
(
FusedOperation
,
FusibleOperation
,
OperationContext
,
)
from
...module.base
import
get_dummy_wgrad
from
...utils
import
clear_tensor_data
from
..basic
import
BasicLinear
,
ConstantScale
from
..op
import
FusedOperation
,
FusibleOperation
,
OperationContext
class
BackwardLinearScale
(
FusedOperation
):
...
...
@@ -54,20 +51,22 @@ class BackwardLinearScale(FusedOperation):
# Saved tensors from forward pass
(
x_local
,
w
)
=
linear_op_ctx
.
saved_tensors
# wgrad fusion
# Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad
=
linear_op
.
_accumulate_into_main_grad
grad_weight
=
None
if
linear_op_ctx
.
weight_requires_grad
and
accumulate_into_main_grad
:
if
hasattr
(
linear_op
.
weight
,
"__fsdp_param__"
):
linear_op
.
weight
.
main_grad
=
linear_op
.
weight
.
get_main_grad
()
if
not
hasattr
(
linear_op
.
weight
,
"main_grad"
):
weight_param
=
linear_op
.
weight
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
if
not
hasattr
(
weight
_param
,
"main_grad"
):
raise
RuntimeError
(
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
)
grad_weight
=
linear_op
.
weight
.
main_grad
.
detach
()
grad_weight
=
weight
_param
.
main_grad
.
detach
()
else
:
accumulate_into_main_grad
=
False
...
...
@@ -92,12 +91,23 @@ class BackwardLinearScale(FusedOperation):
grad_output_quantizer
=
linear_op_ctx
.
grad_output_quantizer
,
grad_input_quantizer
=
linear_op_ctx
.
grad_input_quantizer
,
)
if
accumulate_into_main_grad
:
grad_weight
=
None
# Clear input tensor if possible
clear_tensor_data
(
x_local
)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if
accumulate_into_main_grad
:
grad_weight
=
None
weight_param
=
linear_op
.
weight
if
hasattr
(
weight_param
,
"grad_added_to_main_grad"
):
weight_param
.
grad_added_to_main_grad
=
True
grad_weight
=
get_dummy_wgrad
(
list
(
weight_param
.
size
()),
weight_param
.
dtype
,
zero
=
getattr
(
weight_param
,
"zero_out_wgrad"
,
False
),
)
return
grad_input
,
[(),
(
grad_weight
,)],
[(),
()]
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment