Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
f8c2af4c
Commit
f8c2af4c
authored
May 21, 2025
by
yuguo
Browse files
Merge commit '
1d903f5e
' of...
Merge commit '
1d903f5e
' of
https://github.com/NVIDIA/TransformerEngine
parents
e92773a3
1d903f5e
Changes
211
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1856 additions
and
1953 deletions
+1856
-1953
transformer_engine/pytorch/csrc/type_converters.cpp
transformer_engine/pytorch/csrc/type_converters.cpp
+0
-0
transformer_engine/pytorch/csrc/type_shim.h
transformer_engine/pytorch/csrc/type_shim.h
+0
-359
transformer_engine/pytorch/csrc/util.cpp
transformer_engine/pytorch/csrc/util.cpp
+77
-0
transformer_engine/pytorch/csrc/util.h
transformer_engine/pytorch/csrc/util.h
+0
-2
transformer_engine/pytorch/distributed.py
transformer_engine/pytorch/distributed.py
+48
-18
transformer_engine/pytorch/fp8.py
transformer_engine/pytorch/fp8.py
+2
-2
transformer_engine/pytorch/graph.py
transformer_engine/pytorch/graph.py
+3
-1
transformer_engine/pytorch/jit.py
transformer_engine/pytorch/jit.py
+11
-10
transformer_engine/pytorch/module/_common.py
transformer_engine/pytorch/module/_common.py
+0
-37
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+189
-33
transformer_engine/pytorch/module/grouped_linear.py
transformer_engine/pytorch/module/grouped_linear.py
+13
-7
transformer_engine/pytorch/module/layernorm.py
transformer_engine/pytorch/module/layernorm.py
+4
-7
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+349
-252
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+363
-247
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+362
-266
transformer_engine/pytorch/module/rmsnorm.py
transformer_engine/pytorch/module/rmsnorm.py
+4
-6
transformer_engine/pytorch/ops/basic/basic_linear.py
transformer_engine/pytorch/ops/basic/basic_linear.py
+8
-3
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
...r_engine/pytorch/ops/fused/userbuffers_backward_linear.py
+285
-420
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
...er_engine/pytorch/ops/fused/userbuffers_forward_linear.py
+127
-282
transformer_engine/pytorch/setup.py
transformer_engine/pytorch/setup.py
+11
-1
No files found.
transformer_engine/pytorch/csrc/
extensions/
type_converters.cpp
→
transformer_engine/pytorch/csrc/type_converters.cpp
View file @
f8c2af4c
File moved
transformer_engine/pytorch/csrc/type_shim.h
deleted
100644 → 0
View file @
e92773a3
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#pragma once
#include <ATen/ATen.h>
#include "common/utils.cuh"
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: { \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) { \
case at::ScalarType::Float: { \
using scalar_t_in = float; \
switch (TYPEOUT) { \
case at::ScalarType::Float: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) { \
case at::ScalarType::Double: { \
using scalar_t_in = double; \
switch (TYPEOUT) { \
case at::ScalarType::Double: { \
using scalar_t_out = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_in = float; \
switch (TYPEOUT) { \
case at::ScalarType::Float: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
{
// lanes is intended to be <= 32.
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
x
[
tid
]
+
x
[
tid
+
i
];
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
x
[
tid
]
+
x
[
tid
+
32
];
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
#ifdef __HIP_PLATFORM_AMD__
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down
(
final
,
i
,
THREADS_PER_WARP
);
#else
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
#endif
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
}
__syncthreads
();
// Avoid potential write before read race when reduce_block_into_lanes is called back to back
return
final
;
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes_max_op
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
{
// lanes is intended to be <= 32.
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
i
]));
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
32
]));
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
#ifdef __HIP_PLATFORM_AMD__
final
=
fmaxf
(
fabsf
(
final
),
fabsf
(
__shfl_down
(
final
,
i
,
THREADS_PER_WARP
)));
#else
final
=
fmaxf
(
fabsf
(
final
),
fabsf
(
__shfl_down_sync
(
0xffffffff
,
final
,
i
)));
#endif
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
transformer_engine/pytorch/csrc/
extensions/swizzle
.cpp
→
transformer_engine/pytorch/csrc/
util
.cpp
View file @
f8c2af4c
...
...
@@ -4,10 +4,10 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "transformer_engine/transformer_engine.h"
#include "util.h"
#include "common.h"
std
::
optional
<
at
::
Tensor
>
swizzle_scaling_factors
(
transformer_engine
::
TensorWrapper
&
input
,
bool
rowwise
)
{
using
namespace
transformer_engine
::
pytorch
;
...
...
@@ -45,80 +45,33 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
transformer_engine
::
TensorWrapper
input_cu
(
NVTE_MXFP8_1D_SCALING
);
transformer_engine
::
TensorWrapper
output_cu
(
NVTE_MXFP8_1D_SCALING
);
if
(
rowwise
)
{
input_cu
.
set_rowwise_data
(
input
.
dptr
(),
DType
::
kFloat8E4M3
,
input_shape
);
input_cu
.
set_rowwise_scale_inv
(
scale_inv_dptr
,
DType
::
kFloat8E8M0
,
scale_inv_shape
);
output_cu
.
set_rowwise_data
(
input
.
dptr
(),
DType
::
kFloat8E4M3
,
input_shape
);
output_cu
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
DType
::
kFloat8E8M0
,
scale_inv_shape
);
input_cu
.
set_rowwise_data
(
input
.
dptr
(),
transformer_engine
::
DType
::
kFloat8E4M3
,
input_shape
);
input_cu
.
set_rowwise_scale_inv
(
scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
scale_inv_shape
);
output_cu
.
set_rowwise_data
(
input
.
dptr
(),
transformer_engine
::
DType
::
kFloat8E4M3
,
input_shape
);
output_cu
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
scale_inv_shape
);
}
else
{
input_cu
.
set_columnwise_data
(
input
.
columnwise_dptr
(),
DType
::
kFloat8E4M3
,
input_shape
);
input_cu
.
set_columnwise_scale_inv
(
scale_inv_dptr
,
DType
::
kFloat8E8M0
,
scale_inv_shape
);
output_cu
.
set_columnwise_data
(
input
.
columnwise_dptr
(),
DType
::
kFloat8E4M3
,
input_shape
);
output_cu
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
DType
::
kFloat8E8M0
,
scale_inv_shape
);
input_cu
.
set_columnwise_data
(
input
.
columnwise_dptr
(),
transformer_engine
::
DType
::
kFloat8E4M3
,
input_shape
);
input_cu
.
set_columnwise_scale_inv
(
scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
scale_inv_shape
);
output_cu
.
set_columnwise_data
(
input
.
columnwise_dptr
(),
transformer_engine
::
DType
::
kFloat8E4M3
,
input_shape
);
output_cu
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
scale_inv_shape
);
}
// Launch kernel
nvte_swizzle_scaling_factors
(
input_cu
.
data
(),
output_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
if
(
rowwise
)
{
input
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
DType
::
kFloat8E8M0
,
scale_inv_shape
);
input
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
scale_inv_shape
);
}
else
{
input
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
DType
::
kFloat8E8M0
,
scale_inv_shape
);
input
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
scale_inv_shape
);
}
return
swizzled_scale_inv
;
}
at
::
Tensor
rowwise_swizzle
(
at
::
Tensor
input
,
at
::
Tensor
scale_inv
)
{
using
namespace
transformer_engine
::
pytorch
;
NVTE_CHECK
(
input
.
element_size
()
==
1
,
"8-bit input required for swizzling scaling factors."
);
auto
options
=
at
::
TensorOptions
().
dtype
(
scale_inv
.
dtype
()).
device
(
torch
::
kCUDA
);
auto
swizzled_scale_inv
=
at
::
empty_like
(
scale_inv
,
options
);
void
*
scale_inv_dptr
=
getDataPtr
(
scale_inv
,
0
);
void
*
swizzled_scale_inv_dptr
=
getDataPtr
(
swizzled_scale_inv
,
0
);
auto
input_cu
=
makeTransformerEngineTensor
(
input
.
data_ptr
(),
getTensorShape
(
input
),
DType
::
kFloat8E4M3
,
nullptr
,
nullptr
,
scale_inv_dptr
,
getTensorShape
(
scale_inv
),
NVTE_MXFP8_1D_SCALING
);
auto
output_cu
=
makeTransformerEngineTensor
(
input
.
data_ptr
(),
getTensorShape
(
input
),
DType
::
kFloat8E4M3
,
nullptr
,
nullptr
,
swizzled_scale_inv_dptr
,
getTensorShape
(
swizzled_scale_inv
),
NVTE_MXFP8_1D_SCALING
);
// Launch kernel
nvte_swizzle_scaling_factors
(
input_cu
.
data
(),
output_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
return
swizzled_scale_inv
;
}
at
::
Tensor
columnwise_swizzle
(
at
::
Tensor
input
,
at
::
Tensor
scale_inv
)
{
using
namespace
transformer_engine
::
pytorch
;
NVTE_CHECK
(
input
.
element_size
()
==
1
,
"8-bit input required for swizzling scaling factors."
);
auto
options
=
at
::
TensorOptions
().
dtype
(
scale_inv
.
dtype
()).
device
(
torch
::
kCUDA
);
auto
swizzled_scale_inv
=
at
::
empty_like
(
scale_inv
,
options
);
// Return immediately if tensor is empty
if
(
scale_inv
.
numel
()
==
0
)
{
return
swizzled_scale_inv
;
}
void
*
scale_inv_dptr
=
getDataPtr
(
scale_inv
,
0
);
void
*
swizzled_scale_inv_dptr
=
getDataPtr
(
swizzled_scale_inv
,
0
);
auto
input_cu
=
makeTransformerEngineTensor
(
nullptr
,
input
.
data_ptr
(),
{
1
},
getTensorShape
(
input
),
DType
::
kFloat8E4M3
,
nullptr
,
nullptr
,
nullptr
,
scale_inv_dptr
,
{
1
},
getTensorShape
(
scale_inv
),
NVTE_MXFP8_1D_SCALING
);
auto
output_cu
=
makeTransformerEngineTensor
(
nullptr
,
input
.
data_ptr
(),
{
1
},
getTensorShape
(
input
),
DType
::
kFloat8E4M3
,
nullptr
,
nullptr
,
nullptr
,
swizzled_scale_inv_dptr
,
{
1
},
getTensorShape
(
swizzled_scale_inv
),
NVTE_MXFP8_1D_SCALING
);
// Launch kernel
nvte_swizzle_scaling_factors
(
input_cu
.
data
(),
output_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
return
swizzled_scale_inv
;
}
transformer_engine/pytorch/csrc/util.h
View file @
f8c2af4c
...
...
@@ -13,8 +13,6 @@
#include "transformer_engine/transformer_engine.h"
bool
non_tn_fp8_gemm_supported
();
/* Swizzle the scaling factor of the input tensor.
*
* The returned swizzled scaling factor tensor should be kept alive during the GEMM.
...
...
transformer_engine/pytorch/distributed.py
View file @
f8c2af4c
...
...
@@ -19,7 +19,12 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from
torch.distributed.fsdp._common_utils
import
_get_module_fsdp_state
from
torch.distributed.fsdp._traversal_utils
import
_get_fsdp_states_with_modules
from
.utils
import
non_tn_fp8_gemm_supported
,
safely_set_viewless_tensor_data
,
needs_quantized_gemm
from
.
import
torch_version
from
.utils
import
(
is_non_tn_fp8_gemm_supported
,
safely_set_viewless_tensor_data
,
needs_quantized_gemm
,
)
from
.constants
import
dist_group_type
from
.fp8
import
FP8GlobalStateManager
,
fp8_autocast
from
.tensor.float8_tensor
import
Float8Quantizer
,
Float8Tensor
,
Float8CurrentScalingQuantizer
...
...
@@ -267,17 +272,36 @@ def _get_active_autocast_contexts():
"""
autocast_cached
=
torch
.
is_autocast_cache_enabled
()
gpu_autocast_enabled
=
torch
.
is_autocast_enabled
()
gpu_autocast_dtype
=
torch
.
get_autocast_gpu_dtype
()
gpu_autocast_ctx
=
torch
.
cuda
.
amp
.
autocast
(
gpu_autocast_enabled
,
gpu_autocast_dtype
,
autocast_cached
)
if
torch_version
()
>=
(
2
,
4
,
0
):
gpu_autocast_enabled
=
torch
.
is_autocast_enabled
(
"cuda"
)
gpu_autocast_dtype
=
torch
.
get_autocast_dtype
(
"cuda"
)
gpu_autocast_ctx
=
torch
.
amp
.
autocast
(
"cuda"
,
enabled
=
gpu_autocast_enabled
,
dtype
=
gpu_autocast_dtype
,
cache_enabled
=
autocast_cached
,
)
cpu_autocast_enabled
=
torch
.
is_autocast_cpu_enabled
()
cpu_autocast_dtype
=
torch
.
get_autocast_cpu_dtype
()
cpu_autocast_ctx
=
torch
.
cpu
.
amp
.
autocast
(
cpu_autocast_enabled
,
cpu_autocast_dtype
,
autocast_cached
)
cpu_autocast_enabled
=
torch
.
is_autocast_enabled
(
"cpu"
)
cpu_autocast_dtype
=
torch
.
get_autocast_dtype
(
"cpu"
)
cpu_autocast_ctx
=
torch
.
amp
.
autocast
(
"cpu"
,
enabled
=
cpu_autocast_enabled
,
dtype
=
cpu_autocast_dtype
,
cache_enabled
=
autocast_cached
,
)
else
:
gpu_autocast_enabled
=
torch
.
is_autocast_enabled
()
gpu_autocast_dtype
=
torch
.
get_autocast_gpu_dtype
()
gpu_autocast_ctx
=
torch
.
cuda
.
amp
.
autocast
(
gpu_autocast_enabled
,
gpu_autocast_dtype
,
autocast_cached
)
cpu_autocast_enabled
=
torch
.
is_autocast_cpu_enabled
()
cpu_autocast_dtype
=
torch
.
get_autocast_cpu_dtype
()
cpu_autocast_ctx
=
torch
.
cpu
.
amp
.
autocast
(
cpu_autocast_enabled
,
cpu_autocast_dtype
,
autocast_cached
)
return
gpu_autocast_ctx
,
cpu_autocast_ctx
...
...
@@ -561,7 +585,9 @@ def has_te_modules(network):
"""
from
.module
import
LayerNorm
,
RMSNorm
from
.module.base
import
TransformerEngineBaseModule
from
.attention
import
UnfusedDotProductAttention
,
DotProductAttention
,
MultiheadAttention
from
.attention.dot_product_attention.backends
import
UnfusedDotProductAttention
from
.attention.dot_product_attention.dot_product_attention
import
DotProductAttention
from
.attention.multi_head_attention
import
MultiheadAttention
from
.transformer
import
TransformerLayer
te_classes_list
=
[
...
...
@@ -893,8 +919,10 @@ def _all_gather_fp8(
# Note: We cannot directly all-gather the transposed FP8 tensor,
# so temporarily modify quantizer to avoid creating FP8 transpose.
if
not
isinstance
(
inp
,
Float8TensorBase
):
if
quantizer
is
None
:
raise
ValueError
(
"Input tensor is not FP8 and no quantizer was provided"
)
assert
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
))
# we cannot directly gather the transposed fp8 tensor
# so we need to disable columnwise usage for the quantizer
# and then set it back to the original value after quantizing
init_rowwise_usage
=
quantizer
.
rowwise_usage
init_columnwise_usage
=
quantizer
.
columnwise_usage
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
...
...
@@ -938,7 +966,7 @@ def _all_gather_fp8(
# Make sure FP8 transpose is populated if needed
needs_transpose
=
(
quantizer
is
not
None
and
quantizer
.
columnwise_usage
and
not
non_tn_fp8_gemm_supported
()
quantizer
is
not
None
and
quantizer
.
columnwise_usage
and
not
is_
non_tn_fp8_gemm_supported
()
)
if
needs_transpose
:
if
handle
is
not
None
:
...
...
@@ -1037,11 +1065,11 @@ def _all_gather_mxfp8(
dtype
=
inp
.
dtype
elif
isinstance
(
inp
,
MXFP8TensorBase
):
if
inp
.
_rowwise_data
is
not
None
:
in_shape
=
inp
.
_rowwise_data
.
device
.
size
()
in_shape
=
inp
.
_rowwise_data
.
size
()
device
=
inp
.
_rowwise_data
.
device
dtype
=
inp
.
_rowwise_data
.
dtype
elif
inp
.
_columnwise_data
is
not
None
:
in_shape
=
inp
.
_columnwise_data
.
device
.
size
()
in_shape
=
inp
.
_columnwise_data
.
size
()
device
=
inp
.
_columnwise_data
.
device
dtype
=
inp
.
_columnwise_data
.
dtype
else
:
...
...
@@ -1474,7 +1502,9 @@ def _is_te_module(module):
"""
from
.module
import
LayerNorm
,
RMSNorm
from
.module.base
import
TransformerEngineBaseModule
from
.attention
import
UnfusedDotProductAttention
,
DotProductAttention
,
MultiheadAttention
from
.attention.dot_product_attention.dot_product_attention
import
DotProductAttention
from
.attention.dot_product_attention.backends
import
UnfusedDotProductAttention
from
.attention.multi_head_attention
import
MultiheadAttention
from
.transformer
import
TransformerLayer
te_classes_list
=
[
...
...
transformer_engine/pytorch/fp8.py
View file @
f8c2af4c
...
...
@@ -520,8 +520,8 @@ class FP8GlobalStateManager:
return
# Store updated amaxes and scales from phase 1 post forward.
fp8_meta
[
"updated_amax_history_fwd"
]
=
fp8_meta
[
"scaling_fwd"
].
amax_history
fp8_meta
[
"updated_scale_fwd"
]
=
fp8_meta
[
"scaling_fwd"
].
scale
fp8_meta
[
"updated_amax_history_fwd"
]
=
fp8_meta
[
"scaling_fwd"
].
amax_history
.
clone
()
fp8_meta
[
"updated_scale_fwd"
]
=
fp8_meta
[
"scaling_fwd"
].
scale
.
clone
()
# Retrieve stashed amaxes and scales from phase 1 pre forward.
buffer_position_key
=
"global_fp8_buffer_pos_fwd_recompute"
...
...
transformer_engine/pytorch/graph.py
View file @
f8c2af4c
...
...
@@ -536,7 +536,9 @@ def _make_graphed_callables(
# Only Set the FP8 meta for the modules included by forward
continue
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
from
transformer_engine.pytorch.attention
import
DotProductAttention
from
transformer_engine.pytorch.attention.dot_product_attention
import
(
DotProductAttention
,
)
if
(
isinstance
(
m
,
DotProductAttention
)
...
...
transformer_engine/pytorch/jit.py
View file @
f8c2af4c
...
...
@@ -8,6 +8,9 @@ from functools import wraps
from
typing
import
Callable
,
Optional
,
Tuple
import
torch
from
.
import
torch_version
from
.utils
import
gpu_autocast_ctx
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
# pylint: disable=unnecessary-lambda-assignment
...
...
@@ -32,13 +35,13 @@ def lazy_compile(func):
jit_fuser
=
lambda
func
:
func
if
torch
.
_
_version
__
>=
"2"
and
bool
(
int
(
os
.
getenv
(
"NVTE_TORCH_COMPILE"
,
"1"
))):
if
torch_version
()
>=
(
2
,
0
,
0
)
and
bool
(
int
(
os
.
getenv
(
"NVTE_TORCH_COMPILE"
,
"1"
))):
jit_fuser
=
lazy_compile
# See: https://github.com/NVIDIA/TransformerEngine/issues/597
dropout_fuser
=
torch
.
jit
.
script
if
torch
.
_
_version
__
>=
"2.2"
and
bool
(
int
(
os
.
getenv
(
"NVTE_TORCH_COMPILE"
,
"1"
))):
if
torch_version
()
>=
(
2
,
2
,
0
)
and
bool
(
int
(
os
.
getenv
(
"NVTE_TORCH_COMPILE"
,
"1"
))):
dropout_fuser
=
lazy_compile
...
...
@@ -51,11 +54,9 @@ def set_jit_fusion_options() -> None:
if
not
IS_HIP_EXTENSION
:
"""Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
1
])
if
TORCH_MAJOR
==
2
and
TORCH_MINOR
>=
2
:
if
torch_version
()
>=
(
2
,
2
,
0
):
pass
elif
(
TORCH_MAJOR
==
2
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>=
1
0
):
elif
torch_version
()
>=
(
1
,
10
,
0
):
# nvfuser
torch
.
_C
.
_jit_set_profiling_executor
(
True
)
torch
.
_C
.
_jit_set_profiling_mode
(
True
)
...
...
@@ -124,7 +125,7 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:
def
bias_gelu_fused
(
inp
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Disable native AMP for bias_gelu_fused_"""
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
with
gpu_
autocast
_ctx
(
enabled
=
False
):
if
bias
is
not
None
and
bias
.
numel
()
!=
0
:
return
bias_gelu_fused_
(
inp
,
bias
)
return
gelu_fused_
(
inp
)
...
...
@@ -134,7 +135,7 @@ def bgrad_dgelu_fused(
grad_output
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
)
->
Tuple
[
Optional
[
torch
.
Tensor
],
torch
.
Tensor
]:
"""Disable native AMP for `bgrad_dgelu_fused_`"""
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
with
gpu_
autocast
_ctx
(
enabled
=
False
):
if
bias
is
not
None
and
bias
.
numel
()
!=
0
:
return
bgrad_dgelu_fused_
(
grad_output
,
inp
,
bias
)
return
None
,
dgelu_fused_
(
grad_output
,
inp
)
...
...
@@ -175,7 +176,7 @@ def bias_dropout_add_fused_train(
)
->
torch
.
Tensor
:
"""Disable native AMP and enable grad for BDA"""
with
torch
.
enable_grad
():
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
with
gpu_
autocast
_ctx
(
enabled
=
False
):
return
bias_dropout_add_fused_train_
(
x
,
bias
,
residual
,
prob
)
...
...
@@ -191,7 +192,7 @@ def bias_dropout_add_fused_inference(
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
"""Disable native AMP for BDA"""
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
with
gpu_
autocast
_ctx
(
enabled
=
False
):
return
bias_dropout_add_fused_inference_
(
x
,
bias
,
residual
,
prob
)
...
...
transformer_engine/pytorch/module/_common.py
View file @
f8c2af4c
...
...
@@ -6,8 +6,6 @@
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Union
,
Callable
from
dataclasses
import
dataclass
from
functools
import
reduce
from
operator
import
mul
as
multiply_op
import
queue
import
torch
...
...
@@ -15,7 +13,6 @@ import torch
from
..
import
cpp_extensions
as
tex
from
..constants
import
TE_DType
from
..utils
import
get_default_init_method
from
..tensor.float8_tensor
import
Float8Tensor
import
warnings
try
:
from
lightop
import
rmsnorm_forward
,
rmsnorm_backward
...
...
@@ -24,7 +21,6 @@ except ImportError:
enable_lightop
=
False
warnings
.
warn
(
"Failed to import lightop module. Falling back to alternative implementation."
,
UserWarning
)
def
_get_normalization_func
(
normalization
:
str
,
forward
:
bool
):
fwd_normalization_funcs
=
{
"LayerNorm"
:
tex
.
layernorm_fwd
,
...
...
@@ -40,39 +36,6 @@ def _get_normalization_func(normalization: str, forward: bool):
return
bwd_normalization_funcs
[
normalization
]
def
_fix_gathered_fp8_transpose
(
fp8_tensor
:
Float8Tensor
,
tp_size
:
int
)
->
Float8Tensor
:
"""Reorder FP8 transposes after Userbuffers gather.
The all-gather is performed in-place in the Float8Tensor's
row-wise data, and afterwards we need to do a transpose to get the
correct ordering. This misuses data fields in Float8Tensor and
should be considered an evil hack. It would be best to move
transpose logic into CommOverlap::get_buffer.
Responsibility for fixing: adener, tmoon
"""
assert
isinstance
(
fp8_tensor
,
Float8Tensor
),
"Tensor is not a Float8Tensor"
assert
tp_size
>
1
,
"The tensor transpose cannot be interleaved when TP size is 1"
assert
fp8_tensor
.
_data
is
not
None
,
"The tensor does not hold any rowwise data"
assert
(
fp8_tensor
.
_data
.
shape
[
0
]
%
tp_size
==
0
),
"Leading dimension of data is not divisble by TP size"
data
=
fp8_tensor
.
_data
batched_size
=
reduce
(
multiply_op
,
data
.
shape
[
1
:])
interleaved_shape
=
[
tp_size
,
data
.
shape
[
0
]
//
tp_size
,
batched_size
]
transposed_shape
=
[
data
.
shape
[
0
]
//
tp_size
,
batched_size
*
tp_size
]
fp8_tensor
.
_transpose
=
(
data
.
view
(
interleaved_shape
).
transpose
(
0
,
1
).
contiguous
().
view
(
transposed_shape
)
)
fp8_tensor
.
_transpose_invalid
=
False
fp8_tensor
.
_data
=
None
return
fp8_tensor
def
apply_normalization
(
inputmat
:
torch
.
Tensor
,
ln_out
:
torch
.
Tensor
,
...
...
transformer_engine/pytorch/module/base.py
View file @
f8c2af4c
...
...
@@ -4,6 +4,7 @@
"""Base modules and utilities for TransformerEngine PyTorch API"""
import
io
import
math
import
os
import
pickle
import
warnings
...
...
@@ -35,10 +36,13 @@ from ..distributed import (
_fsdp_gather_tensors
,
)
from
..constants
import
dist_group_type
from
..tensor
import
QuantizedTensor
,
Quantizer
from
..tensor.quantized_tensor
import
QuantizedTensor
,
QuantizedTensorBase
,
Quantizer
from
..tensor.float8_tensor
import
Float8Quantizer
,
Float8CurrentScalingQuantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
..tensor._internal.float8_tensor_base
import
Float8TensorBase
from
..tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
..utils
import
torch_get_autocast_gpu_dtype
from
..tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
...common.recipe
import
Recipe
from
...debug.pytorch.debug_state
import
TEDebugState
...
...
@@ -451,6 +455,142 @@ def destroy_ub():
layers_atomic_ring_exchange
=
[]
def
fill_userbuffers_buffer_for_all_gather
(
comm
,
local_tensor
:
torch
.
Tensor
,
quantizer
:
Optional
[
Quantizer
],
process_group
,
)
->
tuple
[
torch
.
Tensor
|
QuantizedTensorBase
,
torch
.
Tensor
|
QuantizedTensorBase
]:
"""Fill local shard of Userbuffers buffer with data for all-gather
Returns the full tensor and the local shard, both using the
Userbuffers buffer as their underlying data. These tensors should
be used carefully (e.g. only immediately before and after a
Userbuffers operation) since the underlying data may be
overwritten by other Userbuffers operations.
May perform blocking communication if needed for the gathered
tensor's metadata, e.g. scaling factors.
"""
# Tensor dimensions
local_shape
=
local_tensor
.
size
()
if
not
local_shape
:
raise
ValueError
(
f
"Invalid local tensor (shape=
{
tuple
(
local_shape
)
}
)"
)
process_group_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
global_shape
=
list
(
local_shape
)
global_shape
[
0
]
*=
process_group_size
# Unquantized data
if
quantizer
is
None
:
if
isinstance
(
local_tensor
,
QuantizedTensorBase
):
local_tensor
=
local_tensor
.
dequantize
()
if
comm
.
is_fp8_ubuf
():
raise
RuntimeError
(
"Attempting to all-gather unquantized tensor, "
"but Userbuffers is initialized with FP8 buffers"
)
comm
.
copy_into_buffer
(
local_tensor
,
local_chunk
=
True
)
global_tensor
=
comm
.
get_buffer
(
shape
=
global_shape
)
return
global_tensor
,
local_tensor
# FP8 data
if
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)):
if
not
isinstance
(
local_tensor
,
Float8TensorBase
):
if
isinstance
(
local_tensor
,
QuantizedTensorBase
):
local_tensor
.
dequantize
()
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
local_tensor
=
quantizer
(
local_tensor
)
if
not
comm
.
is_fp8_ubuf
():
raise
RuntimeError
(
"Attempting to all-gather FP8 tensor, "
"but Userbuffers is not initialized with FP8 buffers"
)
comm
.
copy_into_buffer
(
local_tensor
.
_data
,
local_chunk
=
True
)
global_tensor_data
=
comm
.
get_buffer
(
shape
=
global_shape
)
global_tensor
=
Float8TensorBase
(
data
=
global_tensor_data
,
fp8_scale_inv
=
local_tensor
.
_scale_inv
,
fp8_dtype
=
local_tensor
.
_fp8_dtype
,
quantizer
=
quantizer
,
)
return
global_tensor
,
local_tensor
# MXFP8 data
if
isinstance
(
quantizer
,
MXFP8Quantizer
):
# Cast to MXFP8 if needed
if
not
isinstance
(
local_tensor
,
MXFP8TensorBase
):
if
isinstance
(
local_tensor
,
QuantizedTensorBase
):
local_tensor
.
dequantize
()
local_tensor
=
quantizer
(
local_tensor
)
if
not
comm
.
is_fp8_ubuf
():
raise
RuntimeError
(
"Attempting to all-gather MXFP8 tensor, "
"but Userbuffers is not initialized with FP8 buffers"
)
# Check which MXFP8 buffer to communicate
if
quantizer
.
rowwise_usage
==
quantizer
.
columnwise_usage
:
raise
ValueError
(
"Userbuffers can only communicate one MXFP8 buffer at a time, "
f
"but quantizer has rowwise_usage=
{
quantizer
.
rowwise_usage
}
, "
f
"columnwise_usage=
{
quantizer
.
columnwise_usage
}
"
)
with_rowwise_data
=
quantizer
.
rowwise_usage
# Copy MXFP8 data to local chunk of Userbuffers buffer
local_data
=
(
local_tensor
.
_rowwise_data
if
with_rowwise_data
else
local_tensor
.
_columnwise_data
)
comm
.
copy_into_buffer
(
local_data
,
local_chunk
=
True
)
# Gather scaling-inverses
if
math
.
prod
(
local_shape
[:
-
1
])
%
128
!=
0
:
raise
ValueError
(
"Userbuffers requires MXFP8 tensor dims that are divisible by 128, "
f
"but got MXFP8 tensor with shape=
{
tuple
(
local_shape
)
}
"
)
local_scale_inv
=
(
local_tensor
.
_rowwise_scale_inv
if
with_rowwise_data
else
local_tensor
.
_columnwise_scale_inv
)
local_scale_inv_size
=
list
(
local_scale_inv
.
size
())
global_scale_inv
=
torch
.
empty
(
[
process_group_size
*
local_scale_inv_size
[
0
]]
+
local_scale_inv_size
[
1
:],
dtype
=
local_scale_inv
.
dtype
,
device
=
local_scale_inv
.
device
,
)
torch
.
distributed
.
all_gather_into_tensor
(
global_scale_inv
,
local_scale_inv
,
group
=
process_group
,
)
# Construct MXFP8 tensor with Userbuffers buffer
rowwise_data
,
rowwise_scale_inv
=
None
,
None
columnwise_data
,
columnwise_scale_inv
=
None
,
None
global_data
=
comm
.
get_buffer
(
shape
=
global_shape
)
if
with_rowwise_data
:
rowwise_data
,
rowwise_scale_inv
=
global_data
,
global_scale_inv
else
:
columnwise_data
,
columnwise_scale_inv
=
global_data
,
global_scale_inv
global_tensor
=
MXFP8TensorBase
(
rowwise_data
=
rowwise_data
,
rowwise_scale_inv
=
rowwise_scale_inv
,
columnwise_data
=
columnwise_data
,
columnwise_scale_inv
=
columnwise_scale_inv
,
fp8_dtype
=
local_tensor
.
_fp8_dtype
,
quantizer
=
quantizer
,
)
return
global_tensor
,
local_tensor
# Unsupported data format
raise
ValueError
(
f
"Unsupported quantizer for Userbuffers (
{
quantizer
}
)"
)
class
TransformerEngineBaseModule
(
torch
.
nn
.
Module
,
ABC
):
"""Base TE module."""
...
...
@@ -625,7 +765,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
reset
(
"scaling_fwd"
)
reset
(
"scaling_bwd"
)
def
get_extra_state
(
self
)
->
torch
.
Tensor
:
def
get_extra_state
(
self
)
->
Optional
[
torch
.
Tensor
]
:
"""Save before checkpointing."""
# This implementation is working around a few issues:
...
...
@@ -659,25 +799,26 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Store FP8 state if needed
state
=
None
fp8_checkpoint
=
self
.
fp8_meta
[
"fp8_checkpoint"
]
or
self
.
fp8
or
self
.
fp8_calibration
if
fp8_checkpoint
:
# Copy tensors to CPU and store
state
=
{}
state
[
"recipe"
]
=
self
.
fp8_meta
[
"recipe"
]
if
state
[
"recipe"
].
delayed
():
state
[
"scale_fwd"
]
=
to_cpu
(
self
.
fp8_meta
[
"scaling_fwd"
].
scale
)
state
[
"amax_history_fwd"
]
=
to_cpu
(
self
.
fp8_meta
[
"scaling_fwd"
].
amax_history
)
state
[
"scale_bwd"
]
=
to_cpu
(
self
.
fp8_meta
[
"scaling_bwd"
].
scale
)
state
[
"amax_history_bwd"
]
=
to_cpu
(
self
.
fp8_meta
[
"scaling_bwd"
].
amax_history
)
# Store other pickelable values
extra
=
{}
for
k
,
v
in
self
.
fp8_meta
.
items
():
if
k
!=
"buffer_index_and_autocast_key"
and
isinstance
(
v
,
(
bool
,
int
,
float
,
str
,
tuple
,
list
)
):
extra
[
k
]
=
v
state
[
"extra_fp8_variables"
]
=
extra
if
not
fp8_checkpoint
:
return
None
# Copy tensors to CPU and store
state
=
{}
state
[
"recipe"
]
=
self
.
fp8_meta
[
"recipe"
]
if
state
[
"recipe"
].
delayed
():
state
[
"scale_fwd"
]
=
to_cpu
(
self
.
fp8_meta
[
"scaling_fwd"
].
scale
)
state
[
"amax_history_fwd"
]
=
to_cpu
(
self
.
fp8_meta
[
"scaling_fwd"
].
amax_history
)
state
[
"scale_bwd"
]
=
to_cpu
(
self
.
fp8_meta
[
"scaling_bwd"
].
scale
)
state
[
"amax_history_bwd"
]
=
to_cpu
(
self
.
fp8_meta
[
"scaling_bwd"
].
amax_history
)
# Store other pickelable values
extra
=
{}
for
k
,
v
in
self
.
fp8_meta
.
items
():
if
k
!=
"buffer_index_and_autocast_key"
and
isinstance
(
v
,
(
bool
,
int
,
float
,
str
,
tuple
,
list
)
):
extra
[
k
]
=
v
state
[
"extra_fp8_variables"
]
=
extra
# Serialize state into byte tensor
torch
.
cuda
.
synchronize
()
...
...
@@ -685,7 +826,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state_serialized
=
torch
.
frombuffer
(
state_serialized
,
dtype
=
torch
.
uint8
)
return
state_serialized
def
set_extra_state
(
self
,
state
:
torch
.
Tensor
)
->
None
:
def
set_extra_state
(
self
,
state
:
Optional
[
torch
.
Tensor
]
)
->
None
:
"""Load previous state."""
if
state
is
None
:
return
...
...
@@ -734,7 +875,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Get activation data type for AMP."""
# Native AMP (`torch.autocast`) gets highest priority
if
torch
.
is_autocast_enabled
():
self
.
activation_dtype
=
torch
.
get_autocast_gpu_dtype
()
self
.
activation_dtype
=
torch
_
get_autocast_gpu_dtype
()
return
# All checks after this have already been performed once, thus skip
...
...
@@ -898,11 +1039,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Non-FP8 case: bgrad is fused with wgrad for this case.
if
not
ctx
.
fp8
and
not
ctx
.
debug
:
if
gather_grad_output
:
if
not
ctx
.
ub_overlap_ag
or
ctx
.
ub_obj_gradout
is
None
:
if
not
ctx
.
ub_overlap_ag
or
ctx
.
ub_obj_gradout
is
None
:
# Perform NCCL all-gather
grad_output
,
_
=
gather_along_first_dim
(
grad_output
,
ctx
.
tp_group
)
else
:
ctx
.
ub_obj_gradout
.
copy_into_buffer
(
grad_output
,
quantizer
,
local_chunk
=
True
)
grad_output
=
ctx
.
ub_obj_gradout
.
get_buffer
(
quantizer
)
else
:
# Initialize Userbuffers all-gather
grad_output
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ctx
.
ub_obj_gradout
,
grad_output
,
None
,
ctx
.
tp_group
,
)
return
grad_output
,
None
# FP8 with all-gather: unfused bgrad, fused cast + transpose
...
...
@@ -925,8 +1070,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_output
=
quantizer
(
grad_output
)
# Copy into communication buffer, and replace original gradient with it
ctx
.
ub_obj_gradout
.
copy_into_buffer
(
grad_output
,
quantizer
,
local_chunk
=
True
)
grad_output
=
ctx
.
ub_obj_gradout
.
get_buffer
(
quantizer
)
grad_output
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ctx
.
ub_obj_gradout
,
grad_output
,
quantizer
,
ctx
.
tp_group
,
)
else
:
grad_output
,
_
=
gather_along_first_dim
(
grad_output
,
...
...
@@ -1140,7 +1289,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
raise
ValueError
(
"tensor and quantizer kwargs must be provided to construct FP8 workspace"
)
if
cache_name
is
not
None
:
# Ensure the tensor in the cache is an instance of torch.Tensor,
# as it persists beyond a single forward pass.
# Setting internal=True would cause the data to be removed in prepare_for_saving(...).
quantizer_internal
=
quantizer
.
internal
quantizer
.
internal
=
False
out
=
quantizer
.
quantize
(
tensor
,
dtype
=
workspace_dtype
)
if
cache_name
is
not
None
:
quantizer
.
internal
=
quantizer_internal
# Update cache
if
cache_name
is
not
None
:
...
...
@@ -1188,7 +1346,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if
self
.
wgrad_store
is
None
or
not
self
.
wgrad_store
.
delay_wgrad_compute
():
return
with
torch
.
cuda
.
nvtx
.
range
(
f
"_
{
self
.
__class__
.
__name__
}
_wgrad"
):
(
wgrad
,
grad
_bias_
,
_
,
_
),
_
=
self
.
wgrad_store
.
pop
()
(
wgrad
,
b
grad
),
_
=
self
.
wgrad_store
.
pop
()
if
not
self
.
fuse_wgrad_accumulation
:
unfused_weights
=
[
getattr
(
self
,
name
)
for
name
in
self
.
weight_names
]
weight_tensor
=
noop_cat
(
unfused_weights
)
...
...
@@ -1197,9 +1355,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if
self
.
use_bias
:
bias_tensor
=
noop_cat
([
getattr
(
self
,
name
)
for
name
in
self
.
bias_names
])
if
bias_tensor
.
grad
is
None
:
bias_tensor
.
grad
=
grad_bias_
.
to
(
bias_tensor
.
dtype
)
del
grad_bias_
del
wgrad
bias_tensor
.
grad
=
bgrad
.
to
(
bias_tensor
.
dtype
)
def
_validate_name
(
self
):
"""
...
...
transformer_engine/pytorch/module/grouped_linear.py
View file @
f8c2af4c
...
...
@@ -4,6 +4,7 @@
"""GroupedLinear API"""
from
typing
import
Union
,
Optional
,
Callable
,
Tuple
,
List
import
warnings
import
functools
import
torch
...
...
@@ -43,7 +44,7 @@ from ..graph import is_graph_capturing
from
..cpu_offload
import
is_cpu_offload_enabled
from
..tensor.quantized_tensor
import
(
QuantizedTensor
,
QuantizedTensor
Base
,
Quantizer
,
prepare_for_saving
,
restore_from_saved
,
...
...
@@ -182,11 +183,11 @@ class _GroupedLinear(torch.autograd.Function):
# TODO: update after #1638 is merged. # pylint: disable=fixme
if
weight_requires_grad
:
for
inputmat
in
inputmats
:
if
isinstance
(
inputmat
,
QuantizedTensor
):
if
isinstance
(
inputmat
,
QuantizedTensor
Base
):
inputmat
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
if
inp
.
requires_grad
:
for
weight
in
weights_fp8
:
if
isinstance
(
weight
,
QuantizedTensor
):
if
isinstance
(
weight
,
QuantizedTensor
Base
):
weight
.
update_usage
(
columnwise_usage
=
True
)
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
...
...
@@ -299,7 +300,7 @@ class _GroupedLinear(torch.autograd.Function):
)
for
weight
,
quantizer
in
zip
(
weights
,
ctx
.
weight_quantizers
):
if
quantizer
is
not
None
and
isinstance
(
weight
,
QuantizedTensor
):
if
quantizer
is
not
None
and
isinstance
(
weight
,
QuantizedTensor
Base
):
weight
.
update_usage
(
rowwise_usage
=
quantizer
.
rowwise_usage
,
columnwise_usage
=
quantizer
.
columnwise_usage
,
...
...
@@ -663,7 +664,7 @@ class GroupedLinear(TransformerEngineBaseModule):
produced)
"""
assert
not
isinstance
(
inp
,
QuantizedTensor
inp
,
QuantizedTensor
Base
),
"GroupedLinear doesn't support input tensor in FP8."
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
...
...
@@ -675,9 +676,14 @@ class GroupedLinear(TransformerEngineBaseModule):
weight_tensors
=
[
getattr
(
self
,
f
"weight
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
bias_tensors
=
[
getattr
(
self
,
f
"bias
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
if
not
self
.
fp8
:
if
not
self
.
fp8
and
any
(
isinstance
(
w
,
QuantizedTensorBase
)
for
w
in
weight_tensors
):
warnings
.
warn
(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
weight_tensors
=
[
w
.
dequantize
()
if
isinstance
(
w
,
QuantizedTensor
)
else
w
for
w
in
weight_tensors
w
.
dequantize
()
if
isinstance
(
w
,
QuantizedTensorBase
)
else
w
for
w
in
weight_tensors
]
input_quantizers
,
weight_quantizers
,
output_quantizers
=
(
...
...
transformer_engine/pytorch/module/layernorm.py
View file @
f8c2af4c
...
...
@@ -94,6 +94,9 @@ class LayerNorm(_LayerNormOp):
)
kwargs
[
"dtype"
]
=
params_dtype
# Flag for sequence parallelism (custom Megatron-LM integration)
self
.
sequence_parallel
:
Optional
[
bool
]
=
sequence_parallel
# Initialize layer norm operation
super
().
__init__
(
normalized_shape
,
...
...
@@ -102,12 +105,6 @@ class LayerNorm(_LayerNormOp):
**
kwargs
,
)
# Flag for sequence parallelism (custom Megatron-LM integration)
self
.
sequence_parallel
:
Optional
[
bool
]
=
sequence_parallel
if
sequence_parallel
is
not
None
:
self
.
weight
.
sequence_parallel
=
sequence_parallel
self
.
bias
.
sequence_parallel
=
sequence_parallel
def
reset_layer_norm_parameters
(
self
)
->
None
:
"""Init LN params"""
warnings
.
warn
(
...
...
@@ -136,7 +133,7 @@ class LayerNorm(_LayerNormOp):
super
().
reset_parameters
()
# Set flag for sequence parallelism (custom Megatron-LM integration)
if
getattr
(
self
,
"
sequence_parallel
"
,
None
)
is
not
None
:
if
self
.
sequence_parallel
is
not
None
:
self
.
weight
.
sequence_parallel
=
self
.
sequence_parallel
self
.
bias
.
sequence_parallel
=
self
.
sequence_parallel
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
f8c2af4c
...
...
@@ -9,7 +9,6 @@ from typing import Callable, Dict, Optional, Tuple, Union
from
functools
import
reduce
from
operator
import
mul
as
multiply_op
import
functools
import
torch
from
torch.nn
import
init
...
...
@@ -18,6 +17,7 @@ import transformer_engine_torch as tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
torch_version
from
.base
import
(
fill_userbuffers_buffer_for_all_gather
,
get_workspace
,
get_ub
,
TransformerEngineBaseModule
,
...
...
@@ -53,9 +53,10 @@ from ..distributed import (
from
..constants
import
GemmParallelModes
,
dist_group_type
from
..jit
import
no_torch_dynamo
from
..graph
import
is_graph_capturing
from
._common
import
apply_normalization
,
noop_cat
,
_fix_gathered_fp8_transpose
,
WeightGradStore
from
._common
import
apply_normalization
,
noop_cat
,
WeightGradStore
from
..tensor.quantized_tensor
import
(
QuantizedTensor
,
QuantizedTensorBase
,
Quantizer
,
prepare_for_saving
,
restore_from_saved
,
...
...
@@ -144,8 +145,10 @@ class _LayerNormLinear(torch.autograd.Function):
# Make sure input dimensions are compatible
out_features
,
in_features
=
weight
.
shape
inp_shape
=
inp
.
shape
inp_requires_grad
=
inp
.
requires_grad
assert
inp_shape
[
-
1
]
==
in_features
,
"GEMM not possible"
inputmat
=
inp
.
view
((
-
1
,
in_features
))
inp
=
inp
.
view
((
-
1
,
in_features
))
inputmat
=
inp
if
fp8
:
assert_dim_for_fp8_exec
(
inputmat
,
weight
)
...
...
@@ -158,42 +161,43 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_range_pop
(
f
"
{
nvtx_label
}
.norm_input_cast"
)
tp_world_size
=
get_distributed_world_size
(
tp_group
)
ub_overlap_ag_fprop
=
(
ub_overlap_ag_fprop
and
is_grad_enabled
and
not
return_layernorm_output
)
weight_requires_grad
=
weight
.
requires_grad
backward_needs_input
=
is_grad_enabled
and
weight_requires_grad
with_input_all_gather
=
parallel_mode
==
"column"
and
sequence_parallel
# Check if Userbuffers is supported
if
fp8
:
if
any
([
ub_overlap_ag_fprop
,
ub_overlap_rs_fprop
])
and
not
(
FP8GlobalStateManager
.
get_fp8_recipe
().
float8_per_tensor_scaling
()
):
raise
NotImplementedError
(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
# Configure Userbuffers communication (comm+GEMM overlap)
ub_obj
=
None
ub_type
=
None
ub_overlap_ag_fprop
=
(
ub_overlap_ag_fprop
and
is_grad_enabled
and
not
return_layernorm_output
)
if
ub_overlap_rs_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_type
=
tex
.
CommOverlapType
.
RS
elif
ub_overlap_ag_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_type
=
tex
.
CommOverlapType
.
AG
# Configure quantizer for norm output
if
fp8
:
if
input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for input tensor"
)
columnwise_usage
=
backward_needs_input
if
(
columnwise_usage
and
with_input_all_gather
and
not
isinstance
(
input_quantizer
,
MXFP8Quantizer
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
backward_needs_input
)
if
with_input_all_gather
and
isinstance
(
input_quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)
):
columnwise
_usage
=
False
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
columnwise_usag
e
)
# All-gather is not supported with FP8
column
-
wise
data
input_quantizer
.
set_usage
(
columnwise
=
Fals
e
)
#
Avoid quantized norm kernel if norm output will be returned
#
or if a gather of ln_out must be in high precis
ion
.
#
Do TP communication in high precision if quantized format
#
does not support communicat
ion
force_hp_blockwise_ln_out_gather
=
(
fp8
and
with_input_all_gather
and
isinstance
(
input_quantizer
,
Float8BlockQuantizer
)
)
# Perform TP communication in high precision.
)
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
with_quantized_norm
=
(
fp8
and
not
return_layernorm_output
...
...
@@ -215,16 +219,19 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_ln_sm_margin
,
zero_centered_gamma
,
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.norm"
)
# Store unquantized layer norm output if we need to return it
ln_out_return
=
None
if
return_layernorm_output
or
return_layernorm_output_gathered
:
ln_out_return
=
ln_out
nvtx_range_pop
(
f
"
{
nvtx_label
}
.norm"
)
# Prepare GEMM input
# ------------------------------------------------------
# Prepare GEMM input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# ------------------------------------------------------
nvtx_range_push
(
f
"
{
nvtx_label
}
.gemm_input_cast_comm"
)
ln_out_total
=
None
ub_obj_fprop
=
None
if
with_input_all_gather
:
if
return_layernorm_output_gathered
:
# Perform all-gather in high precision if gathered
...
...
@@ -232,47 +239,53 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out
,
tp_group
)
ln_out_return
=
ln_out_total
if
fp8
or
debug
:
ln_out
=
input_quantizer
(
ln_out
)
if
not
force_hp_blockwise_ln_out_gather
:
ln_out
=
input_quantizer
(
ln_out
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
ln_out_total
=
input_quantizer
(
ln_out_total
)
else
:
quantizer
=
None
if
fp8
or
debug
:
quantizer
=
input_quantizer
if
not
with_quantized_norm
and
not
force_hp_blockwise_ln_out_gather
:
ln_out
=
input_quantizer
(
ln_out
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
ub_overlap_ag_fprop
:
# Copy into Userbuffers buffer
ub_obj_fprop
=
get_ub
(
ub_name
+
"_fprop"
)
ub_obj_fprop
.
get_buffer
(
input_quantizer
,
local_chunk
=
True
).
copy_
(
ln_out
)
ln_out_total
=
ub_obj_fprop
.
get_buffer
(
input_quantizer
)
else
:
# All-gather with NCCL
ln_out
=
quantizer
(
ln_out
)
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
ub_overlap_ag_fprop
:
# Initialize Userbuffers all-gather
ln_out_total
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ub_obj
,
ln_out
,
quantizer
,
tp_group
,
)
else
:
# Perform NCCL all-gather
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out
,
tp_group
,
quantizer
=
(
input_
quantizer
if
fp8
or
debug
else
None
)
,
quantizer
=
quantizer
,
)
else
:
if
(
fp8
or
debug
)
and
not
with_quantized_norm
:
ln_out
=
input_quantizer
(
ln_out
)
ln_out_total
=
ln_out
nvtx_range_pop
(
f
"
{
nvtx_label
}
.gemm_input_cast_comm"
)
# ------------------------------------------------------
# GEMM input tensor is ready...
# ------------------------------------------------------
# Cast weight to expected dtype
# ------------------------------------------------------
# Prepare weight tensor
# ------------------------------------------------------
weightmat
=
weight
quantized_weight
=
False
if
not
fp8
and
not
debug
:
weightmat
=
cast_if_needed
(
weightmat
,
activation_dtype
)
else
:
quantized_weight
=
not
isinstance
(
weight
,
QuantizedTensor
)
if
fp8
or
debug
:
quantized_weight
=
not
isinstance
(
weight
,
QuantizedTensorBase
)
# Configure quantizer
if
weight_quantizer
is
not
None
:
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
#
FP8 cast to workspace buffer
#
Get quantized weight
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
weightmat
=
module
.
get_weight_workspace
(
tensor
=
weight
,
quantizer
=
weight_quantizer
,
...
...
@@ -282,17 +295,21 @@ class _LayerNormLinear(torch.autograd.Function):
fsdp_group
=
fsdp_group
,
workspace_dtype
=
activation_dtype
,
)
weightmat
.
update_usage
(
rowwise_usage
=
True
)
else
:
weightmat
=
cast_if_needed
(
weightmat
,
activation_dtype
)
# Cast for AMP
# ------------------------------------------------------
# Weight tensor is ready for GEMM...
# ------------------------------------------------------
# Cast bias to expected dtype
bias_dtype
=
activation_dtype
if
needs_quantized_gemm
(
ln_out_total
)
and
activation_dtype
==
torch
.
float32
:
# cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16
bias_dtype
=
torch
.
bfloat16
bias
=
cast_if_needed
(
bias
,
bias_dtype
)
if
bias
is
not
None
else
bias
# Configure output quantizer
if
output_quantizer
is
not
None
:
output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
# Calibrate quantizers if needed
if
not
fp8
and
fp8_calibration
:
if
input_quantizer
is
not
None
:
...
...
@@ -300,47 +317,80 @@ class _LayerNormLinear(torch.autograd.Function):
if
weight_quantizer
is
not
None
:
weight_quantizer
.
calibrate
(
weight
)
ub_obj
=
None
ub_type
=
None
rs_out
=
None
if
ub_overlap_rs_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_type
=
tex
.
CommOverlapType
.
RS
out_shape
=
[
reduce
(
multiply_op
,
inp_shape
[:
-
1
])
//
tp_world_size
,
out_features
]
rs_out
=
torch
.
empty
(
out_shape
,
dtype
=
activation_dtype
,
device
=
ln_out_total
.
device
)
elif
ub_overlap_ag_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_type
=
tex
.
CommOverlapType
.
AG
if
fp8
:
assert
ub_obj
.
is_fp8_ubuf
(),
"AG overlap with FP8 GEMM inputs requires FP8 buffer."
ln_out_total
=
ub_obj
.
get_buffer
(
input_quantizer
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.gemm"
)
fprop_gemm_use_split_accumulator
=
_2X_ACC_FPROP
# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator
=
_2X_ACC_FPROP
if
fp8
:
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
hasattr
(
recipe
,
"fp8_gemm_fprop"
):
fprop_gemm_use_split_accumulator
=
recipe
.
fp8_gemm_fprop
.
use_split_accumulator
use_split_accumulator
=
recipe
.
fp8_gemm_fprop
.
use_split_accumulator
# Configure output quantizer
if
output_quantizer
is
not
None
:
output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
out
,
*
_
,
rs_out
=
general_gemm
(
# Output buffer for Userbuffers reduce-scatter
reduce_scatter_out
=
None
if
ub_overlap_rs_fprop
:
out_shape
=
list
(
inp_shape
)
out_shape
[
0
]
//=
tp_world_size
out_shape
[
-
1
]
=
out_features
reduce_scatter_out
=
torch
.
empty
(
out_shape
,
dtype
=
activation_dtype
,
device
=
inp
.
device
)
# ------------------------------------------------------
# Forward GEMM
# Note: y = x * w^T
# ------------------------------------------------------
nvtx_range_push
(
f
"
{
nvtx_label
}
.gemm"
)
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
weightmat
,
ln_out_total
,
get_workspace
(),
quantization_params
=
output_quantizer
,
out_dtype
=
activation_dtype
,
bias
=
bias
,
use_split_accumulator
=
fprop_gemm_
use_split_accumulator
,
use_split_accumulator
=
use_split_accumulator
,
ub
=
ub_obj
,
ub_type
=
ub_type
,
extra_output
=
r
s
_out
,
extra_output
=
r
educe_scatter
_out
,
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.gemm"
)
# ------------------------------------------------------
# Finished forward GEMM...
# ------------------------------------------------------
# Deallocate GEMM input tensor if no longer needed
if
not
weight
.
requires_grad
and
not
return_layernorm_output
:
ln_out
=
ln_out_total
=
None
clear_tensor_data
(
ln_out
,
ln_out_total
)
# ------------------------------------------------------
# Prepare output tensor
# Note: Perform tensor-parallel communication
# ------------------------------------------------------
out
=
None
if
ub_overlap_rs_fprop
:
out
=
reduce_scatter_out
elif
parallel_mode
==
"row"
and
tp_size
>
1
:
nvtx_range_push
(
f
"
{
nvtx_label
}
.row_parallel_comm"
)
out
=
gemm_out
if
sequence_parallel
:
out
,
_
=
reduce_scatter_along_first_dim
(
out
,
tp_group
)
elif
tensor_parallel
:
if
symmetric_ar_type
is
not
None
:
out
,
_
=
symmetric_all_reduce
(
out
,
tp_group
,
all_reduce_type
=
symmetric_ar_type
)
else
:
out
,
_
=
allreduce
(
out
,
tp_group
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.row_parallel_comm"
)
else
:
out
=
gemm_out
out
=
out
.
view
(
-
1
,
*
inp_shape
[
1
:
-
1
],
out_features
)
# ------------------------------------------------------
# Output tensor is ready to return...
# ------------------------------------------------------
if
not
weight
.
requires_grad
:
if
not
return_layernorm_output
:
ln_out
=
ln_out_total
=
None
clear_tensor_data
(
ln_out
,
ln_out_total
)
# ------------------------------------------------------
# Cache state for backward pass
# ------------------------------------------------------
if
is_grad_enabled
:
ctx
.
weight_quantizer
=
weight_quantizer
...
...
@@ -351,19 +401,15 @@ class _LayerNormLinear(torch.autograd.Function):
# Input with column-wise usage is needed for wgrad GEMM.
if
backward_needs_input
:
if
isinstance
(
ln_out
,
QuantizedTensor
):
if
isinstance
(
ln_out
,
QuantizedTensor
Base
):
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
if
isinstance
(
ln_out
,
MXFP8TensorBase
)
or
not
ctx
.
ln_out_needs_gather
:
ln_out
.
update_usage
(
rowwise_usage
=
False
)
# For force_hp_blockwise_ln_out_gather, we should
# be saving the unquantized ln_out to ctx.
assert
not
force_hp_blockwise_ln_out_gather
# Weight with column-wise usage is needed for dgrad GEMM.
if
isinstance
(
weightmat
,
QuantizedTensor
):
if
isinstance
(
weightmat
,
QuantizedTensor
Base
):
weightmat
.
update_usage
(
columnwise_usage
=
True
)
if
cpu_offloading
:
...
...
@@ -406,7 +452,7 @@ class _LayerNormLinear(torch.autograd.Function):
)
ctx
.
save_for_backward
(
*
tensors_to_save
)
ctx
.
tensor_objects
=
tensor_objects
ctx
.
requires_dgrad
=
inp
.
requires_grad
ctx
.
requires_dgrad
=
inp
_
requires_grad
ctx
.
requires_wgrad
=
weight
.
requires_grad
ctx
.
quantized_weight
=
quantized_weight
if
fuse_wgrad_accumulation
and
weight
.
requires_grad
:
...
...
@@ -439,7 +485,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx
.
ub_bulk_wgrad
=
ub_bulk_wgrad
ctx
.
ub_bulk_dgrad
=
ub_bulk_dgrad
ctx
.
ub_name
=
ub_name
ctx
.
requires_dgrad
=
inp
.
requires_grad
ctx
.
requires_dgrad
=
inp
_
requires_grad
ctx
.
normalization
=
normalization
ctx
.
reduce_and_update_bwd_fp8_tensors
=
False
if
ctx
.
fp8
and
requires_grad
(
inp
,
ln_weight
,
ln_bias
,
weight
,
bias
):
...
...
@@ -450,29 +496,16 @@ class _LayerNormLinear(torch.autograd.Function):
ctx
.
wgrad_store
=
wgrad_store
ctx
.
debug
=
debug
# Row Parallel Linear
if
ub_overlap_rs_fprop
:
out
=
rs_out
elif
parallel_mode
==
"row"
:
nvtx_range_push
(
f
"
{
nvtx_label
}
.row_parallel_comm"
)
if
sequence_parallel
:
out
,
_
=
reduce_scatter_along_first_dim
(
out
,
tp_group
)
elif
tensor_parallel
:
if
symmetric_ar_type
is
not
None
:
out
,
_
=
symmetric_all_reduce
(
out
,
tp_group
,
all_reduce_type
=
symmetric_ar_type
)
else
:
out
,
_
=
allreduce
(
out
,
tp_group
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.row_parallel_comm"
)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
out
=
out
.
view
(
-
1
,
*
inp_shape
[
1
:
-
1
],
out_features
)
# ------------------------------------------------------
# Cached state for backward pass is ready...
# ------------------------------------------------------
if
return_layernorm_output
:
if
return_layernorm_output_gathered
:
shape
=
list
(
inp
.
shape
)
shape
=
list
(
inp
_
shape
)
shape
[
0
]
*=
tp_size
return
out
,
ln_out_return
.
view
(
shape
)
return
out
,
ln_out_return
.
view
_as
(
inp
)
return
out
,
ln_out_return
.
view
(
inp
_shape
)
return
out
@
staticmethod
...
...
@@ -487,24 +520,6 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_label
=
f
"
{
nvtx_label
}
.
{
ctx
.
ub_name
}
"
with
torch
.
cuda
.
nvtx
.
range
(
"_LayerNormLinear_backward"
):
if
(
ctx
.
fp8
and
any
(
[
ctx
.
ub_overlap_ag
,
ctx
.
ub_overlap_rs_dgrad
,
ctx
.
ub_bulk_dgrad
,
ctx
.
ub_bulk_wgrad
,
]
)
and
(
ctx
.
fp8_recipe
is
not
None
)
):
if
not
ctx
.
fp8_recipe
.
float8_per_tensor_scaling
():
raise
NotImplementedError
(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
saved_tensors
=
ctx
.
saved_tensors
(
# pylint: disable=unbalanced-tuple-unpacking
inputmat
,
...
...
@@ -549,66 +564,50 @@ class _LayerNormLinear(torch.autograd.Function):
if
ctx
.
requires_wgrad
and
ctx
.
fuse_wgrad_accumulation
:
origin_weight
.
main_grad
=
main_grad
# Configure Userbuffers communication (comm+GEMM overlap)
ctx
.
ub_obj_gradout
=
None
ub_obj_dgrad
=
None
ub_obj_wgrad
=
None
ub_type_dgrad
=
None
ub_type_wgrad
=
None
dgrad_shape
=
[
reduce
(
multiply_op
,
ctx
.
inp_shape
[:
-
1
]),
ctx
.
inp_shape
[
-
1
]]
rs_out
=
None
dgrad_bulk
=
None
if
ctx
.
ub_overlap_ag
:
# Overlap grad_output all-gather with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
elif
ctx
.
ub_overlap_rs_dgrad
:
# Overlap dgrad reduce-scatter with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
RS
rs_out
=
torch
.
empty
(
dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
inputmat
.
device
)
else
:
if
ctx
.
ub_bulk_dgrad
:
# Overlap inputmat all-gather with dgrad compute
# NOTE: Copying into communication buffer will always prefer rowwise data,
# and will copy columnwise data if rowwise does not exist. In that case,
# the all-gather will apply to the leading dimension of the transpose,
# which then needs to be interleaved correctly before WGRAD.
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
ub_obj_dgrad
.
copy_into_buffer
(
ln_out
,
ctx
.
input_quantizer
,
local_chunk
=
True
)
if
ctx
.
ub_bulk_wgrad
:
# Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
)
ub_type_wgrad
=
tex
.
CommOverlapType
.
RS
ub_obj_wgrad
.
set_buffer_params
(
ctx
.
grad_input_quantizer
)
dgrad_bulk
=
ub_obj_wgrad
.
get_buffer
(
ctx
.
grad_input_quantizer
)
# --------------------------------------------------
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# --------------------------------------------------
# Configure quantizer for grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if
ctx
.
grad_output_quantizer
is
not
None
:
rowwise_usage
=
True
columnwise_usage
=
True
if
ctx
.
ub_overlap_ag
and
isinstance
(
ctx
.
grad_output_quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
),
):
# If data is in FP8 and communication is handled
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage
=
False
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
rowwise_usage
,
columnwise
=
columnwise_usage
,
)
quantizer
=
ctx
.
grad_output_quantizer
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
if
ctx
.
ub_overlap_ag
:
# Userbuffers only supports communication for one
# tensor usage at a time. Configure quantizer with
# usage for only dgrad GEMM.
quantizer
.
set_usage
(
columnwise
=
False
)
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
...
...
@@ -624,12 +623,21 @@ class _LayerNormLinear(torch.autograd.Function):
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.grad_output_preprocess"
)
# Launch tensor-parallel communication for LayerNorm out tensor
# --------------------------------------------------
# Grad output tensor is ready for computing grad input...
# --------------------------------------------------
# --------------------------------------------------
# Prepare GEMM input tensor
# Note: Input tensor is needed for wgrad GEMM.
# Tensor-parallel communication is overlapped with dgrad
# GEMM.
# --------------------------------------------------
ln_out_total
=
None
ln_out_total_work
=
None
if
ctx
.
ln_out_needs_gather
and
not
ctx
.
ub_bulk_dgrad
:
if
ctx
.
ln_out_needs_gather
:
quantizer
=
None
if
ctx
.
input_quantizer
is
not
None
:
if
ctx
.
input_quantizer
is
not
None
and
not
ctx
.
force_hp_blockwise_ln_out_gather
:
quantizer
=
ctx
.
input_quantizer
if
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)):
# If data is in FP8, we compute FP8 transposes manually
...
...
@@ -637,70 +645,92 @@ class _LayerNormLinear(torch.autograd.Function):
else
:
# wgrad GEMM requires input with column-wise usage
quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.column_parallel_comm_input"
)
# async_op is not compatible with high precision gather since
# gather_along_first_dim does not offer callback chaining.
gather_quantizer
=
None
if
ctx
.
force_hp_blockwise_ln_out_gather
else
quantizer
ln_out_total
,
ln_out_total_work
=
gather_along_first_dim
(
ln_out
,
ctx
.
tp_group
,
async_op
=
True
,
quantizer
=
gather_quantizer
,
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.column_parallel_comm_input"
)
if
ctx
.
ub_bulk_dgrad
:
ln_out_total
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ub_obj_dgrad
,
ln_out
,
quantizer
,
ctx
.
tp_group
,
)
else
:
nvtx_range_push
(
f
"
{
nvtx_label
}
.column_parallel_comm_input"
)
ln_out_total
,
ln_out_total_work
=
gather_along_first_dim
(
ln_out
,
ctx
.
tp_group
,
async_op
=
True
,
quantizer
=
quantizer
,
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.column_parallel_comm_input"
)
else
:
ln_out_total
=
ln_out
# Check whether to output wgrad GEMM directly into main grad
if
ctx
.
is_first_microbatch
is
not
None
:
accumulate_wgrad_into_param_main_grad
=
(
ctx
.
fuse_wgrad_accumulation
and
not
ctx
.
is_first_microbatch
)
else
:
accumulate_wgrad_into_param_main_grad
=
ctx
.
fuse_wgrad_accumulation
# dgrad GEMM
if
ctx
.
grad_input_quantizer
is
not
None
:
ctx
.
grad_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.dgrad_gemm"
)
dgrad_gemm_use_split_accumulator
=
_2X_ACC_DGRAD
# --------------------------------------------------
# Input tensor is ready for computing grad weight...
# --------------------------------------------------
# --------------------------------------------------
# Compute grad input tensor
# Note: Gradient w.r.t. GEMM input (i.e. norm output).
# --------------------------------------------------
# Make sure required data is available
if
isinstance
(
grad_output
,
QuantizedTensorBase
):
grad_output
.
update_usage
(
rowwise_usage
=
True
)
if
ctx
.
weight_quantizer
is
not
None
and
isinstance
(
weight
,
QuantizedTensorBase
):
weight
.
update_usage
(
columnwise_usage
=
True
)
# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator
=
_2X_ACC_DGRAD
if
ctx
.
fp8
:
recipe
=
ctx
.
fp8_recipe
if
hasattr
(
recipe
,
"fp8_gemm_dgrad"
):
dgrad_gemm_use_split_accumulator
=
recipe
.
fp8_gemm_dgrad
.
use_split_accumulator
use_split_accumulator
=
recipe
.
fp8_gemm_dgrad
.
use_split_accumulator
# Update grad input quantizer
if
ctx
.
grad_input_quantizer
is
not
None
:
ctx
.
grad_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
ctx
.
weight_quantizer
is
not
None
and
isinstance
(
weight
,
QuantizedTensor
):
weight
.
update_usage
(
rowwise_usage
=
ctx
.
weight_quantizer
.
rowwise_usage
,
columnwise_usage
=
ctx
.
weight_quantizer
.
columnwise_usage
,
# Output buffers for Userbuffers reduce-scatter
gemm_out
=
None
reduce_scatter_out
=
None
if
ctx
.
ub_overlap_rs_dgrad
:
reduce_scatter_out
=
torch
.
empty
(
dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
grad_outputs
[
0
].
device
)
dgrad
,
*
_
=
general_gemm
(
elif
ctx
.
ub_bulk_wgrad
:
gemm_out
=
ub_obj_wgrad
.
get_buffer
(
local_chunk
=
False
)
# dgrad GEMM
# Note: dx = dy * w
nvtx_range_push
(
f
"
{
nvtx_label
}
.dgrad_gemm"
)
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
weight
,
grad_output
,
get_workspace
(),
layout
=
"NN"
,
grad
=
True
,
quantization_params
=
ctx
.
grad_input_quantizer
,
out
=
dgrad_bulk
,
out
=
gemm_out
,
out_dtype
=
ctx
.
activation_dtype
,
use_split_accumulator
=
dgrad_gemm_
use_split_accumulator
,
use_split_accumulator
=
use_split_accumulator
,
ub
=
ub_obj_dgrad
,
ub_type
=
ub_type_dgrad
,
extra_output
=
r
s
_out
,
extra_output
=
r
educe_scatter
_out
,
bulk_overlap
=
ctx
.
ub_bulk_dgrad
,
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.dgrad_gemm"
)
# Launch tensor-parallel communication
# Prepare grad input tensor
# Note: Perform tensor-parallel communication
dgrad
=
None
dgrad_work
=
None
if
ctx
.
ub_overlap_rs_dgrad
:
dgrad
=
rs_out
elif
ctx
.
parallel_mode
==
"column"
and
not
ctx
.
ub_bulk_wgrad
:
dgrad
=
reduce_scatter_out
elif
ctx
.
ub_bulk_wgrad
:
dgrad
=
ub_obj_wgrad
.
get_buffer
(
local_chunk
=
True
)
elif
ctx
.
parallel_mode
==
"column"
and
ctx
.
tp_size
>
1
:
nvtx_range_push
(
f
"
{
nvtx_label
}
.column_parallel_comm_dgrad"
)
dgrad
=
gemm_out
if
ctx
.
sequence_parallel
:
if
ctx
.
return_layernorm_output
and
ctx
.
return_layernorm_output_gathered
:
dgrad
=
dgrad
+
grad_outputs
[
1
].
view_as
(
dgrad
)
dgrad
,
dgrad_work
=
reduce_scatter_along_first_dim
(
dgrad
,
ctx
.
tp_group
,
...
...
@@ -709,41 +739,55 @@ class _LayerNormLinear(torch.autograd.Function):
else
:
dgrad
,
dgrad_work
=
allreduce
(
dgrad
,
ctx
.
tp_group
,
async_op
=
True
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.column_parallel_comm_dgrad"
)
else
:
dgrad
=
gemm_out
# --------------------------------------------------
# Grad input tensor has been computed...
# --------------------------------------------------
# --------------------------------------------------
# Compute grad weight
# --------------------------------------------------
# Compute grad weight tensor
wgrad
=
None
if
ctx
.
requires_wgrad
:
# Synchronize tensor-parallel communication for input tensor
if
ctx
.
ub_bulk_dgrad
:
ln_out_total
=
ub_obj_dgrad
.
get_buffer
(
ctx
.
input_quantizer
)
if
ctx
.
fp8
:
# FP8 GEMM on Hopper only supports TN layout so the gathered input must have
# a valid transpose.
if
ln_out
.
_data
is
None
:
# All-gather executed on columnwise data and result is in rowwise data,
# so we need to fix the interleaving before WGRAD.
ln_out_total
=
_fix_gathered_fp8_transpose
(
ln_out_total
,
ctx
.
tp_size
)
else
:
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
ln_out_total
.
_create_transpose
()
# Prepare input tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if
ln_out_total_work
is
not
None
:
ln_out_total_work
.
wait
()
ln_out_total_work
=
None
if
ctx
.
input_quantizer
is
not
None
and
not
isinstance
(
ln_out_total
,
QuantizedTensor
):
# Async gather may have been done in BF16
# call quantizer after gather.
ctx
.
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ln_out_total
=
ctx
.
input_quantizer
(
ln_out_total
)
# Make sure GEMM inputs have required data
if
isinstance
(
ln_out_total
,
QuantizedTensor
):
ln_out_total
.
update_usage
(
columnwise_usage
=
True
)
if
isinstance
(
grad_output
,
QuantizedTensor
):
grad_output
.
update_usage
(
columnwise_usage
=
True
)
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
ln_out_total
,
QuantizedTensorBase
):
ln_out_total
.
update_usage
(
columnwise_usage
=
True
)
else
:
ctx
.
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ln_out_total
=
ctx
.
input_quantizer
(
ln_out_total
)
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if
ctx
.
ub_overlap_ag
and
isinstance
(
ctx
.
grad_output_quantizer
,
MXFP8Quantizer
):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
grad_output
,
_
=
gather_along_first_dim
(
grad_outputs
[
0
],
ctx
.
tp_group
,
quantizer
=
ctx
.
grad_output_quantizer
,
)
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
grad_output
,
QuantizedTensorBase
):
grad_output
.
update_usage
(
columnwise_usage
=
True
)
else
:
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
grad_output
=
ctx
.
grad_output_quantizer
(
grad_output
)
# Figure out whether to use split accumulator
use_split_accumulator
=
_2X_ACC_WGRAD
...
...
@@ -752,55 +796,95 @@ class _LayerNormLinear(torch.autograd.Function):
if
hasattr
(
recipe
,
"fp8_gemm_wgrad"
):
use_split_accumulator
=
recipe
.
fp8_gemm_wgrad
.
use_split_accumulator
# Output buffer for overlapping grad input
# Figure out whether to output wgrad GEMM directly into main grad
if
ctx
.
is_first_microbatch
is
not
None
:
accumulate_wgrad_into_param_main_grad
=
(
ctx
.
fuse_wgrad_accumulation
and
not
ctx
.
is_first_microbatch
)
else
:
accumulate_wgrad_into_param_main_grad
=
ctx
.
fuse_wgrad_accumulation
# Output buffer for overlapping FP8 grad input
# reduce-scatter with wgrad GEMM
reduce_scatter_out
=
None
if
ctx
.
ub_bulk_wgrad
and
ub_obj_wgrad
.
is_fp8_ubuf
():
r
s
_out
=
torch
.
empty
(
dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
inputmat
.
device
r
educe_scatter
_out
=
torch
.
empty
(
dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
grad_outputs
[
0
]
.
device
)
# wgrad GEMM
# Note: Fuse with bgrad computation if needed
nvtx_range_push
(
f
"
{
nvtx_label
}
.wgrad_gemm"
)
general_gemm_wgrad
=
functools
.
partial
(
general_gemm
,
out_dtype
=
(
# Arguments to include in wgrad GEMM closure
wgrad_gemm_kwargs
=
{
"workspace"
:
get_workspace
(),
"out_dtype"
:
(
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
),
workspace
=
get_workspace
(),
layout
=
"NT"
,
grad
=
True
,
bias
=
(
bias
if
(
grad_bias
is
None
and
not
ctx
.
fp8
)
else
None
),
out
=
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
use_split_accumulator
=
use_split_accumulator
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
quantization_params
=
ctx
.
grad_weight_quantizer
,
ub
=
ub_obj_wgrad
,
ub_type
=
ub_type_wgrad
,
extra_output
=
rs_out
,
bulk_overlap
=
ctx
.
ub_bulk_wgrad
,
)
"quantization_params"
:
ctx
.
grad_weight_quantizer
,
"accumulate"
:
accumulate_wgrad_into_param_main_grad
,
"layout"
:
"NT"
,
"out"
:
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
"bias"
:
(
bias
if
(
grad_bias
is
None
and
not
ctx
.
fp8
)
else
None
),
"use_split_accumulator"
:
use_split_accumulator
,
"grad"
:
True
,
"ub"
:
ub_obj_wgrad
,
"ub_type"
:
ub_type_wgrad
,
"extra_output"
:
reduce_scatter_out
,
"bulk_overlap"
:
ctx
.
ub_bulk_wgrad
,
}
def
wgrad_gemm
(
x
:
torch
.
Tensor
,
dy
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Perform wgrad GEMM: dw = dy^T * x
May be fused with bgrad computation.
May be called outside of this function to enable
some advanced communication/compute overlapping.
"""
nvtx_range_push
(
f
"
{
nvtx_label
}
.wgrad_gemm"
)
dw
,
db
,
*
_
=
general_gemm
(
x
,
dy
,
**
wgrad_gemm_kwargs
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.wgrad_gemm"
)
return
dw
,
db
# Choose whether to call wgrad GEMM now or delay
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
ctx
.
wgrad_store
.
put
([
ln_out_total
,
grad_output
],
general_gemm_wgrad
)
if
(
wgrad_gemm_kwargs
[
"ub"
]
is
not
None
or
wgrad_gemm_kwargs
[
"ub_type"
]
is
not
None
or
wgrad_gemm_kwargs
[
"extra_output"
]
is
not
None
or
wgrad_gemm_kwargs
[
"bulk_overlap"
]
):
raise
NotImplementedError
(
"Delayed weight grad computation is not supported "
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx
.
wgrad_store
.
put
([
ln_out_total
,
grad_output
],
wgrad_gemm
)
else
:
wgrad
,
grad_bias_
,
_
,
rs_out
=
general_gemm_wgrad
(
ln_out_total
,
grad_output
)
# Call wgrad GEMM now
wgrad
,
grad_bias_
=
wgrad_gemm
(
ln_out_total
,
grad_output
)
# Update grad bias if needed
if
grad_bias
is
None
:
grad_bias
=
grad_bias_
del
grad_bias_
# Deallocate input tensor
# Deallocate input tensor
if permitted
if
not
ctx
.
return_layernorm_output
:
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
clear_tensor_data
(
ln_out_total
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.wgrad_gemm"
)
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if
ctx
.
ub_bulk_wgrad
:
if
ub_obj_wgrad
.
is_fp8_ubuf
():
dgrad
=
r
s
_out
dgrad
=
r
educe_scatter
_out
else
:
dgrad
=
ub_obj_wgrad
.
get_buffer
(
None
,
local_chunk
=
True
)
dgrad
=
ub_obj_wgrad
.
get_buffer
(
local_chunk
=
True
).
clone
()
# --------------------------------------------------
# Grad weight has been computed...
# --------------------------------------------------
# Don't return grad bias if not needed
if
not
ctx
.
use_bias
:
...
...
@@ -879,7 +963,7 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_range_pop
(
f
"
{
nvtx_label
}
.reduce_and_update_fp8_tensors"
)
# Scatter fp8 weight buffers
# if ctx.fp8 and not isinstance(weight, QuantizedTensor):
# if ctx.fp8 and not isinstance(weight, QuantizedTensor
Base
):
# _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
return
(
...
...
@@ -1405,6 +1489,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
"Splitting QuantizedTensor into multiple params is not supported"
)
else
:
warnings
.
warn
(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights
=
[
w
.
dequantize
()
for
w
in
unfused_weights
]
weight_tensor
=
noop_cat
(
unfused_weights
)
...
...
@@ -1511,7 +1599,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
grad_output_quantizer
=
None
output_quantizer
=
None
input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
input_quantizer
.
internal
=
Fals
e
input_quantizer
.
internal
=
Tru
e
weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
weight_quantizer
.
internal
=
True
if
fp8_output
:
...
...
@@ -1579,3 +1667,12 @@ class LayerNormLinear(TransformerEngineBaseModule):
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_epsilon
=
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
# parallel related
if
self
.
sequence_parallel
and
self
.
parallel_mode
==
"row"
:
# customize grad_output_quantizer with amax reduction TP group
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
with_amax_reduction
=
True
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_reduction_group
=
self
.
tp_group
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
f8c2af4c
...
...
@@ -8,7 +8,6 @@ import warnings
from
typing
import
Callable
,
Optional
,
Tuple
,
Union
from
functools
import
reduce
from
operator
import
mul
as
multiply_op
import
functools
import
torch
from
torch.nn.parameter
import
Parameter
...
...
@@ -20,6 +19,7 @@ import transformer_engine_torch as tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
torch_version
from
.base
import
(
fill_userbuffers_buffer_for_all_gather
,
get_workspace
,
_ub_communicators
,
get_ub
,
...
...
@@ -43,7 +43,6 @@ from ..utils import (
assert_dim_for_fp8_exec
,
clear_tensor_data
,
requires_grad
,
non_tn_fp8_gemm_supported
,
needs_quantized_gemm
,
)
from
..distributed
import
(
...
...
@@ -67,10 +66,11 @@ from ..tensor.float8_tensor import (
)
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
._common
import
apply_normalization
,
_fix_gathered_fp8_transpose
,
WeightGradStore
from
._common
import
apply_normalization
,
WeightGradStore
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
..tensor.quantized_tensor
import
(
QuantizedTensor
,
QuantizedTensorBase
,
Quantizer
,
prepare_for_saving
,
restore_from_saved
,
...
...
@@ -201,24 +201,16 @@ class _LayerNormMLP(torch.autograd.Function):
)
->
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
]:
# pylint: disable=missing-function-docstring
in_features
,
inp_shape
=
ln_weight
.
numel
(),
inp
.
shape
# Make sure input dimensions are compatible
in_features
,
inp_shape
=
ln_weight
.
numel
(),
inp
.
shape
assert
inp_shape
[
-
1
]
==
in_features
,
"GEMM not possible"
inputmat
=
inp
.
view
((
-
1
,
in_features
))
if
fp8
:
assert_dim_for_fp8_exec
(
inputmat
,
fc1_weight
,
fc2_weight
)
if
any
([
ub_overlap_ag
,
ub_overlap_rs
])
and
not
(
FP8GlobalStateManager
.
get_fp8_recipe
().
float8_per_tensor_scaling
()
):
raise
NotImplementedError
(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
activation_func
=
_act_func
(
activation
,
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8
else
None
)[
0
]
device
=
inp
.
device
# Cast for native AMP
inputmat
=
cast_if_needed
(
inputmat
,
activation_dtype
)
...
...
@@ -226,6 +218,38 @@ class _LayerNormMLP(torch.autograd.Function):
if
ln_bias
is
not
None
:
ln_bias
=
cast_if_needed
(
ln_bias
,
activation_dtype
)
tp_world_size
=
get_distributed_world_size
(
tp_group
)
backwards_needs_fc1_input
=
is_grad_enabled
and
fc1_weight
.
requires_grad
device
=
inp
.
device
# Configure Userbuffers communication (comm+GEMM overlap)
ub_overlap_ag
=
ub_overlap_ag
and
is_grad_enabled
and
not
return_layernorm_output_gathered
ub_overlap_rs
=
ub_overlap_rs
and
is_grad_enabled
# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator
=
_2X_ACC_FPROP
if
fp8
:
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
hasattr
(
recipe
,
"fp8_gemm_fprop"
):
use_split_accumulator
=
recipe
.
fp8_gemm_fprop
.
use_split_accumulator
# Configure quantizer for norm output
if
fp8
:
if
fc1_input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for FC1 input tensor"
)
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
backwards_needs_fc1_input
)
if
sequence_parallel
and
isinstance
(
fc1_input_quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)
):
# All-gather is not supported with FP8 column-wise data
fc1_input_quantizer
.
set_usage
(
columnwise
=
False
)
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_fc1_input_gather
=
(
fp8
and
sequence_parallel
and
isinstance
(
fc1_input_quantizer
,
Float8BlockQuantizer
)
)
# for fp8 DelayedScaling: layernorm output = FP8
# only output of the linear is returned
# for return_layernorm_output: layernorm output = High precision, then cast to FP8
...
...
@@ -241,29 +265,6 @@ class _LayerNormMLP(torch.autograd.Function):
# Kernels not available for norm fusion.
with_quantized_norm
=
False
tp_world_size
=
get_distributed_world_size
(
tp_group
)
ub_overlap_ag
=
ub_overlap_ag
and
is_grad_enabled
and
not
return_layernorm_output_gathered
ub_overlap_rs
=
ub_overlap_rs
and
is_grad_enabled
backwards_needs_fc1_input
=
is_grad_enabled
and
fc1_weight
.
requires_grad
# TODO(kwyss): Support FP8 allgather of Float8BlockQuantizer recipe
force_hp_fc1_input_gather
=
(
fp8
and
sequence_parallel
and
isinstance
(
fc1_input_quantizer
,
Float8BlockQuantizer
)
)
# Perform TP communication in high precision.
# Configure quantizer for norm output
if
fp8
:
if
fc1_input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for FC1 input tensor"
)
columnwise_usage
=
backwards_needs_fc1_input
if
(
columnwise_usage
and
sequence_parallel
and
not
isinstance
(
fc1_input_quantizer
,
MXFP8Quantizer
)
):
columnwise_usage
=
False
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
columnwise_usage
)
# Apply normalization
ln_out
,
mu
,
rsigma
=
apply_normalization
(
inputmat
,
...
...
@@ -297,39 +298,43 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
ln_out_total
=
fc1_input_quantizer
(
ln_out_total
)
else
:
quantizer
=
None
if
fp8
or
debug
:
quantizer
=
fc1_input_quantizer
if
not
with_quantized_norm
and
not
force_hp_fc1_input_gather
:
ln_out
=
fc1_input_quantizer
(
ln_out
)
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
ub_overlap_ag
:
# Copy into Userbuffers buffer
ub_obj_lnout
=
get_ub
(
"fc1_fprop"
)
ub_obj_lnout
.
get_buffer
(
fc1_input_quantizer
,
local_chunk
=
True
).
copy_
(
ln_out
)
ln_out_total
=
ub_obj_lnout
.
get_buffer
(
fc1_input_quantizer
)
ln_out_total
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ub_obj_lnout
,
ln_out
,
quantizer
,
tp_group
,
)
else
:
# All-gather with NCCL
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out
,
tp_group
,
quantizer
=
(
fc1_input_quantizer
if
fp8
or
debug
else
None
)
,
quantizer
=
quantizer
,
)
else
:
# NOTE: force_hp_fc1_input_gather is redundant with else, but
# here for clarity. We should not quantize ln_out if bwd needs
# to gather in hp.
if
(
fp8
or
debug
)
and
not
with_quantized_norm
and
not
force_hp_fc1_input_gather
:
if
(
fp8
or
debug
)
and
not
with_quantized_norm
:
ln_out
=
fc1_input_quantizer
(
ln_out
)
ln_out_total
=
ln_out
# Cast weights to expected dtype
fc1_weight_final
=
fc1_weight
fc2_weight_final
=
fc2_weight
if
fp8
or
debug
:
# If weights are not quantized, we call get_weight_workspace,
# which handles weight caching etc.
# FP8 cast to workspace buffer
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
fc1_weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
fc2_weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
fc1_weight_final
=
module
.
get_weight_workspace
(
tensor
=
fc1_weight
,
quantizer
=
fc1_weight_quantizer
,
...
...
@@ -339,7 +344,6 @@ class _LayerNormMLP(torch.autograd.Function):
fsdp_group
=
fsdp_group
,
workspace_dtype
=
activation_dtype
,
)
fc2_weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
fc2_weight_final
=
module
.
get_weight_workspace
(
tensor
=
fc2_weight
,
quantizer
=
fc2_weight_quantizer
,
...
...
@@ -349,6 +353,8 @@ class _LayerNormMLP(torch.autograd.Function):
fsdp_group
=
fsdp_group
,
workspace_dtype
=
activation_dtype
,
)
fc1_weight_final
.
update_usage
(
rowwise_usage
=
True
)
fc2_weight_final
.
update_usage
(
rowwise_usage
=
True
)
else
:
fc1_weight_final
=
cast_if_needed
(
fc1_weight_final
,
activation_dtype
)
fc2_weight_final
=
cast_if_needed
(
fc2_weight_final
,
activation_dtype
)
...
...
@@ -356,6 +362,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Cast biases to expected dtype
bias_dtype
=
activation_dtype
if
needs_quantized_gemm
(
ln_out_total
)
and
activation_dtype
==
torch
.
float32
:
# cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16
bias_dtype
=
torch
.
bfloat16
if
fc1_bias
is
not
None
:
fc1_bias
=
cast_if_needed
(
fc1_bias
,
bias_dtype
)
...
...
@@ -369,7 +376,9 @@ class _LayerNormMLP(torch.autograd.Function):
if
fc1_weight_quantizer
is
not
None
:
fc1_weight_quantizer
.
calibrate
(
fc1_weight
)
# ------------------------------------------------------
# FC1 GEMM
# ------------------------------------------------------
# There are 2 fusions possible:
# - gemm_gelu_fusion - default for full precision, optional for fp8 - need to turn on gemm_gelu_fusion,
...
...
@@ -401,11 +410,16 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_bias
if
not
bias_gelu_fusion
else
None
),
# otherwise bias is added later (fused with gelu)
gelu
=
gemm_gelu_fusion
,
accumulate
=
_2X_ACC_FPROP
,
use_split_accumulator
=
use_split_accumulator
,
ub
=
ub_obj_lnout
,
ub_type
=
tex
.
CommOverlapType
.
AG
if
ub_overlap_ag
else
None
,
)
# ------------------------------------------------------
# Finished FC1 GEMM...
# ------------------------------------------------------
# Deallocate FC1 GEMM input tensor if no longer needed
if
not
is_grad_enabled
and
(
ln_out_total
is
not
ln_out_return
):
clear_tensor_data
(
ln_out_total
)
...
...
@@ -439,45 +453,66 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_input_quantizer
.
calibrate
(
act_out
)
fc2_weight_quantizer
.
calibrate
(
fc2_weight
)
# Configure Userbuffers reduce-scatter if needed
ub_obj_fc2out
=
None
rs_out
=
None
fc2_out
=
None
reduce_scatter_out
=
None
if
ub_overlap_rs
:
ub_obj_fc2out
=
get_ub
(
"fc2_fprop"
)
dim_size
=
list
(
act_out
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
//
tp_world_size
dim_size
[
1
]
=
fc2_weight
.
size
(
0
)
rs_out
=
torch
.
empty
(
dim_size
,
dtype
=
activation_dtype
,
device
=
device
)
else
:
dim_size
=
list
(
act_out
.
size
())
dim_size
[
1
]
=
fc2_weight
.
size
(
0
)
fc2_out
=
torch
.
empty
(
dim_size
,
dtype
=
activation_dtype
,
device
=
device
)
dim_size
[
0
]
//=
tp_world_size
dim_size
[
-
1
]
=
fc2_weight
.
size
(
0
)
reduce_scatter_out
=
torch
.
empty
(
dim_size
,
dtype
=
activation_dtype
,
device
=
device
)
# ------------------------------------------------------
# FC2 GEMM
_
=
general_gemm
(
# ------------------------------------------------------
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
fc2_weight_final
,
act_out
,
get_workspace
(),
out_dtype
=
activation_dtype
,
bias
=
fc2_bias
,
quantization_params
=
fc2_output_quantizer
,
out
=
fc2_out
,
use_split_accumulator
=
_2X_ACC_FPROP
,
use_split_accumulator
=
use_split_accumulator
,
ub
=
ub_obj_fc2out
,
ub_type
=
tex
.
CommOverlapType
.
RS
if
ub_overlap_rs
else
None
,
extra_output
=
r
s
_out
,
extra_output
=
r
educe_scatter
_out
,
)
# ------------------------------------------------------
# Finished FC2 GEMM...
# ------------------------------------------------------
# Deallocate tensors if no longer needed
if
not
is_grad_enabled
:
clear_tensor_data
(
act_out
,
fc1_out_without_bias
,
fc1_out
)
# Prepare output tensor
# Note: Perform tensor-parallel communication if needed
fc2_out
=
None
if
ub_overlap_rs
:
fc2_out
=
reduce_scatter_out
elif
set_parallel_mode
and
sequence_parallel
:
fc2_out
,
_
=
reduce_scatter_along_first_dim
(
gemm_out
,
tp_group
)
elif
set_parallel_mode
and
tensor_parallel
:
if
symmetric_ar_type
is
not
None
:
fc2_out
,
_
=
symmetric_all_reduce
(
gemm_out
,
tp_group
,
all_reduce_type
=
symmetric_ar_type
)
else
:
fc2_out
,
_
=
allreduce
(
gemm_out
,
tp_group
)
else
:
fc2_out
=
gemm_out
fc2_out
=
fc2_out
.
view
(
-
1
,
*
inp_shape
[
1
:
-
1
],
fc2_out
.
shape
[
-
1
])
#
Weight with column-wise usage is needed for dgrad GEMM.
#
Cache state for backward pass
if
is_grad_enabled
:
if
isinstance
(
fc1_weight_final
,
QuantizedTensor
):
# Weight with column-wise usage is needed for dgrad GEMM.
if
isinstance
(
fc1_weight_final
,
QuantizedTensorBase
):
fc1_weight_final
.
update_usage
(
columnwise_usage
=
True
)
if
isinstance
(
fc2_weight_final
,
QuantizedTensor
):
if
isinstance
(
fc2_weight_final
,
QuantizedTensor
Base
):
fc2_weight_final
.
update_usage
(
columnwise_usage
=
True
)
if
not
is_grad_enabled
:
clear_tensor_data
(
act_out
,
fc1_out_without_bias
,
fc1_out
)
else
:
if
cpu_offloading
:
mark_activation_offload
(
inputmat
,
mu
,
rsigma
,
ln_out
,
fc1_out
,
fc1_out_without_bias
,
act_out
...
...
@@ -504,8 +539,6 @@ class _LayerNormMLP(torch.autograd.Function):
if
not
return_layernorm_output
:
clear_tensor_data
(
ln_out
)
ln_out
=
None
elif
force_hp_fc1_input_gather
:
assert
not
isinstance
(
ln_out
,
QuantizedTensor
)
if
not
fc2_weight
.
requires_grad
:
clear_tensor_data
(
act_out
)
act_out
=
None
...
...
@@ -592,28 +625,12 @@ class _LayerNormMLP(torch.autograd.Function):
ctx
.
wgrad_store
=
wgrad_store
# Row Parallel Linear
if
ub_overlap_rs
:
fc2_out
=
rs_out
elif
set_parallel_mode
and
sequence_parallel
:
fc2_out
,
_
=
reduce_scatter_along_first_dim
(
fc2_out
,
tp_group
)
elif
set_parallel_mode
and
tensor_parallel
:
if
symmetric_ar_type
is
not
None
:
fc2_out
,
_
=
symmetric_all_reduce
(
fc2_out
,
tp_group
,
all_reduce_type
=
symmetric_ar_type
)
else
:
fc2_out
,
_
=
allreduce
(
fc2_out
,
tp_group
)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
fc2_out
=
fc2_out
.
view
(
-
1
,
*
inp_shape
[
1
:
-
1
],
fc2_out
.
shape
[
-
1
])
if
return_layernorm_output
:
if
return_layernorm_output_gathered
:
shape
=
list
(
inp_shape
)
shape
[
0
]
*=
tp_size
return
fc2_out
,
ln_out_return
.
view
(
shape
)
return
fc2_out
,
ln_out_return
.
view
_as
(
inp
)
return
fc2_out
,
ln_out_return
.
view
(
inp
_shape
)
return
fc2_out
@
staticmethod
...
...
@@ -622,24 +639,6 @@ class _LayerNormMLP(torch.autograd.Function):
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
...]:
# pylint: disable=missing-function-docstring
with
torch
.
cuda
.
nvtx
.
range
(
"_LayerNormMLP_backward"
):
if
(
ctx
.
fp8
and
any
(
[
ctx
.
ub_overlap_ag
,
ctx
.
ub_overlap_rs_dgrad
,
ctx
.
ub_bulk_dgrad
,
ctx
.
ub_bulk_wgrad
,
]
)
and
(
ctx
.
fp8_recipe
is
not
None
)
):
if
not
ctx
.
fp8_recipe
.
float8_per_tensor_scaling
():
raise
NotImplementedError
(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
saved_tensors
=
ctx
.
saved_tensors
(
# pylint: disable=unbalanced-tuple-unpacking
inputmat
,
...
...
@@ -699,6 +698,16 @@ class _LayerNormMLP(torch.autograd.Function):
# fc2_weight_fp8 if ctx.fp8 and not isinstance(fc2_weight, Float8Tensor) else None,
# )
# Choose whether to use GEMM kernel with split accumulator
dgrad_use_split_accumulator
=
_2X_ACC_DGRAD
wgrad_use_split_accumulator
=
_2X_ACC_WGRAD
if
ctx
.
fp8
:
recipe
=
ctx
.
fp8_recipe
if
hasattr
(
recipe
,
"fp8_gemm_dgrad"
):
dgrad_use_split_accumulator
=
recipe
.
fp8_gemm_dgrad
.
use_split_accumulator
if
hasattr
(
recipe
,
"fp8_gemm_wgrad"
):
wgrad_use_split_accumulator
=
recipe
.
fp8_gemm_wgrad
.
use_split_accumulator
# No need to do bulk DGRAD/WGRAD overlap if WGRAD is not required
ctx
.
ub_bulk_dgrad
=
ctx
.
fc1_weight_requires_grad
and
ctx
.
ub_bulk_dgrad
ctx
.
ub_bulk_wgrad
=
ctx
.
fc1_weight_requires_grad
and
ctx
.
ub_bulk_wgrad
...
...
@@ -707,20 +716,13 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if
ctx
.
fc2_grad_output_quantizer
is
not
None
:
rowwise_usage
=
True
columnwise_usage
=
True
if
ctx
.
ub_overlap_ag
and
isinstance
(
ctx
.
fc2_grad_output_quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
),
):
# If data is in FP8 and communication is handled
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage
=
False
ctx
.
fc2_grad_output_quantizer
.
set_usage
(
rowwise
=
rowwise_usage
,
columnwise
=
columnwise_usage
,
)
quantizer
=
ctx
.
fc2_grad_output_quantizer
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
if
ctx
.
ub_overlap_ag
:
# Userbuffers only supports communication for one
# tensor usage at a time. Configure quantizer with
# usage for only dgrad GEMM.
quantizer
.
set_usage
(
columnwise
=
False
)
# Prepare FC2 grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
...
...
@@ -738,14 +740,10 @@ class _LayerNormMLP(torch.autograd.Function):
# Launch tensor-parallel communication for FC1 GEMM input
ln_out_total
=
None
ln_out_total_work
=
None
if
(
ctx
.
fc1_weight_requires_grad
and
ctx
.
tensor_parallel
and
ctx
.
sequence_parallel
and
not
ctx
.
ub_bulk_dgrad
):
ub_obj_fc1_dgrad
=
None
if
ctx
.
fc1_weight_requires_grad
and
ctx
.
tensor_parallel
and
ctx
.
sequence_parallel
:
quantizer
=
None
if
ctx
.
fp8
or
ctx
.
debug
:
if
ctx
.
fp8
or
ctx
.
debug
and
not
ctx
.
force_hp_fc1_input_gather
:
quantizer
=
ctx
.
fc1_input_quantizer
if
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)):
# If data is in FP8, we compute FP8 transposes manually
...
...
@@ -753,13 +751,21 @@ class _LayerNormMLP(torch.autograd.Function):
else
:
# wgrad GEMM requires input with column-wise usage
quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
gather_quantizer
=
None
if
ctx
.
force_hp_fc1_input_gather
else
quantizer
ln_out_total
,
ln_out_total_work
=
gather_along_first_dim
(
ln_out
,
ctx
.
tp_group
,
async_op
=
True
,
quantizer
=
gather_quantizer
,
)
if
ctx
.
ub_bulk_dgrad
:
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
)
ln_out_total
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ub_obj_fc1_dgrad
,
ln_out
,
quantizer
,
ctx
.
tp_group
,
)
else
:
ln_out_total
,
ln_out_total_work
=
gather_along_first_dim
(
ln_out
,
ctx
.
tp_group
,
async_op
=
True
,
quantizer
=
quantizer
,
)
else
:
ln_out_total
=
ln_out
...
...
@@ -770,6 +776,11 @@ class _LayerNormMLP(torch.autograd.Function):
)
else
:
accumulate_wgrad_into_param_main_grad
=
ctx
.
fuse_wgrad_accumulation
# --------------------------------------------------
# FC2 DGRAD
# --------------------------------------------------
# There are 6 possible fusion paths
# 1 high-precision bias_gelu_fusion: gemm, FC1_bias + gelu,
# 2 high-precision fc2_dgrad_gemm_gelu_fusion: gemm + gelu, FC1_bias + quantize
...
...
@@ -784,12 +795,15 @@ class _LayerNormMLP(torch.autograd.Function):
and
(
not
ctx
.
debug
)
)
# FC2 DGRAD; Unconditional
if
ctx
.
fc2_weight_quantizer
is
not
None
and
isinstance
(
ctx
.
fc2_weight
,
QuantizedTensor
):
ctx
.
fc2_weight
.
update_usage
(
rowwise_usage
=
ctx
.
fc2_weight_quantizer
.
rowwise_usage
,
columnwise_usage
=
ctx
.
fc2_weight_quantizer
.
columnwise_usage
,
)
# Make sure required data is available
if
isinstance
(
grad_output
,
QuantizedTensorBase
):
grad_output
.
update_usage
(
rowwise_usage
=
True
)
if
ctx
.
fc2_weight_quantizer
is
not
None
and
isinstance
(
ctx
.
fc2_weight
,
QuantizedTensorBase
):
ctx
.
fc2_weight
.
update_usage
(
columnwise_usage
=
True
)
# Perform GEMM
gemm_output
,
*
_
=
general_gemm
(
fc2_weight
,
grad_output
,
...
...
@@ -804,52 +818,107 @@ class _LayerNormMLP(torch.autograd.Function):
out_dtype
=
ctx
.
activation_dtype
,
gelu
=
fc2_dgrad_gemm_gelu_fusion
,
gelu_in
=
fc1_out
if
fc2_dgrad_gemm_gelu_fusion
else
None
,
use_split_accumulator
=
_2X_ACC_DGRAD
,
use_split_accumulator
=
dgrad_use_split_accumulator
,
ub
=
ub_obj_fc2_dgrad
,
ub_type
=
tex
.
CommOverlapType
.
AG
if
ctx
.
ub_overlap_ag
else
None
,
)
# Prepare input grad tensor
dact
=
None
fc2_dgrad
=
None
if
fc2_dgrad_gemm_gelu_fusion
:
dact
=
gemm_output
fc2_dgrad
=
None
else
:
fc2_dgrad
=
gemm_output
# --------------------------------------------------
# Finished FC2 DGRAD...
# --------------------------------------------------
# --------------------------------------------------
# FC2 WGRAD
# --------------------------------------------------
fc2_wgrad
=
None
if
ctx
.
fc2_weight_requires_grad
:
if
isinstance
(
act_out
,
QuantizedTensor
):
act_out
.
update_usage
(
rowwise_usage
=
True
,
columnwise_usage
=
True
)
if
isinstance
(
grad_output
,
QuantizedTensor
):
grad_output
.
update_usage
(
rowwise_usage
=
True
,
columnwise_usage
=
True
)
# Prepare input tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
act_out
,
QuantizedTensorBase
):
act_out
.
update_usage
(
columnwise_usage
=
True
)
else
:
ctx
.
fc2_input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
act_out
=
ctx
.
fc2_input_quantizer
(
act_out
)
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if
ctx
.
ub_overlap_ag
and
isinstance
(
ctx
.
fc2_grad_output_quantizer
,
MXFP8Quantizer
):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx
.
fc2_grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
grad_output
,
_
=
gather_along_first_dim
(
grad_outputs
[
0
],
ctx
.
tp_group
,
quantizer
=
ctx
.
fc2_grad_output_quantizer
,
)
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
grad_output
,
QuantizedTensorBase
):
grad_output
.
update_usage
(
columnwise_usage
=
True
)
else
:
ctx
.
fc2_grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
grad_output
=
ctx
.
fc2_grad_output_quantizer
(
grad_output
)
# Whether to set grad arg in general_gemm
grad_arg
=
True
if
ctx
.
fp8
and
ctx
.
fp8_recipe
.
float8_block_scaling
():
grad_arg
=
False
general_gemm_fc2_wgrad
=
functools
.
partial
(
general_gemm
,
out_dtype
=
(
# Arguments to include in wgrad GEMM closure
fc2_wgrad_gemm_kwargs
=
{
"workspace"
:
get_workspace
(),
"out_dtype"
:
(
origin_fc2_weight
.
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
),
workspace
=
get_workspace
(),
quantization_params
=
ctx
.
fc2_grad_weight_quantizer
,
# wgrad in high precision
layout
=
"NT"
,
grad
=
grad_arg
,
bias
=
fc2_bias
if
fc2_bias
is
not
None
and
fc2_bias_grad
is
None
else
None
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
use_split_accumulator
=
_2X_ACC_WGRAD
,
out
=
origin_fc2_weight
.
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
)
"quantization_params"
:
ctx
.
fc2_grad_weight_quantizer
,
# wgrad in high precision
"accumulate"
:
accumulate_wgrad_into_param_main_grad
,
"layout"
:
"NT"
,
"out"
:
origin_fc2_weight
.
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
"bias"
:
fc2_bias
if
fc2_bias
is
not
None
and
fc2_bias_grad
is
None
else
None
,
"use_split_accumulator"
:
wgrad_use_split_accumulator
,
"grad"
:
grad_arg
,
}
def
fc2_wgrad_gemm
(
x
:
torch
.
Tensor
,
dy
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Perform FC2 WGRAD GEMM
May be called outside of this function to enable
some advanced communication/compute overlapping.
"""
dw
,
db
,
*
_
=
general_gemm
(
x
,
dy
,
**
fc2_wgrad_gemm_kwargs
)
return
dw
,
db
# Choose whether to call wgrad GEMM now or delay
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
ctx
.
wgrad_store
.
put
([
act_out
,
grad_output
],
general_gemm_fc2_wgrad
)
fc2_wgrad
=
None
ctx
.
wgrad_store
.
put
([
act_out
,
grad_output
],
fc2_wgrad_gemm
)
else
:
fc2_wgrad
,
fc2_bias_grad_
,
*
_
=
general_gemm_fc2_wgrad
(
act_out
,
grad_output
,
)
# Call wgrad GEMM now
fc2_wgrad
,
fc2_bias_grad_
=
fc2_wgrad_gemm
(
act_out
,
grad_output
)
# Update grad bias if needed
if
fc2_bias_grad
is
None
:
if
(
ctx
.
fp8
...
...
@@ -858,12 +927,17 @@ class _LayerNormMLP(torch.autograd.Function):
):
# BGRAD not fused with GEMM for float8 blockwise gemm.
fc2_bias_grad_
=
act_out
.
view
(
-
1
,
act_out
.
shape
[
-
1
]).
sum
(
dim
=
0
)
fc2_bias_grad
=
fc2_bias_grad_
del
fc2_bias_grad_
# Deallocate input tensor if permitted
if
ctx
.
wgrad_store
is
not
None
and
not
ctx
.
wgrad_store
.
delay_wgrad_compute
():
clear_tensor_data
(
act_out
)
# --------------------------------------------------
# Finished FC2 WGRAD...
# --------------------------------------------------
# bias computation
fc1_bias_grad
=
None
fuse_gemm_and_bias_fc1_wgrad
=
False
...
...
@@ -927,63 +1001,69 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_fc1_dgrad
=
None
ub_obj_fc1_wgrad
=
None
ub_type_fc1_dgrad
=
None
ub_type_fc1_wgrad
=
None
fc1_dgrad_shape
=
[
reduce
(
multiply_op
,
inputmat
.
shape
[:
-
1
]),
inputmat
.
shape
[
-
1
]]
fc1_dgrad_rs_out
=
None
fc1_dgrad_bulk
=
None
if
ctx
.
ub_overlap_rs_dgrad
:
# Overlap DGRAD+RS
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
)
ub_type_fc1_dgrad
=
tex
.
CommOverlapType
.
RS
fc1_dgrad_rs_out
=
torch
.
empty
(
fc1_dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
"cuda"
)
else
:
if
ctx
.
ub_bulk_dgrad
:
# Overlap ln_out all-gather with DGRAD compute
# NOTE: Copying into communication buffer will always prefer rowwise data,
# and will copy columnwise data if rowwise does not exist. In that case,
# the all-gather will apply to the leading dimension of the transpose,
# which then needs to be interleaved correctly before WGRAD.
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
)
ub_type_fc1_dgrad
=
tex
.
CommOverlapType
.
AG
ub_obj_fc1_dgrad
.
copy_into_buffer
(
ln_out
,
ctx
.
fc1_input_quantizer
,
local_chunk
=
True
)
if
ctx
.
ub_bulk_wgrad
:
# Overlap FC1 DGRAD reduce-scatter with WGRAD compute
ub_obj_fc1_wgrad
=
get_ub
(
"fc1_wgrad"
)
fc1_dgrad_bulk
=
ub_obj_fc1_wgrad
.
get_buffer
(
None
)
ub_type_fc1_wgrad
=
tex
.
CommOverlapType
.
RS
# FC1 DGRAD: Unconditional
# --------------------------------------------------
# FC1 DGRAD
# --------------------------------------------------
# Make sure required data is available
if
ctx
.
fc1_weight_quantizer
is
not
None
and
isinstance
(
ctx
.
fc1_weight_quantizer
,
QuantizedTensor
ctx
.
fc1_weight_quantizer
,
QuantizedTensor
Base
):
ctx
.
fc1_weight
.
update_usage
(
rowwise_usage
=
ctx
.
fc1_weight_quantizer
.
rowwise_usage
,
columnwise_usage
=
ctx
.
fc1_weight_quantizer
.
columnwise_usage
,
ctx
.
fc1_weight
.
update_usage
(
columnwise_usage
=
True
)
# Output buffers for Userbuffers reduce-scatter
gemm_out
=
None
reduce_scatter_out
=
None
if
ctx
.
ub_overlap_rs_dgrad
:
reduce_scatter_out
=
torch
.
empty
(
fc1_dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
"cuda"
)
fc1_dgrad
,
*
_
,
fc1_dgrad_rs_out
=
general_gemm
(
if
ctx
.
ub_bulk_wgrad
:
gemm_out
=
ub_obj_fc1_wgrad
.
get_buffer
(
local_chunk
=
False
)
# dgrad GEMM
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
fc1_weight
,
dact
,
get_workspace
(),
out
=
fc1_dgrad_bulk
,
out
=
gemm_out
,
out_dtype
=
ctx
.
activation_dtype
,
quantization_params
=
ctx
.
fc1_grad_input_quantizer
,
layout
=
"NN"
,
grad
=
True
,
use_split_accumulator
=
dgrad_use_split_accumulator
,
ub
=
ub_obj_fc1_dgrad
,
ub_type
=
ub_type_fc1_dgrad
,
extra_output
=
fc1_dgrad_rs
_out
,
extra_output
=
reduce_scatter
_out
,
bulk_overlap
=
ctx
.
ub_bulk_dgrad
,
)
# Overlap dgrad-RS/AR with wgrad
# Prepare grad input tensor
# Note: Perform tensor-parallel communication
fc1_dgrad
=
None
fc1_dgrad_work
=
None
if
ctx
.
ub_overlap_rs_dgrad
:
fc1_dgrad
=
fc1_dgrad_rs_out
fc1_dgrad
=
reduce_scatter_out
elif
ctx
.
ub_bulk_wgrad
:
fc1_dgrad
=
ub_obj_fc1_wgrad
.
get_buffer
(
local_chunk
=
True
)
elif
ctx
.
set_parallel_mode
and
not
ctx
.
ub_bulk_wgrad
:
fc1_dgrad
=
gemm_out
if
ctx
.
sequence_parallel
:
if
ctx
.
return_layernorm_output
and
ctx
.
return_layernorm_output_gathered
:
fc1_dgrad
=
fc1_dgrad
+
grad_outputs
[
1
].
view_as
(
fc1_dgrad
)
...
...
@@ -994,90 +1074,125 @@ class _LayerNormMLP(torch.autograd.Function):
)
elif
ctx
.
tensor_parallel
:
fc1_dgrad
,
fc1_dgrad_work
=
allreduce
(
fc1_dgrad
,
ctx
.
tp_group
,
async_op
=
True
)
else
:
fc1_dgrad
=
gemm_out
# --------------------------------------------------
# Finished FC1 DGRAD...
# --------------------------------------------------
# --------------------------------------------------
# FC1 WGRAD
# --------------------------------------------------
fc1_wgrad
=
None
if
ctx
.
fc1_weight_requires_grad
:
# Synchronize tensor-parallel communication for FC1 GEMM input tensor
if
ctx
.
ub_bulk_dgrad
:
ln_out_total
=
ub_obj_fc1_dgrad
.
get_buffer
(
ctx
.
fc1_input_quantizer
)
if
ctx
.
fp8
:
if
ln_out
.
_data
is
None
:
# All-gather executed on columnwise data and result is in rowwise data,
# so we need to fix the interleaving before WGRAD.
ln_out_total
=
_fix_gathered_fp8_transpose
(
ln_out_total
,
ctx
.
tp_size
)
elif
not
non_tn_fp8_gemm_supported
():
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
ln_out_total
.
_create_transpose
()
# Prepare input tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if
ln_out_total_work
is
not
None
:
ln_out_total_work
.
wait
()
ln_out_total_work
=
None
if
ctx
.
fc1_input_quantizer
is
not
None
and
not
isinstance
(
ln_out_total
,
QuantizedTensor
):
# Async gather in BF16 does not asynchronously
# call quantizer after gather.
ctx
.
fc1_input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ln_out_total
=
ctx
.
fc1_input_quantizer
(
ln_out_total
)
# Make sure GEMM inputs have required data
if
isinstance
(
ln_out_total
,
QuantizedTensor
):
ln_out_total
.
update_usage
(
columnwise_usage
=
True
)
if
isinstance
(
dact
,
QuantizedTensor
):
dact
.
update_usage
(
columnwise_usage
=
True
)
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
ln_out_total
,
QuantizedTensorBase
):
ln_out_total
.
update_usage
(
columnwise_usage
=
True
)
else
:
ctx
.
fc1_input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ln_out_total
=
ctx
.
fc1_input_quantizer
(
ln_out_total
)
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
dact
,
QuantizedTensorBase
):
dact
.
update_usage
(
columnwise_usage
=
True
)
else
:
ctx
.
fc1_grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
dact
=
ctx
.
fc1_grad_output_quantizer
(
dact
)
# Output buffer for overlapping grad input
# reduce-scatter with wgrad GEMM
reduce_scatter_out
=
None
if
ctx
.
ub_bulk_wgrad
and
ub_obj_fc1_wgrad
.
is_fp8_ubuf
():
fc1_dgrad_rs
_out
=
torch
.
empty
(
reduce_scatter
_out
=
torch
.
empty
(
fc1_dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
"cuda"
)
# wgrad GEMM
gene
ra
l
_gemm_
fc1_wgrad
=
functools
.
partial
(
general_gemm
,
out_dtype
=
(
#
Arguments to include in
wgrad GEMM
closure
fc1_wg
ra
d
_gemm_
kwargs
=
{
"workspace"
:
get_workspace
()
,
"
out_dtype
"
:
(
origin_fc1_weight
.
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
),
workspace
=
get_workspace
(),
layout
=
"NT"
,
quantization_params
=
ctx
.
fc1_grad_weight_quantizer
,
grad
=
fuse_gemm_and_bias_fc1_wgrad
,
bias
=
fc1_bias
if
fuse_gemm_and_bias_fc1_wgrad
else
None
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
out
=
origin_fc1_weight
.
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
ub
=
ub_obj_fc1_wgrad
,
ub_type
=
tex
.
CommOverlapType
.
RS
if
ctx
.
ub_bulk_wgrad
else
None
,
extra_output
=
fc1_dgrad_rs_out
,
bulk_overlap
=
ctx
.
ub_bulk_wgrad
,
)
"quantization_params"
:
ctx
.
fc1_grad_weight_quantizer
,
"accumulate"
:
accumulate_wgrad_into_param_main_grad
,
"layout"
:
"NT"
,
"out"
:
origin_fc1_weight
.
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
"bias"
:
fc1_bias
if
fuse_gemm_and_bias_fc1_wgrad
else
None
,
"use_split_accumulator"
:
wgrad_use_split_accumulator
,
"grad"
:
fuse_gemm_and_bias_fc1_wgrad
,
"ub"
:
ub_obj_fc1_wgrad
,
"ub_type"
:
ub_type_fc1_wgrad
,
"extra_output"
:
reduce_scatter_out
,
"bulk_overlap"
:
ctx
.
ub_bulk_wgrad
,
}
def
fc1_wgrad_gemm
(
x
:
torch
.
Tensor
,
dy
:
torch
.
Tensor
,
_is_delayed
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Perform FC1 WGRAD GEMM
May be called outside of this function to enable
some advanced communication/compute overlapping.
"""
dw
,
db
,
*
_
=
general_gemm
(
x
,
dy
,
**
fc1_wgrad_gemm_kwargs
)
return
dw
,
db
# Choose whether to call wgrad GEMM now or delay
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
ctx
.
wgrad_store
.
put
([
ln_out_total
,
dact
],
general_gemm_fc1_wgrad
)
if
(
fc1_wgrad_gemm_kwargs
[
"ub"
]
is
not
None
or
fc1_wgrad_gemm_kwargs
[
"ub_type"
]
is
not
None
or
fc1_wgrad_gemm_kwargs
[
"extra_output"
]
is
not
None
or
fc1_wgrad_gemm_kwargs
[
"bulk_overlap"
]
):
raise
NotImplementedError
(
"Delayed weight grad computation is not supported "
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx
.
wgrad_store
.
put
([
ln_out_total
,
dact
],
fc1_wgrad_gemm
)
fc1_wgrad
=
None
if
fuse_gemm_and_bias_fc1_wgrad
:
fc1_bias_grad
=
None
else
:
fc1_wgrad_outputs
=
general_gemm_fc1_wgrad
(
ln_out_total
,
dact
,
)
clear_tensor_data
(
ln_out_total
,
dact
)
# Call wgrad GEMM now
fc1_wgrad_outputs
=
fc1_wgrad_gemm
(
ln_out_total
,
dact
)
if
fuse_gemm_and_bias_fc1_wgrad
:
fc1_wgrad
,
fc1_bias_grad
,
*
_
=
fc1_wgrad_outputs
fc1_wgrad
,
fc1_bias_grad
=
fc1_wgrad_outputs
else
:
fc1_wgrad
,
*
_
=
fc1_wgrad_outputs
fc1_wgrad
,
_
=
fc1_wgrad_outputs
# Deallocate tensors if permitted
clear_tensor_data
(
dact
)
if
not
ctx
.
return_layernorm_output_gathered
:
clear_tensor_data
(
ln_out_total
)
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if
ctx
.
ub_bulk_wgrad
:
if
ub_obj_fc1_wgrad
.
is_fp8_ubuf
():
fc1_dgrad
=
fc1_dgrad_rs
_out
fc1_dgrad
=
reduce_scatter
_out
else
:
fc1_dgrad
=
ub_obj_fc1_wgrad
.
get_buffer
(
None
,
local_chunk
=
True
)
fc1_dgrad
=
ub_obj_fc1_wgrad
.
get_buffer
(
local_chunk
=
True
).
clone
()
# --------------------------------------------------
# Finished FC1 WGRAD...
# --------------------------------------------------
# Make sure all tensor-parallel communication is finished
if
ln_out_total_work
is
not
None
:
...
...
@@ -1748,7 +1863,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
)
=
[
None
]
*
12
if
self
.
fp8
:
fc1_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
fc1_input_quantizer
.
internal
=
False
# temporary
fc1_input_quantizer
.
internal
=
True
fc1_weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
fc1_weight_quantizer
.
internal
=
True
fc2_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_INPUT
]
...
...
@@ -1756,6 +1871,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
rowwise
=
True
,
columnwise
=
isinstance
(
fc2_input_quantizer
,
(
MXFP8Quantizer
,
Float8BlockQuantizer
)),
)
fc1_input_quantizer
.
internal
=
True
fc2_weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_WEIGHT
]
fc2_weight_quantizer
.
internal
=
True
if
fp8_output
:
...
...
@@ -1764,11 +1880,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
]
if
torch
.
is_grad_enabled
():
fc2_grad_output_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT
1
tex
.
FP8BwdTensors
.
GRAD_OUTPUT
2
]
fc2_grad_output_quantizer
.
internal
=
True
fc1_grad_output_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_
IN
PUT1
tex
.
FP8BwdTensors
.
GRAD_
OUT
PUT1
]
fc1_grad_output_quantizer
.
internal
=
True
...
...
@@ -1853,25 +1969,25 @@ class LayerNormMLP(TransformerEngineBaseModule):
else
:
# fc2_grad_output_quantizer: set configs about amax epsilon and power_2_scale for fc2_grad_output_quantizer
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT
1
tex
.
FP8BwdTensors
.
GRAD_OUTPUT
2
].
force_pow_2_scales
=
recipe
.
fp8_quant_bwd_grad
.
power_2_scale
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT
1
tex
.
FP8BwdTensors
.
GRAD_OUTPUT
2
].
amax_epsilon
=
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
# fc1_grad_output_quantizer: also set numerical configs for fc1_grad_output_quantizer
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_
IN
PUT1
tex
.
FP8BwdTensors
.
GRAD_
OUT
PUT1
].
force_pow_2_scales
=
recipe
.
fp8_quant_bwd_grad
.
power_2_scale
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_
IN
PUT1
tex
.
FP8BwdTensors
.
GRAD_
OUT
PUT1
].
amax_epsilon
=
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
if
self
.
sequence_parallel
and
self
.
set_parallel_mode
:
# fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT
1
tex
.
FP8BwdTensors
.
GRAD_OUTPUT
2
].
with_amax_reduction
=
True
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT
1
tex
.
FP8BwdTensors
.
GRAD_OUTPUT
2
].
amax_reduction_group
=
self
.
tp_group
def
backward_dw
(
self
):
...
...
transformer_engine/pytorch/module/linear.py
View file @
f8c2af4c
...
...
@@ -6,24 +6,26 @@
from
typing
import
Callable
,
Dict
,
Optional
,
Tuple
,
Union
from
functools
import
reduce
from
operator
import
mul
as
multiply_op
import
warnings
import
functools
import
torch
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
torch_version
from
.base
import
(
get_workspace
,
fill_userbuffers_buffer_for_all_gather
,
get_dummy_wgrad
,
get_ub
,
get_workspace
,
TransformerEngineBaseModule
,
get_dummy_wgrad
,
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
)
from
._common
import
noop_cat
,
_fix_gathered_fp8_transpose
,
WeightGradStore
from
._common
import
noop_cat
,
WeightGradStore
from
..fp8
import
FP8GlobalStateManager
from
..utils
import
(
cast_if_needed
,
...
...
@@ -32,7 +34,6 @@ from ..utils import (
init_method_constant
,
requires_grad
,
needs_quantized_gemm
,
non_tn_fp8_gemm_supported
,
assert_dim_for_fp8_exec
,
nvtx_range_pop
,
nvtx_range_push
,
...
...
@@ -57,6 +58,7 @@ from ..jit import no_torch_dynamo
from
..graph
import
is_graph_capturing
from
..tensor.quantized_tensor
import
(
QuantizedTensor
,
QuantizedTensorBase
,
Quantizer
,
prepare_for_saving
,
restore_from_saved
,
...
...
@@ -125,88 +127,100 @@ class _Linear(torch.autograd.Function):
# Make sure input dimensions are compatible
out_features
,
in_features
=
weight
.
shape
inp_shape
=
inp
.
shape
assert
inp_shape
[
-
1
]
==
in_features
,
"GEMM not possible"
assert
inp
.
shape
[
-
1
]
==
in_features
,
"GEMM not possible"
# Configure tensor-parallel communication
tp_world_size
=
get_distributed_world_size
(
tp_group
)
backward_needs_input
=
is_grad_enabled
and
weight
.
requires_grad
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push
(
f
"
{
nvtx_label
}
.input_cast_comm"
)
inputmat
=
inp
.
view
(
-
1
,
in_features
)
inputmat_total
=
None
with_input_all_gather_nccl
=
(
parallel_mode
==
"column"
and
sequence_parallel
and
not
ub_overlap_ag_fprop
)
own_quantized_input
=
False
# TODO(kwyss): Support FP8 allgather for FP8 block quantization.
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_input_gather
=
(
fp8
and
with_input_all_gather_nccl
and
isinstance
(
input_quantizer
,
Float8BlockQuantizer
)
)
# Perform TP communication in high precision.
)
# Configure Userbuffers communication (comm+GEMM overlap)
ub_obj
=
None
ub_type
=
None
if
ub_overlap_rs_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_type
=
tex
.
CommOverlapType
.
RS
elif
ub_overlap_ag_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_type
=
tex
.
CommOverlapType
.
AG
# ------------------------------------------------------
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# ------------------------------------------------------
nvtx_range_push
(
f
"
{
nvtx_label
}
.input_cast_comm"
)
inputmat
=
inp
# Input tensor to save for backward (maybe sharded)
inputmat_total
=
None
# Input tensor to pass to GEMM (gathered)
own_quantized_input
=
False
if
fp8
:
assert_dim_for_fp8_exec
(
inputmat
,
weight
)
if
any
([
ub_overlap_ag_fprop
,
ub_overlap_rs_fprop
])
and
not
(
FP8GlobalStateManager
.
get_fp8_recipe
().
float8_per_tensor_scaling
()
):
raise
NotImplementedError
(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
if
fp8
or
debug
:
if
input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for input tensor"
)
if
with_input_all_gather_nccl
:
if
force_hp_input_gather
:
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
inputmat_total
,
_
=
gather_along_first_dim
(
inputmat
,
tp_group
,
quantizer
=
input_quantizer
)
else
:
if
not
isinstance
(
inputmat
,
QuantizedTensor
):
columnwise_usage
=
backward_needs_input
and
isinstance
(
input_quantizer
,
MXFP8Quantizer
)
# force_hp_input_gather should enforce this
assert
not
isinstance
(
input_quantizer
,
Float8BlockQuantizer
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
columnwise_usage
)
inputmat
=
input_quantizer
(
inputmat
)
own_quantized_input
=
True
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
inputmat_total
,
_
=
gather_along_first_dim
(
inputmat
,
tp_group
,
quantizer
=
input_quantizer
,
)
if
with_input_all_gather_nccl
or
ub_overlap_ag_fprop
:
# All-gather input tensor
# Cast local input tensor if needed
if
fp8
or
debug
:
if
input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for input tensor"
)
if
not
force_hp_input_gather
and
not
isinstance
(
inputmat
,
QuantizedTensorBase
):
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
backward_needs_input
)
if
isinstance
(
input_quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)
):
# All-gather is not supported with FP8 column-wise data
input_quantizer
.
set_usage
(
columnwise
=
False
)
inputmat
=
input_quantizer
(
inputmat
)
own_quantized_input
=
True
else
:
if
(
FP8GlobalStateManager
.
get_fp8_recipe
().
float8_per_tensor_scaling
()
and
ub_bulk_dgrad
):
# reduce duplicated transpose in `_fix_gathered_fp8_transpose`
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
inputmat
=
cast_if_needed
(
inp
,
activation_dtype
)
# Cast for AMP
# Initialize gathered input tensor
quantizer
=
None
if
fp8
or
debug
:
quantizer
=
input_quantizer
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
with_input_all_gather_nccl
:
# Perform NCCL all-gather
inputmat_total
,
_
=
gather_along_first_dim
(
inputmat
,
tp_group
,
quantizer
=
quantizer
,
)
elif
ub_overlap_ag_fprop
:
# Initialize Userbuffers all-gather
inputmat_total
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ub_obj
,
inputmat
,
quantizer
,
tp_group
,
)
else
:
# Do not all-gather input tensor
if
fp8
or
debug
:
if
isinstance
(
inputmat
,
QuantizedTensorBase
):
inputmat
.
update_usage
(
rowwise_usage
=
True
)
else
:
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
backward_needs_input
,
)
if
not
isinstance
(
inputmat
,
QuantizedTensor
):
if
input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for input tensor"
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
backward_needs_input
)
inputmat
=
input_quantizer
(
inputmat
)
own_quantized_input
=
True
elif
backward_needs_input
:
inputmat
.
update_usage
(
rowwise_usage
=
True
,
columnwise_usage
=
True
)
inputmat_total
=
inputmat
else
:
inputmat
=
cast_if_needed
(
inp
,
activation_dtype
)
if
with_input_all_gather_nccl
:
inputmat_total
,
_
=
gather_along_first_dim
(
inputmat
,
tp_group
)
else
:
inputmat_total
=
inputmat
inputmat
=
cast_if_needed
(
inp
,
activation_dtype
)
# Cast for AMP
inputmat_total
=
inputmat
nvtx_range_pop
(
f
"
{
nvtx_label
}
.input_cast_comm"
)
# ------------------------------------------------------
# Input tensor is ready for GEMM...
# ------------------------------------------------------
# Cast weight to expected dtype
# ------------------------------------------------------
# Prepare weight tensor
# ------------------------------------------------------
weightmat
=
weight
if
fp8
or
debug
:
# Configure quantizer
if
weight_quantizer
is
not
None
:
...
...
@@ -217,7 +231,8 @@ class _Linear(torch.autograd.Function):
and
not
in_fp8_activation_recompute_phase
()
)
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
columnwise_usage
)
# FP8 cast to workspace buffer
# Get quantized weight
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
weightmat
=
module
.
get_weight_workspace
(
tensor
=
weight
,
...
...
@@ -228,19 +243,21 @@ class _Linear(torch.autograd.Function):
fsdp_group
=
fsdp_group
,
workspace_dtype
=
activation_dtype
,
)
weightmat
.
update_usage
(
rowwise_usage
=
True
)
else
:
weightmat
=
cast_if_needed
(
weightmat
,
activation_dtype
)
weightmat
=
cast_if_needed
(
weightmat
,
activation_dtype
)
# Cast for AMP
# ------------------------------------------------------
# Weight tensor is ready for GEMM...
# ------------------------------------------------------
# Cast bias to expected dtype
bias_dtype
=
activation_dtype
if
needs_quantized_gemm
(
inputmat_total
)
and
activation_dtype
==
torch
.
float32
:
# cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16
bias_dtype
=
torch
.
bfloat16
bias
=
cast_if_needed
(
bias
,
bias_dtype
)
if
bias
is
not
None
else
bias
# Configure output quantizer
if
output_quantizer
is
not
None
:
output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
# Calibrate quantizers if needed
if
not
fp8
and
fp8_calibration
:
if
input_quantizer
is
not
None
:
...
...
@@ -248,44 +265,74 @@ class _Linear(torch.autograd.Function):
if
weight_quantizer
is
not
None
:
weight_quantizer
.
calibrate
(
weight
)
ub_obj
=
None
ub_type
=
None
rs_out
=
None
out_dtype
=
activation_dtype
if
ub_overlap_rs_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_type
=
tex
.
CommOverlapType
.
RS
out_shape
=
[
reduce
(
multiply_op
,
inp_shape
[:
-
1
])
//
tp_world_size
,
out_features
]
rs_out
=
torch
.
empty
(
out_shape
,
dtype
=
activation_dtype
,
device
=
inputmat_total
.
device
)
elif
ub_overlap_ag_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_type
=
tex
.
CommOverlapType
.
AG
if
fp8
:
assert
ub_obj
.
is_fp8_ubuf
(),
"AG overlap with FP8 GEMM inputs requires FP8 buffer."
ub_obj
.
copy_into_buffer
(
inputmat_total
,
input_quantizer
,
local_chunk
=
True
)
inputmat_total
=
ub_obj
.
get_buffer
(
input_quantizer
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.gemm"
)
fprop_gemm_use_split_accumulator
=
_2X_ACC_FPROP
# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator
=
_2X_ACC_FPROP
if
fp8
:
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
hasattr
(
recipe
,
"fp8_gemm_fprop"
):
fprop_gemm_use_split_accumulator
=
recipe
.
fp8_gemm_fprop
.
use_split_accumulator
use_split_accumulator
=
recipe
.
fp8_gemm_fprop
.
use_split_accumulator
# Configure output quantizer
if
output_quantizer
is
not
None
:
output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
out
,
*
_
,
rs_out
=
general_gemm
(
# Output buffer for Userbuffers reduce-scatter
reduce_scatter_out
=
None
if
ub_overlap_rs_fprop
:
out_shape
=
list
(
inp
.
shape
)
out_shape
[
0
]
//=
tp_world_size
out_shape
[
-
1
]
=
out_features
reduce_scatter_out
=
torch
.
empty
(
out_shape
,
dtype
=
activation_dtype
,
device
=
inp
.
device
)
# ------------------------------------------------------
# Forward GEMM
# Note: y = x * w^T
# ------------------------------------------------------
nvtx_range_push
(
f
"
{
nvtx_label
}
.gemm"
)
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
weightmat
,
inputmat_total
,
get_workspace
(),
quantization_params
=
output_quantizer
,
out_dtype
=
out
_dtype
,
out_dtype
=
activation
_dtype
,
bias
=
bias
,
use_split_accumulator
=
fprop_gemm_
use_split_accumulator
,
use_split_accumulator
=
use_split_accumulator
,
ub
=
ub_obj
,
ub_type
=
ub_type
,
extra_output
=
r
s
_out
,
extra_output
=
r
educe_scatter
_out
,
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.gemm"
)
# ------------------------------------------------------
# Finished forward GEMM...
# ------------------------------------------------------
# ------------------------------------------------------
# Prepare output tensor
# Note: Perform tensor-parallel communication
# ------------------------------------------------------
out
=
None
if
ub_overlap_rs_fprop
:
out
=
reduce_scatter_out
elif
parallel_mode
==
"row"
and
tp_size
>
1
:
nvtx_range_push
(
f
"
{
nvtx_label
}
.row_parallel_comm"
)
out
=
gemm_out
if
sequence_parallel
:
out
,
_
=
reduce_scatter_along_first_dim
(
out
,
tp_group
)
elif
tensor_parallel
:
if
symmetric_ar_type
is
not
None
:
out
,
_
=
symmetric_all_reduce
(
out
,
tp_group
,
all_reduce_type
=
symmetric_ar_type
)
else
:
out
,
_
=
allreduce
(
out
,
tp_group
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.row_parallel_comm"
)
else
:
out
=
gemm_out
# ------------------------------------------------------
# Output tensor is ready to return...
# ------------------------------------------------------
# ------------------------------------------------------
# Cache state for backward pass
# ------------------------------------------------------
if
is_grad_enabled
:
ctx
.
weight_quantizer
=
weight_quantizer
...
...
@@ -296,19 +343,19 @@ class _Linear(torch.autograd.Function):
)
if
backward_needs_input
:
if
own_quantized_input
and
isinstance
(
inputmat
,
QuantizedTensor
):
if
own_quantized_input
and
isinstance
(
inputmat
,
QuantizedTensor
Base
):
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
if
isinstance
(
inputmat
,
MXFP8TensorBase
)
or
not
ctx
.
backward_input_needs_gather
:
inputmat
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
if
force_hp_input_gather
:
assert
not
isinstance
(
inputmat
,
QuantizedTensor
)
assert
not
isinstance
(
inputmat
,
QuantizedTensor
Base
)
saved_inputmat
=
inputmat
# Weight with column-wise usage is needed for dgrad GEMM.
if
inp
.
requires_grad
:
if
isinstance
(
weightmat
,
QuantizedTensor
):
if
isinstance
(
weightmat
,
QuantizedTensor
Base
):
weightmat
.
update_usage
(
columnwise_usage
=
True
)
if
cpu_offloading
and
saved_inputmat
is
not
None
:
...
...
@@ -321,7 +368,7 @@ class _Linear(torch.autograd.Function):
ctx
.
fsdp_shapes
=
_fsdp_scatter_tensors
(
fsdp_group
,
saved_inputmat
,
weightmat
if
fp8
and
not
isinstance
(
weight
,
QuantizedTensor
)
else
None
,
weightmat
if
fp8
and
not
isinstance
(
weight
,
QuantizedTensor
Base
)
else
None
,
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
...
...
@@ -364,7 +411,7 @@ class _Linear(torch.autograd.Function):
ctx
.
use_bias
=
bias
is
not
None
ctx
.
sequence_parallel
=
sequence_parallel
ctx
.
tensor_parallel
=
tensor_parallel
ctx
.
inp_shape
=
inp
_
shape
ctx
.
inp_shape
=
inp
.
shape
ctx
.
parallel_mode
=
parallel_mode
ctx
.
tp_group
=
tp_group
ctx
.
ub_overlap_ag
=
ub_overlap_ag_dgrad
...
...
@@ -376,6 +423,7 @@ class _Linear(torch.autograd.Function):
ctx
.
requires_dgrad
=
inp
.
requires_grad
ctx
.
requires_wgrad
=
weight
.
requires_grad
ctx
.
reduce_and_update_bwd_fp8_tensors
=
False
ctx
.
owns_input
=
saved_inputmat
is
not
inp
if
ctx
.
fp8
and
requires_grad
(
inp
,
weight
,
bias
):
_first_fp8_module
=
FP8GlobalStateManager
.
IS_FIRST_FP8_MODULE
...
...
@@ -384,21 +432,10 @@ class _Linear(torch.autograd.Function):
FP8GlobalStateManager
.
IS_FIRST_FP8_MODULE
=
_first_fp8_module
ctx
.
wgrad_store
=
wgrad_store
# Row Parallel Linear
if
ub_overlap_rs_fprop
:
out
=
rs_out
elif
parallel_mode
==
"row"
:
nvtx_range_push
(
f
"
{
nvtx_label
}
.row_parallel_comm"
)
if
sequence_parallel
:
out
,
_
=
reduce_scatter_along_first_dim
(
out
,
tp_group
)
elif
tensor_parallel
:
if
symmetric_ar_type
is
not
None
:
out
,
_
=
symmetric_all_reduce
(
out
,
tp_group
,
all_reduce_type
=
symmetric_ar_type
)
else
:
out
,
_
=
allreduce
(
out
,
tp_group
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.row_parallel_comm"
)
# ------------------------------------------------------
# Cached state for backward pass is ready...
# ------------------------------------------------------
out
=
out
.
view
(
-
1
,
*
inp_shape
[
1
:
-
1
],
out_features
)
return
out
@
staticmethod
...
...
@@ -411,28 +448,11 @@ class _Linear(torch.autograd.Function):
nvtx_label
=
f
"
{
nvtx_label
}
.
{
ctx
.
ub_name
}
"
with
torch
.
cuda
.
nvtx
.
range
(
"_Linear_backward"
):
if
(
ctx
.
fp8
and
any
(
[
ctx
.
ub_overlap_ag
,
ctx
.
ub_overlap_rs_dgrad
,
ctx
.
ub_bulk_dgrad
,
ctx
.
ub_bulk_wgrad
,
]
)
and
(
ctx
.
fp8_recipe
is
not
None
)
):
if
not
ctx
.
fp8_recipe
.
float8_per_tensor_scaling
():
raise
NotImplementedError
(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
saved_tensors
=
ctx
.
saved_tensors
inputmat
,
weight_fp8
,
weight
,
bias
=
(
# pylint: disable=unbalanced-tuple-unpacking
restore_from_saved
(
ctx
.
tensor_objects
,
saved_tensors
)
)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx
.
tensor_objects
=
None
...
...
@@ -462,69 +482,55 @@ class _Linear(torch.autograd.Function):
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_gather"
)
# Configure Userbuffers communication (comm+GEMM overlap)
ctx
.
ub_obj_gradout
=
None
ub_obj_dgrad
=
None
ub_obj_wgrad
=
None
ub_type_dgrad
=
None
ub_type_wgrad
=
None
dgrad_shape
=
[
reduce
(
multiply_op
,
ctx
.
inp_shape
[:
-
1
]),
ctx
.
inp_shape
[
-
1
]]
rs_out
=
None
dgrad_bulk
=
None
if
ctx
.
ub_overlap_ag
:
# Overlap grad_output all-gather with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
elif
ctx
.
ub_overlap_rs_dgrad
:
# Overlap dgrad reduce-scatter with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
RS
rs_out
=
torch
.
empty
(
dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
grad_output
.
device
)
else
:
if
ctx
.
ub_bulk_dgrad
:
# Overlap inputmat all-gather with dgrad compute
# NOTE: Copying into communication buffer will always prefer rowwise data,
# and will copy columnwise data if rowwise does not exist. In that case,
# the all-gather will apply to the leading dimension of the transpose,
# which then needs to be interleaved correctly before WGRAD.
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
ub_obj_dgrad
.
copy_into_buffer
(
inputmat
,
ctx
.
input_quantizer
,
local_chunk
=
True
)
if
ctx
.
ub_bulk_wgrad
:
# Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
)
ub_type_wgrad
=
tex
.
CommOverlapType
.
RS
ub_obj_wgrad
.
set_buffer_params
(
ctx
.
grad_input_quantizer
)
dgrad_bulk
=
ub_obj_wgrad
.
get_buffer
(
ctx
.
grad_input_quantizer
)
# --------------------------------------------------
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# --------------------------------------------------
# Unmodified grad output tensor
grad_output_arg
=
grad_output
# Configure quantizer for grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if
ctx
.
grad_output_quantizer
is
not
None
:
rowwise_usage
=
True
columnwise_usage
=
True
if
ctx
.
ub_overlap_ag
and
isinstance
(
ctx
.
grad_output_quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
),
):
# If data is in FP8 and communication is handled
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage
=
False
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
rowwise_usage
,
columnwise
=
columnwise_usage
,
)
quantizer
=
ctx
.
grad_output_quantizer
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
if
ctx
.
ub_overlap_ag
:
# Userbuffers only supports communication for one
# tensor usage at a time. Configure quantizer with
# usage for only dgrad GEMM.
quantizer
.
set_usage
(
columnwise
=
False
)
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push
(
f
"
{
nvtx_label
}
.grad_output_preprocess"
)
(
grad_output
,
...
...
@@ -537,12 +543,21 @@ class _Linear(torch.autograd.Function):
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.grad_output_preprocess"
)
# Launch tensor-parallel communication for input tensor
# --------------------------------------------------
# Grad output tensor is ready for computing grad input...
# --------------------------------------------------
# --------------------------------------------------
# Prepare input tensor
# Note: Input tensor is needed for wgrad GEMM.
# Tensor-parallel communication is overlapped with dgrad
# GEMM.
# --------------------------------------------------
inputmat_total
=
None
inputmat_total_work
=
None
if
ctx
.
backward_input_needs_gather
and
not
ctx
.
ub_bulk_dgrad
:
if
ctx
.
backward_input_needs_gather
:
quantizer
=
None
if
ctx
.
fp8
or
ctx
.
debug
:
if
(
ctx
.
fp8
or
ctx
.
debug
)
and
not
ctx
.
force_hp_input_gather
:
quantizer
=
ctx
.
input_quantizer
if
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)):
# If data is in FP8, we compute FP8 transposes manually
...
...
@@ -550,72 +565,92 @@ class _Linear(torch.autograd.Function):
else
:
# wgrad GEMM requires input with column-wise usage
quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.column_parallel_comm_input"
)
gather_quantizer
=
None
if
ctx
.
force_hp_input_gather
else
quantizer
inputmat_total
,
inputmat_total_work
=
gather_along_first_dim
(
inputmat
,
ctx
.
tp_group
,
async_op
=
True
,
quantizer
=
gather_quantizer
,
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.column_parallel_comm_input"
)
if
ctx
.
ub_bulk_dgrad
:
inputmat_total
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ub_obj_dgrad
,
inputmat
,
quantizer
,
ctx
.
tp_group
,
)
else
:
nvtx_range_push
(
f
"
{
nvtx_label
}
.column_parallel_comm_input"
)
inputmat_total
,
inputmat_total_work
=
gather_along_first_dim
(
inputmat
,
ctx
.
tp_group
,
async_op
=
True
,
quantizer
=
quantizer
,
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.column_parallel_comm_input"
)
else
:
inputmat_total
=
inputmat
# --------------------------------------------------
# Input tensor is ready for computing grad weight...
# --------------------------------------------------
# Check whether to output wgrad GEMM directly into main grad
if
ctx
.
is_first_microbatch
is
not
None
:
accumulate_wgrad_into_param_main_grad
=
(
ctx
.
fuse_wgrad_accumulation
and
not
ctx
.
is_first_microbatch
)
else
:
accumulate_wgrad_into_param_main_grad
=
ctx
.
fuse_wgrad_accumulation
# --------------------------------------------------
# Compute grad input tensor
# --------------------------------------------------
dgrad
=
None
dgrad_work
=
None
if
ctx
.
requires_dgrad
:
# Update quantizer
if
ctx
.
grad_input_quantizer
is
not
None
:
ctx
.
grad_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
# dgrad GEMM
nvtx_range_push
(
f
"
{
nvtx_label
}
.dgrad_gemm"
)
dgrad_gemm_use_split_accumulator
=
_2X_ACC_DGRAD
# Make sure required data is available
if
isinstance
(
grad_output
,
QuantizedTensorBase
):
grad_output
.
update_usage
(
rowwise_usage
=
True
)
if
ctx
.
weight_quantizer
is
not
None
and
isinstance
(
weight_fp8
,
QuantizedTensorBase
):
weight_fp8
.
update_usage
(
columnwise_usage
=
True
)
# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator
=
_2X_ACC_DGRAD
if
ctx
.
fp8
:
recipe
=
ctx
.
fp8_recipe
if
hasattr
(
recipe
,
"fp8_gemm_dgrad"
):
dgrad_gemm_use_split_accumulator
=
(
recipe
.
fp8_gemm_dgrad
.
use_split_accumulator
)
use_split_accumulator
=
recipe
.
fp8_gemm_dgrad
.
use_split_accumulator
if
ctx
.
weight_quantizer
is
not
None
and
isinstance
(
weight_fp8
,
QuantizedTensor
):
weight_fp8
.
update_usage
(
rowwise_usage
=
ctx
.
weight_quantizer
.
rowwise_usage
,
columnwise_usage
=
ctx
.
weight_quantizer
.
columnwise_usage
,
# Update grad input quantizer
if
ctx
.
grad_input_quantizer
is
not
None
:
ctx
.
grad_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
# Output buffers for Userbuffers reduce-scatter
gemm_out
=
None
reduce_scatter_out
=
None
if
ctx
.
ub_overlap_rs_dgrad
:
reduce_scatter_out
=
torch
.
empty
(
dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
grad_output_arg
.
device
)
elif
ctx
.
ub_bulk_wgrad
:
gemm_out
=
ub_obj_wgrad
.
get_buffer
(
local_chunk
=
False
)
dgrad
,
*
_
,
rs_out
=
general_gemm
(
# dgrad GEMM
# Note: dx = dy * w
nvtx_range_push
(
f
"
{
nvtx_label
}
.dgrad_gemm"
)
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
weight_fp8
,
grad_output
,
get_workspace
(),
layout
=
"NN"
,
grad
=
True
,
quantization_params
=
ctx
.
grad_input_quantizer
,
out
=
dgrad_bulk
,
out
=
gemm_out
,
out_dtype
=
ctx
.
activation_dtype
,
use_split_accumulator
=
dgrad_gemm_
use_split_accumulator
,
use_split_accumulator
=
use_split_accumulator
,
ub
=
ub_obj_dgrad
,
ub_type
=
ub_type_dgrad
,
extra_output
=
r
s
_out
,
extra_output
=
r
educe_scatter
_out
,
bulk_overlap
=
ctx
.
ub_bulk_dgrad
,
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.dgrad_gemm"
)
# Launch tensor-parallel communication
# Prepare grad input tensor
# Note: Perform tensor-parallel communication
if
ctx
.
ub_overlap_rs_dgrad
:
dgrad
=
rs_out
elif
ctx
.
parallel_mode
==
"column"
and
not
ctx
.
ub_bulk_wgrad
:
dgrad
=
reduce_scatter_out
elif
ctx
.
ub_bulk_wgrad
:
dgrad
=
ub_obj_wgrad
.
get_buffer
(
local_chunk
=
True
)
elif
ctx
.
parallel_mode
==
"column"
and
ctx
.
tp_size
>
1
:
nvtx_range_push
(
f
"
{
nvtx_label
}
.column_parallel_comm_dgrad"
)
dgrad
=
gemm_out
if
ctx
.
sequence_parallel
:
dgrad
,
dgrad_work
=
reduce_scatter_along_first_dim
(
dgrad
,
...
...
@@ -625,41 +660,55 @@ class _Linear(torch.autograd.Function):
else
:
dgrad
,
dgrad_work
=
allreduce
(
dgrad
,
ctx
.
tp_group
,
async_op
=
True
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.column_parallel_comm_dgrad"
)
else
:
dgrad
=
gemm_out
# --------------------------------------------------
# Grad input tensor has been computed...
# --------------------------------------------------
# --------------------------------------------------
# Compute grad weight
# --------------------------------------------------
# Compute grad weight tensor
wgrad
=
None
if
ctx
.
requires_wgrad
:
# Synchronize tensor-parallel communication for input tensor
if
ctx
.
ub_bulk_dgrad
:
inputmat_total
=
ub_obj_dgrad
.
get_buffer
(
ctx
.
input_quantizer
)
if
ctx
.
fp8
:
if
inputmat
.
_data
is
None
:
# All-gather executed on columnwise data and result is in rowwise data,
# so we need to fix the interleaving before WGRAD.
inputmat_total
=
_fix_gathered_fp8_transpose
(
inputmat_total
,
ctx
.
tp_size
)
elif
not
non_tn_fp8_gemm_supported
():
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
inputmat_total
.
_create_transpose
()
# Prepare input tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if
inputmat_total_work
is
not
None
:
inputmat_total_work
.
wait
()
inputmat_total_work
=
None
if
ctx
.
input_quantizer
is
not
None
and
not
isinstance
(
inputmat_total
,
QuantizedTensor
):
# Async gather in BF16 does not asynchronously
# call quantizer after gather.
ctx
.
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
inputmat_total
=
ctx
.
input_quantizer
(
inputmat_total
)
# Make sure GEMM inputs have required data
if
isinstance
(
inputmat_total
,
QuantizedTensor
):
inputmat_total
.
update_usage
(
columnwise_usage
=
True
)
if
isinstance
(
grad_output
,
QuantizedTensor
):
grad_output
.
update_usage
(
columnwise_usage
=
True
)
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
inputmat_total
,
QuantizedTensorBase
):
inputmat_total
.
update_usage
(
columnwise_usage
=
True
)
else
:
ctx
.
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
inputmat_total
=
ctx
.
input_quantizer
(
inputmat_total
)
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if
ctx
.
ub_overlap_ag
and
isinstance
(
ctx
.
grad_output_quantizer
,
MXFP8Quantizer
):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
grad_output
,
_
=
gather_along_first_dim
(
grad_output_arg
,
ctx
.
tp_group
,
quantizer
=
ctx
.
grad_output_quantizer
,
)
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
grad_output
,
QuantizedTensorBase
):
grad_output
.
update_usage
(
columnwise_usage
=
True
)
else
:
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
grad_output
=
ctx
.
grad_output_quantizer
(
grad_output
)
# Figure out whether to use split accumulator
use_split_accumulator
=
_2X_ACC_WGRAD
...
...
@@ -668,54 +717,95 @@ class _Linear(torch.autograd.Function):
if
hasattr
(
recipe
,
"fp8_gemm_wgrad"
):
use_split_accumulator
=
recipe
.
fp8_gemm_wgrad
.
use_split_accumulator
# Output buffer for overlapping grad input
# Figure out whether to output wgrad GEMM directly into main grad
if
ctx
.
is_first_microbatch
is
not
None
:
accumulate_wgrad_into_param_main_grad
=
(
ctx
.
fuse_wgrad_accumulation
and
not
ctx
.
is_first_microbatch
)
else
:
accumulate_wgrad_into_param_main_grad
=
ctx
.
fuse_wgrad_accumulation
# Output buffer for overlapping FP8 grad input
# reduce-scatter with wgrad GEMM
reduce_scatter_out
=
None
if
ctx
.
ub_bulk_wgrad
and
ub_obj_wgrad
.
is_fp8_ubuf
():
r
s
_out
=
torch
.
empty
(
dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
grad_output
.
device
r
educe_scatter
_out
=
torch
.
empty
(
dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
grad_output
_arg
.
device
)
# wgrad GEMM
# Note: Fuse with bgrad computation if needed
nvtx_range_push
(
f
"
{
nvtx_label
}
.wgrad_gemm"
)
general_gemm_wgrad
=
functools
.
partial
(
general_gemm
,
out_dtype
=
(
# Arguments to include in wgrad GEMM closure
wgrad_gemm_kwargs
=
{
"workspace"
:
get_workspace
(),
"out_dtype"
:
(
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
),
workspace
=
get_workspace
(),
layout
=
"NT"
,
grad
=
True
,
bias
=
(
bias
if
(
grad_bias
is
None
and
not
ctx
.
fp8
)
else
None
),
out
=
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
use_split_accumulator
=
use_split_accumulator
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
quantization_params
=
ctx
.
grad_weight_quantizer
,
ub
=
ub_obj_wgrad
,
ub_type
=
ub_type_wgrad
,
extra_output
=
rs_out
,
bulk_overlap
=
ctx
.
ub_bulk_wgrad
,
)
"quantization_params"
:
ctx
.
grad_weight_quantizer
,
"accumulate"
:
accumulate_wgrad_into_param_main_grad
,
"layout"
:
"NT"
,
"out"
:
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
"bias"
:
(
bias
if
(
grad_bias
is
None
and
not
ctx
.
fp8
)
else
None
),
"use_split_accumulator"
:
use_split_accumulator
,
"grad"
:
True
,
"ub"
:
ub_obj_wgrad
,
"ub_type"
:
ub_type_wgrad
,
"extra_output"
:
reduce_scatter_out
,
"bulk_overlap"
:
ctx
.
ub_bulk_wgrad
,
}
def
wgrad_gemm
(
x
:
torch
.
Tensor
,
dy
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Perform wgrad GEMM: dw = dy^T * x
May be fused with bgrad computation.
May be called outside of this function to enable
some advanced communication/compute overlapping.
"""
nvtx_range_push
(
f
"
{
nvtx_label
}
.wgrad_gemm"
)
dw
,
db
,
*
_
=
general_gemm
(
x
,
dy
,
**
wgrad_gemm_kwargs
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.wgrad_gemm"
)
return
dw
,
db
# Choose whether to call wgrad GEMM now or delay
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
ctx
.
wgrad_store
.
put
([
inputmat_total
,
grad_output
],
general_gemm_wgrad
)
if
(
wgrad_gemm_kwargs
[
"ub"
]
is
not
None
or
wgrad_gemm_kwargs
[
"ub_type"
]
is
not
None
or
wgrad_gemm_kwargs
[
"extra_output"
]
is
not
None
or
wgrad_gemm_kwargs
[
"bulk_overlap"
]
):
raise
NotImplementedError
(
"Delayed weight grad computation is not supported "
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx
.
wgrad_store
.
put
([
inputmat_total
,
grad_output
],
wgrad_gemm
)
else
:
wgrad
,
grad_bias_
,
_
,
rs_out
=
general_gemm_wgrad
(
inputmat_total
,
grad_output
)
# Call wgrad GEMM now
wgrad
,
grad_bias_
=
wgrad_gemm
(
inputmat_total
,
grad_output
)
# Update grad bias if needed
if
grad_bias
is
None
:
grad_bias
=
grad_bias_
del
grad_bias_
# Deallocate input tensor
# Deallocate input tensor
if permitted
if
ctx
.
owns_input
:
clear_tensor_data
(
inputmat_total
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.wgrad_gemm"
)
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if
ctx
.
ub_bulk_wgrad
:
if
ub_obj_wgrad
.
is_fp8_ubuf
():
dgrad
=
r
s
_out
dgrad
=
r
educe_scatter
_out
else
:
dgrad
=
ub_obj_wgrad
.
get_buffer
(
ctx
.
grad_input_quantizer
,
local_chunk
=
True
)
dgrad
=
ub_obj_wgrad
.
get_buffer
(
local_chunk
=
True
).
clone
()
# --------------------------------------------------
# Grad weight has been computed...
# --------------------------------------------------
# Don't return grad bias if not needed
if
not
ctx
.
use_bias
:
...
...
@@ -753,13 +843,14 @@ class _Linear(torch.autograd.Function):
else
:
wgrad
=
None
# Update FP8 scaling factors if needed
if
ctx
.
reduce_and_update_bwd_fp8_tensors
and
not
is_graph_capturing
():
nvtx_range_push
(
f
"
{
nvtx_label
}
.reduce_and_update_fp8_tensors"
)
FP8GlobalStateManager
.
reduce_and_update_fp8_tensors
(
forward
=
False
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.reduce_and_update_fp8_tensors"
)
# Scatter fp8 weight buffers
if
ctx
.
fp8
and
not
isinstance
(
weight
,
QuantizedTensor
):
if
ctx
.
fp8
and
not
isinstance
(
weight
,
QuantizedTensor
Base
):
_fsdp_scatter_tensors
(
ctx
.
fsdp_group
,
weight_fp8
)
return
(
wgrad
,
...
...
@@ -1207,7 +1298,12 @@ class Linear(TransformerEngineBaseModule):
"Splitting QuantizedTensor into multiple params is not supported"
)
else
:
warnings
.
warn
(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights
=
[
w
.
dequantize
()
for
w
in
unfused_weights
]
weight_tensor
=
noop_cat
(
unfused_weights
)
if
self
.
use_bias
:
bias_tensor
=
noop_cat
([
getattr
(
self
,
name
)
for
name
in
self
.
bias_names
])
...
...
@@ -1302,7 +1398,7 @@ class Linear(TransformerEngineBaseModule):
grad_output_quantizer
=
None
output_quantizer
=
None
input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
input_quantizer
.
internal
=
Fals
e
input_quantizer
.
internal
=
Tru
e
weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
weight_quantizer
.
internal
=
True
if
fp8_output
:
...
...
transformer_engine/pytorch/module/rmsnorm.py
View file @
f8c2af4c
...
...
@@ -98,6 +98,9 @@ class RMSNorm(_RMSNormOp):
)
kwargs
[
"dtype"
]
=
params_dtype
# Flag for sequence parallelism (custom Megatron-LM integration)
self
.
sequence_parallel
:
Optional
[
bool
]
=
sequence_parallel
# Initialize RMSNorm operation
super
().
__init__
(
normalized_shape
,
...
...
@@ -106,11 +109,6 @@ class RMSNorm(_RMSNormOp):
**
kwargs
,
)
# Flag for sequence parallelism (custom Megatron-LM integration)
self
.
sequence_parallel
:
Optional
[
bool
]
=
sequence_parallel
if
sequence_parallel
is
not
None
:
self
.
weight
.
sequence_parallel
=
sequence_parallel
def
reset_rms_norm_parameters
(
self
)
->
None
:
"""Deprecated"""
warnings
.
warn
(
...
...
@@ -139,7 +137,7 @@ class RMSNorm(_RMSNormOp):
super
().
reset_parameters
()
# Flag for sequence parallelism (custom Megatron-LM integration)
if
getattr
(
self
,
"
sequence_parallel
"
,
None
)
is
not
None
:
if
self
.
sequence_parallel
is
not
None
:
self
.
weight
.
sequence_parallel
=
self
.
sequence_parallel
@
property
...
...
transformer_engine/pytorch/ops/basic/basic_linear.py
View file @
f8c2af4c
...
...
@@ -534,7 +534,9 @@ class BasicLinear(BasicOperation):
# Configure input tensor for backward pass
if
with_quantized_compute
and
isinstance
(
x_local
,
QuantizedTensor
):
x_local
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
if
not
(
isinstance
(
x_local
,
Float8TensorBase
)
and
with_x_all_gather
):
# FP8 does not support all-gather of transpose data
x_local
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
return
y
,
x_local
,
w
...
...
@@ -622,7 +624,10 @@ class BasicLinear(BasicOperation):
# Check datatype
if
dtype
is
None
:
dtype
=
weight
.
dtype
if
weight
is
not
None
:
dtype
=
weight
.
dtype
else
:
dtype
=
grad_output
.
dtype
dtype
=
canonicalize_dtype
(
dtype
)
if
dtype
not
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
raise
ValueError
(
f
"Supported dtypes are float32, float16, bfloat16 (got
{
dtype
}
)"
)
...
...
@@ -814,7 +819,7 @@ class BasicLinear(BasicOperation):
x_async
=
None
dy_async
=
None
# Check grad
inpu
t tensor
# Check grad
weigh
t tensor
dw
=
grad_weight
dw_dtype
=
dtype
if
dw
is
None
:
...
...
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
View file @
f8c2af4c
...
...
@@ -4,30 +4,27 @@
"""Linear layer backward with Userbuffers communication."""
# pylint: skip-file ### TODO Debug Userbuffers support
from
__future__
import
annotations
from
collections.abc
import
Iterable
from
typing
import
Any
,
Optional
from
typing
import
Optional
import
warnings
import
torch
from
transformer_engine_torch
import
CommOverlap
Algo
from
transformer_engine_torch
import
CommOverlap
Type
from
...cpp_extensions
import
general_gemm
from
...distributed
import
get_distributed_world_size
from
...float8_tensor
import
Float8Tensor
from
...fp8
import
FP8GlobalStateManager
,
get_fp8_te_dtype
from
...module.base
import
get_ub
,
get_workspace
from
...distributed
import
gather_along_first_dim
,
get_distributed_world_size
from
...module.base
import
(
fill_userbuffers_buffer_for_all_gather
,
get_ub
,
get_workspace
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
)
from
...tensor.quantized_tensor
import
QuantizedTensorBase
,
Quantizer
from
...tensor.mxfp8_tensor
import
MXFP8Quantizer
from
...utils
import
canonicalize_device
,
canonicalize_dtype
,
clear_tensor_data
from
..basic
import
BasicLinear
,
Bias
,
ReduceScatter
from
..op
import
FusedOperation
,
FusibleOperation
,
OperationContext
from
.._common
import
(
convert_tensor
,
get_fp8_meta_from_fp8_tensor
,
is_float8_tensor
,
reshape
,
)
class
UserbuffersBackwardLinear
(
FusedOperation
):
...
...
@@ -47,9 +44,6 @@ class UserbuffersBackwardLinear(FusedOperation):
reduce_scatter
:
Optional
[
ReduceScatter
],
)
->
None
:
### TODO Debug Userbuffers support
raise
NotImplementedError
(
"Userbuffers support has been broken by recent refactors"
)
# Basic operations that comprise this fused operation
op_idxs
=
{
"linear"
:
None
,
"bias"
:
None
,
"reduce_scatter"
:
None
}
ops
=
[]
...
...
@@ -89,9 +83,8 @@ class UserbuffersBackwardLinear(FusedOperation):
grad_output
:
torch
.
Tensor
,
input
:
Optional
[
torch
.
Tensor
],
# pylint: disable=redefined-builtin
weight
:
Optional
[
torch
.
Tensor
],
input_dims
:
Iterable
[
int
],
weight_dims
:
Iterable
[
int
],
*
,
input_requires_grad
:
bool
=
True
,
weight_requires_grad
:
bool
=
True
,
bias_requires_grad
:
bool
=
False
,
device
:
Optional
[
torch
.
device
]
=
None
,
...
...
@@ -102,11 +95,11 @@ class UserbuffersBackwardLinear(FusedOperation):
tensor_parallel_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
tensor_parallel_size
:
Optional
[
int
]
=
None
,
sequence_parallel
:
bool
=
False
,
with_
fp8
_compute
:
bool
=
False
,
input_
fp8_meta
:
Optional
[
dict
[
str
,
Any
]
]
=
None
,
weight_
fp8_meta
:
Optional
[
dict
[
str
,
Any
]
]
=
None
,
grad_output_
fp8_meta
:
Optional
[
dict
[
str
,
Any
]
]
=
None
,
grad_input_
fp8_meta
:
Optional
[
dict
[
str
,
Any
]
]
=
None
,
with_
quantized
_compute
:
bool
=
False
,
input_
quantizer
:
Optional
[
Quantizer
]
=
None
,
weight_
quantizer
:
Optional
[
Quantizer
]
=
None
,
grad_output_
quantizer
:
Optional
[
Quantizer
]
=
None
,
grad_input_
quantizer
:
Optional
[
Quantizer
]
=
None
,
ub_comm_name
:
str
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
dict
]:
"""Functional API for backward pass
...
...
@@ -121,10 +114,6 @@ class UserbuffersBackwardLinear(FusedOperation):
weight: torch.Tensor, optional
Weight tensor. Required to compute loss gradient w.r.t.
input.
input_dims: iterable of int
Input tensor dimensions
weight_dims: iterable of int
Weight tensor dimensions
weight_requires_grad: bool
Whether to compute loss gradient w.r.t. weight tensor
bias_requires_grad: bool
...
...
@@ -146,21 +135,18 @@ class UserbuffersBackwardLinear(FusedOperation):
parallelism, i.e. distributing input or output tensors
along outer dimension (sequence or batch dim) when not
distributing along inner dimension (embedding dim)
with_fp8_compute: bool, default = `False`
Whether to perform compute in FP8
input_fp8_meta: dict, optional
FP8 metadata for casting input tensor to FP8. Required for
FP8 compute if input is not already in FP8.
weight_fp8_meta: dict, optional
FP8 metadata for casting weight tensor to FP8. Required for
FP8 compute if weight is not already in FP8.
grad_output_fp8_meta: dict, optional
FP8 metadata for casting loss gradient w.r.t. output
tensor to FP8. Required if output grad is not already in
FP8.
grad_input_fp8_meta: dict, optional
FP8 metadata for casting loss gradient w.r.t. input
tensor to FP8
with_quantized_compute: bool, default = `False`
Whether to perform compute with quantized data.
input_quantizer: Quantizer, optional
Builder class for quantized input tensor.
weight_quantizer: Quantizer, optional
Builder class for quantized weight tensor.
grad_output_quantizer: Quantizer, optional
Builder class for quantized loss gradient w.r.t. output
tensor.
grad_input_quantizer: Quantizer, optional
Builder class for quantized loss gradient w.r.t. input
tensor.
ub_comm_name: str
Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is
used to access the corresponding Userbuffers communicators
...
...
@@ -183,37 +169,24 @@ class UserbuffersBackwardLinear(FusedOperation):
# Check device
if
device
is
None
:
device
=
weight
.
device
if
weight
is
not
None
:
device
=
weight
.
device
else
:
device
=
grad_output
.
device
device
=
canonicalize_device
(
device
)
if
device
.
type
!=
"cuda"
:
raise
ValueError
(
f
"Only CUDA devices are supported (got
{
device
}
)"
)
# Check datatype
if
dtype
is
None
:
dtype
=
weight
.
dtype
if
weight
is
not
None
:
dtype
=
weight
.
dtype
else
:
dtype
=
grad_output
.
dtype
dtype
=
canonicalize_dtype
(
dtype
)
if
dtype
not
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
raise
ValueError
(
f
"Supported dtypes are float32, float16, bfloat16 (got
{
dtype
}
)"
)
# Input tensor dims
output_dims
=
tuple
(
grad_output
.
size
())
input_dims
=
tuple
(
input_dims
)
weight_dims
=
tuple
(
weight_dims
)
if
len
(
weight_dims
)
!=
2
:
raise
ValueError
(
f
"Weight tensor is not 2D (shape=
{
weight_dims
}
)"
)
if
len
(
input_dims
)
==
0
or
weight_dims
[
1
]
!=
input_dims
[
-
1
]:
raise
ValueError
(
f
"Input tensor (shape=
{
input_dims
}
) "
f
"and weight tensor (shape=
{
weight_dims
}
) "
"are not compatible"
)
if
weight_dims
[
0
]
!=
output_dims
[
-
1
]:
raise
ValueError
(
f
"Grad output tensor (shape=
{
output_dims
}
) "
f
"and weight tensor (shape=
{
weight_dims
}
) "
"are not compatible"
)
# Check tensor parallel group
if
tensor_parallel_size
is
None
:
tensor_parallel_size
=
get_distributed_world_size
(
tensor_parallel_group
)
...
...
@@ -227,373 +200,283 @@ class UserbuffersBackwardLinear(FusedOperation):
if
not
sequence_parallel
:
raise
RuntimeError
(
f
"Invalid configuration for Userbuffers (
{
sequence_parallel
=
}
)"
)
# Check if FP8 is enabled
if
with_fp8_compute
:
if
grad_output_fp8_meta
is
None
and
not
is_float8_tensor
(
grad_output
):
raise
ValueError
(
"No FP8 metadata was provided for casting output gradient to FP8"
)
# dgrad GEMM is required
if
not
input_requires_grad
:
warnings
.
warn
(
"Linear input doesn't require gradient, "
"but Userbuffers implementation requires dgrad GEMM."
)
input_requires_grad
=
True
# Check quantizers
if
with_quantized_compute
:
if
weight_requires_grad
and
input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for input tensor"
)
if
input_requires_grad
and
weight_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for weight tensor"
)
if
grad_output_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for grad output tensor"
)
if
grad_input_quantizer
is
not
None
:
raise
ValueError
(
"Quantized grad input is not supported"
)
else
:
input_fp8_meta
=
None
weight_fp8_meta
=
None
grad_output_fp8_meta
=
None
grad_input_fp8_meta
=
None
with_fp8_grad_input
=
(
with_fp8_compute
and
tensor_parallel_mode
!=
"column"
and
grad_input_fp8_meta
is
not
None
)
input_quantizer
=
None
weight_quantizer
=
None
grad_output_quantizer
=
None
grad_input_quantizer
=
None
# Get Userbuffers communicators
and algorithms
# Note:
c
ommunication patterns are (1) overlap dy all-gather
# Get Userbuffers communicators
# Note:
C
ommunication patterns are (1) overlap dy all-gather
# with dgrad GEMM, (2) overlap x all-gather with dgrad GEMM
# and dx reduce-scatter with wgrad GEMM, (3) overlap dx
# reduce-scatter with dgrad GEMM
.
with_ub_all_gather_dy
=
Fals
e
with_ub_reduce_scatter_dx
=
Fals
e
with_ub_all_gather_x
=
Fals
e
ub_
comm_dy
=
None
ub_comm_dx
=
Non
e
ub_comm_x
=
Non
e
ub_algo
_d
y
=
Non
e
ub_algo_dx
=
Non
e
ub_algo_x
=
Non
e
# reduce-scatter with dgrad GEMM
ub_comm_dgrad
=
Non
e
ub_comm_wgrad
=
Non
e
ub_type_dgrad
=
Non
e
ub_
type_wgrad
=
None
with_bulk_overlap
=
Fals
e
with_dgrad_all_gather_dy
=
Fals
e
with_dgrad_reduce_scatter
_d
x
=
Fals
e
with_dgrad_all_gather_x
=
Fals
e
with_wgrad_reduce_scatter_dx
=
Fals
e
if
tensor_parallel_mode
==
"row"
:
with_ub_all_gather_dy
=
True
ub_comm_dy
=
get_ub
(
ub_comm_name
+
"_dgrad"
)
if
with_fp8_compute
and
ub_comm_dy
.
is_atomic_gemm
():
ub_algo_dy
=
CommOverlapAlgo
.
ATOMIC_GEMM_AG_P2P
else
:
ub_algo_dy
=
CommOverlapAlgo
.
SPLIT_PIPELINED_AG_P2P
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
)
ub_type_dgrad
=
CommOverlapType
.
AG
with_dgrad_all_gather_dy
=
True
elif
tensor_parallel_mode
==
"column"
:
with_ub_reduce_scatter_dx
=
True
if
weight_requires_grad
:
with_ub_all_gather_x
=
True
ub_comm_dx
=
get_ub
(
ub_comm_name
+
"_wgrad"
)
ub_comm_x
=
get_ub
(
ub_comm_name
+
"_dgrad"
)
ub_algo_dx
=
CommOverlapAlgo
.
BULK_OVERLAP_RS
ub_algo_x
=
CommOverlapAlgo
.
BULK_OVERLAP_AG
if
input_requires_grad
and
weight_requires_grad
:
with_bulk_overlap
=
True
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
)
ub_type_dgrad
=
CommOverlapType
.
AG
with_dgrad_all_gather_x
=
True
ub_comm_wgrad
=
get_ub
(
ub_comm_name
+
"_wgrad"
)
ub_type_wgrad
=
CommOverlapType
.
RS
with_wgrad_reduce_scatter_dx
=
True
if
ub_comm_wgrad
.
is_fp8_ubuf
():
raise
RuntimeError
(
"Userbuffers reduce-scatter is not supported with FP8 buffers"
)
else
:
with_ub_all_gather_x
=
False
ub_comm_dx
=
get_ub
(
ub_comm_name
+
"_dgrad"
)
is_atomic_gemm
=
with_fp8_compute
and
ub_comm_dx
.
is_atomic_gemm
()
ub_algo_dx
=
{
(
True
,
True
):
CommOverlapAlgo
.
ATOMIC_GEMM_RS_P2P
,
(
True
,
False
):
CommOverlapAlgo
.
SPLIT_PIPELINED_RS_P2P
,
(
False
,
True
):
CommOverlapAlgo
.
ATOMIC_GEMM_RS
,
(
False
,
False
):
CommOverlapAlgo
.
SPLIT_PIPELINED_RS
,
}[(
ub_comm_dx
.
is_p2p_overlap
(),
is_atomic_gemm
)]
# Check grad output tensor
# Note: Possibly fuse cast with computing grad bias
dy_local
=
reshape
(
grad_output
,
(
-
1
,
output_dims
[
-
1
]),
device
=
device
,
dtype
=
dtype
,
)
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
)
ub_type_dgrad
=
CommOverlapType
.
RS
with_dgrad_reduce_scatter_dx
=
True
if
ub_comm_dgrad
.
is_fp8_ubuf
():
raise
RuntimeError
(
"Userbuffers reduce-scatter is not supported with FP8 buffers"
)
# Compute grad bias if needed
db
=
None
db_async
=
None
if
bias_requires_grad
and
with_fp8_compute
and
with_ub_all_gather_dy
:
# We don't have a grad bias impl that takes FP8 input. For
# cases where we cast to FP8 and all-gather, it's better
# to compute the grad bias on ungathered, non-FP8 values.
db
=
dy_local
.
sum
(
dim
=
0
)
db_async
=
torch
.
distributed
.
all_reduce
(
db
,
group
=
tensor_parallel_group
,
async_op
=
True
,
)
if
with_fp8_compute
and
not
is_float8_tensor
(
dy_local
):
fp8_dtype
=
get_fp8_te_dtype
(
grad_output_fp8_meta
[
"recipe"
],
fprop_tensor
=
False
,
)
if
bias_requires_grad
and
db
is
None
:
# Fused cast-transpose-bgrad
fp8_meta_key
=
FP8GlobalStateManager
.
get_meta_tensor_key
(
forward
=
False
)
fp8_scale_inv
=
torch
.
empty
([
1
],
dtype
=
torch
.
float32
,
device
=
device
)
db
,
data
,
data_transpose
=
fp8_cast_transpose_bgrad_fused
(
dy_local
,
grad_output_fp8_meta
[
fp8_meta_key
],
0
,
fp8_dtype
,
scale_inv
=
fp8_scale_inv
,
)
if
with_ub_all_gather_dy
:
data
=
ub_comm_dy
.
get_ubuf_output
(
0
).
copy_
(
data
)
dy_local
=
Float8Tensor
(
data
=
data
,
fp8_meta
=
grad_output_fp8_meta
,
fp8_meta_forward
=
False
,
fp8_meta_index
=
0
,
fp8_dtype
=
fp8_dtype
,
fp8_scale_inv
=
fp8_scale_inv
,
dtype
=
dtype
,
data_transpose
=
data_transpose
,
if
bias_requires_grad
:
db
=
grad_output
.
sum
(
tuple
(
range
(
grad_output
.
dim
()
-
1
)))
if
tensor_parallel_mode
==
"row"
:
db_async
=
torch
.
distributed
.
all_reduce
(
db
,
group
=
tensor_parallel_group
,
async_op
=
True
,
)
else
:
dy_local
=
Float8Tensor
.
to_float8
(
dy_local
,
fp8_meta
=
grad_output_fp8_meta
,
fp8_meta_forward
=
False
,
fp8_meta_index
=
0
,
fp8_dtype
=
fp8_dtype
,
data
=
(
ub_comm_dy
.
get_ubuf_output
(
0
)
if
with_ub_all_gather_dy
else
None
),
with_transpose_cache
=
(
not
with_ub_all_gather_dy
),
# Cast grad output tensor dtype if needed
dy_local
=
grad_output
if
with_quantized_compute
:
if
not
isinstance
(
dy_local
,
QuantizedTensorBase
):
with_columnwise
=
weight_requires_grad
if
(
with_columnwise
and
with_dgrad_all_gather_dy
and
not
isinstance
(
grad_output_quantizer
,
MXFP8Quantizer
)
):
with_columnwise
=
False
grad_output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
with_columnwise
,
)
elif
not
with_fp8_compute
and
is_float8_tensor
(
dy_local
):
if
with_ub_all_gather_dy
:
ub_local_buffer
=
ub_comm_dy
.
get_ubuf_output
(
0
)
dy_local
=
ub_local_buffer
.
copy_
(
dy_local
)
else
:
dy_local
=
dy_local
.
dequantize
()
if
bias_requires_grad
and
db
is
None
and
with_fp8_compute
and
with_ub_all_gather_dy
:
# We don't have a fused grad bias impl that takes FP8
# input. For cases where we cast to FP8 and all-gather,
# it's better to compute the grad bias on ungathered,
# non-FP8 values.
db
=
dy_local
.
sum
(
dim
=
0
)
db_async
=
torch
.
distributed
.
all_reduce
(
db
,
group
=
tensor_parallel_group
,
async_op
=
True
,
)
dy_local
=
grad_output_quantizer
(
dy_local
)
else
:
if
isinstance
(
dy_local
,
QuantizedTensorBase
):
dy_local
=
dy_local
.
dequantize
(
dtype
=
dtype
)
elif
dy_local
.
dtype
!=
dtype
:
dy_local
=
dy_local
.
to
(
dtype
=
dtype
)
# Cast weight tensor dtype if needed
if
weight
is
None
:
raise
ValueError
(
"Weight tensor is required to compute input grad"
)
w
=
weight
if
with_quantized_compute
:
if
not
isinstance
(
w
,
QuantizedTensorBase
):
weight_quantizer
.
set_usage
(
columnwise
=
True
)
w
=
weight_quantizer
(
w
)
else
:
if
isinstance
(
w
,
QuantizedTensorBase
):
w
=
w
.
dequantize
(
dtype
=
dtype
)
elif
w
.
dtype
!=
dtype
:
w
=
w
.
to
(
dtype
=
dtype
)
# C
heck
input tensor
# C
ast
input tensor
dtype if needed
x_local
=
None
if
weight_requires_grad
:
x_local
=
reshape
(
input
,
(
-
1
,
input_dims
[
-
1
]),
device
=
device
,
dtype
=
dtype
,
)
if
with_fp8_compute
and
not
is_float8_tensor
(
x_local
):
fp8_dtype
=
get_fp8_te_dtype
(
input_fp8_meta
[
"recipe"
],
fprop_tensor
=
True
,
if
input
is
None
:
raise
ValueError
(
"Input tensor is required to compute weight grad"
)
x_local
=
input
if
with_quantized_compute
:
if
not
isinstance
(
x_local
,
QuantizedTensorBase
):
input_quantizer
.
set_usage
(
columnwise
=
True
)
x_local
=
input_quantizer
(
x_local
)
else
:
if
isinstance
(
x_local
,
QuantizedTensorBase
):
x_local
=
x_local
.
dequantize
(
dtype
=
dtype
)
elif
x_local
.
dtype
!=
dtype
:
x_local
=
x_local
.
to
(
dtype
=
dtype
)
# dgrad GEMM
dx_local
=
None
dx
=
None
dy
=
None
x
=
None
if
input_requires_grad
:
# Initialize grad output
if
with_dgrad_all_gather_dy
:
if
grad_output_quantizer
is
not
None
:
grad_output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
dy
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ub_comm_dgrad
,
dy_local
,
grad_output_quantizer
,
tensor_parallel_group
,
)
x_local
=
Float8Tensor
.
to_float8
(
else
:
dy
=
dy_local
# Construct grad input tensor if needed
if
with_dgrad_reduce_scatter_dx
or
with_wgrad_reduce_scatter_dx
:
dx_size
=
list
(
dy
.
size
())
dx_size
[
-
1
]
=
w
.
size
(
-
1
)
dx_local_size
=
list
(
dx_size
)
dx_local_size
[
0
]
//=
tensor_parallel_size
if
with_dgrad_reduce_scatter_dx
:
dx_local
=
torch
.
empty
(
dx_local_size
,
dtype
=
dtype
,
device
=
device
,
)
elif
with_wgrad_reduce_scatter_dx
:
dx_local
=
ub_comm_wgrad
.
get_buffer
(
local_chunk
=
True
,
shape
=
dx_local_size
,
)
dx
=
ub_comm_wgrad
.
get_buffer
(
local_chunk
=
False
,
shape
=
dx_size
,
)
# Initialize input tensor if needed
if
with_dgrad_all_gather_x
:
if
input_quantizer
is
not
None
:
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
x
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ub_comm_dgrad
,
x_local
,
fp8_meta
=
input_fp8_meta
,
fp8_meta_forward
=
True
,
fp8_meta_index
=
0
,
fp8_dtype
=
fp8_dtype
,
data
=
(
ub_comm_x
.
get_ubuf_output
(
0
)
if
with_ub_all_gather_x
else
None
),
with_transpose_cache
=
(
not
with_ub_all_gather_x
),
input_quantizer
,
tensor_parallel_group
,
)
elif
not
with_fp8_compute
and
is_float8_tensor
(
x_local
):
if
with_ub_all_gather_x
:
ub_local_buffer
=
ub_comm_x
.
get_ubuf_output
(
0
)
x_local
=
ub_local_buffer
.
copy_
(
x_local
)
else
:
x_local
=
x_local
.
dequantize
()
# Check weight tensor
w
=
convert_tensor
(
weight
,
device
=
device
,
dtype
=
dtype
,
memory_format
=
torch
.
contiguous_format
,
)
if
with_fp8_compute
and
not
is_float8_tensor
(
w
):
fp8_dtype
=
get_fp8_te_dtype
(
weight_fp8_meta
[
"recipe"
],
fprop_tensor
=
True
,
)
w
=
Float8Tensor
.
to_float8
(
# Perform dgrad GEMM
dx
,
*
_
=
general_gemm
(
w
,
fp8_meta
=
weight_fp8_meta
,
fp8_meta_forward
=
True
,
fp8_meta_index
=
0
,
fp8_dtype
=
fp8_dtype
,
with_transpose_cache
=
True
,
dy
,
get_workspace
(),
out_dtype
=
dtype
,
quantization_params
=
grad_input_quantizer
,
layout
=
"NN"
,
out
=
dx
,
use_split_accumulator
=
_2X_ACC_DGRAD
,
grad
=
True
,
ub
=
ub_comm_dgrad
,
ub_type
=
ub_type_dgrad
,
extra_output
=
dx_local
if
with_dgrad_reduce_scatter_dx
else
None
,
bulk_overlap
=
with_bulk_overlap
,
)
elif
not
with_fp8_compute
and
is_float8_tensor
(
w
):
w
=
w
.
dequantize
()
# Initialize buffers for UB all-gather if needed
dy
=
dy_local
x
=
x_local
if
with_ub_all_gather_dy
:
ub_local_buffer
=
ub_comm_dy
.
get_ubuf_output
(
0
)
ub_global_buffer
=
ub_comm_dy
.
get_ubuf_output
(
1
)
if
with_fp8_compute
:
dy
=
Float8Tensor
.
make_like
(
dy_local
,
data
=
ub_global_buffer
)
if
dy_local
.
_data
.
data_ptr
()
!=
ub_local_buffer
.
data_ptr
():
ub_local_buffer
.
copy_
(
dy_local
.
_data
)
else
:
dy
=
ub_global_buffer
if
dy_local
.
data_ptr
()
!=
ub_local_buffer
.
data_ptr
():
ub_local_buffer
.
copy_
(
dy_local
)
if
with_ub_all_gather_x
:
ub_local_buffer
=
ub_comm_x
.
get_ubuf_output
(
0
)
ub_global_buffer
=
ub_comm_x
.
get_ubuf_output
(
1
)
if
with_fp8_compute
:
x
=
Float8Tensor
.
make_like
(
x_local
,
data
=
ub_global_buffer
)
if
x_local
.
_data
.
data_ptr
()
!=
ub_local_buffer
.
data_ptr
():
ub_local_buffer
.
copy_
(
x_local
.
_data
)
else
:
x
=
ub_global_buffer
if
x_local
.
data_ptr
()
!=
ub_local_buffer
.
data_ptr
():
ub_local_buffer
.
copy_
(
x_local
)
if
not
(
with_dgrad_reduce_scatter_dx
or
with_wgrad_reduce_scatter_dx
):
dx_local
=
dx
# Construct grad input tensor
dx
=
None
dx_local
=
None
if
with_ub_reduce_scatter_dx
:
# Initialize buffers for UB reduce-scatter
dx
=
ub_comm_dx
.
get_ubuf_output
(
1
)
ub_local_buffer
=
ub_comm_dx
.
get_ubuf_output
(
0
)
if
with_ub_all_gather_x
:
dx_local
=
ub_local_buffer
else
:
dx_local
=
torch
.
empty_like
(
ub_local_buffer
)
else
:
# Allocate grad input tensor
if
with_fp8_grad_input
:
fp8_dtype
=
get_fp8_te_dtype
(
grad_input_fp8_meta
[
"recipe"
],
fprop_tensor
=
False
,
)
data
=
torch
.
empty
(
(
dy
.
size
(
0
),
w
.
size
(
-
1
)),
dtype
=
torch
.
uint8
,
device
=
device
,
)
dx
=
Float8Tensor
(
data
=
data
,
fp8_meta
=
grad_input_fp8_meta
,
fp8_meta_forward
=
False
,
fp8_meta_index
=
0
,
fp8_dtype
=
fp8_dtype
,
dtype
=
dtype
,
# wgrad GEMM
dw
=
None
if
weight_requires_grad
:
# Initialize grad output
if
tensor_parallel_mode
==
"row"
and
isinstance
(
grad_output_quantizer
,
MXFP8Quantizer
):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, MXFP8 does not
# allow reusing the grad output that was gathered for
# the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
dy
,
_
=
gather_along_first_dim
(
grad_output
,
tensor_parallel_group
,
quantizer
=
grad_output_quantizer
,
)
else
:
d
x
=
torch
.
empty
(
(
dy
.
size
(
0
),
w
.
size
(
-
1
)),
dtype
=
dtype
,
device
=
device
,
if
tensor_parallel_mode
==
"column"
:
d
y
=
dy_local
if
dy
is
None
:
raise
RuntimeError
(
"wgrad GEMM requires grad output tensor, which has not been initialized"
)
dx_local
=
dx
if
isinstance
(
dy
,
QuantizedTensorBase
):
dy
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
# Allocate grad input tensor
if
grad_weight
is
None
:
if
accumulate_into_grad_weight
:
raise
ValueError
(
"Attempted to accumulate into grad weight bufferwithout providing grad weight"
# Initialize input tensor
if
tensor_parallel_mode
==
"row"
:
x
=
x_local
if
x
is
None
:
raise
RuntimeError
(
"wgrad GEMM requires input tensor, which has not been initialized"
)
grad_weight
=
torch
.
empty
(
weight_dims
,
dtype
=
dtype
,
device
=
device
,
memory_format
=
torch
.
contiguous_format
,
)
if
isinstance
(
x
,
QuantizedTensorBase
):
x
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
# Check grad weight tensor
dw
=
grad_weight
dw_dtype
=
dtype
if
dw
is
None
:
if
accumulate_into_grad_weight
:
raise
ValueError
(
"Attempted to accumulate into grad weight tensor "
"without providing grad weight tensor"
)
else
:
dw_dtype
=
dw
.
dtype
# Perform dgrad GEMM
if
with_fp8_compute
:
kwargs
=
{
"out"
:
dx
,
"use_split_accumulator"
:
True
}
if
with_ub_all_gather_dy
:
kwargs
[
"ub_algo"
]
=
ub_algo_dy
kwargs
[
"ub"
]
=
ub_comm_dy
elif
with_ub_all_gather_x
:
kwargs
[
"ub_algo"
]
=
ub_algo_x
kwargs
[
"ub"
]
=
ub_comm_x
elif
with_ub_reduce_scatter_dx
:
kwargs
[
"ub_algo"
]
=
ub_algo_dx
kwargs
[
"ub"
]
=
ub_comm_dx
kwargs
[
"extra_output_tensor"
]
=
dx_local
if
with_fp8_grad_input
:
fp8_meta
,
fp8_meta_index
=
get_fp8_meta_from_fp8_tensor
(
dx
)
kwargs
.
update
(
{
"out"
:
dx
.
_data
,
"out_index"
:
fp8_meta_index
,
"fp8_meta_tensor"
:
fp8_meta
,
"D_dtype"
:
dx
.
_fp8_dtype
,
}
)
fp8_gemm
(
w
.
transpose_2d
(),
w
.
_scale_inv
,
0
,
w
.
_fp8_dtype
,
dy
.
_data
,
dy
.
_scale_inv
,
0
,
dy
.
_fp8_dtype
,
dy
.
dtype
,
get_workspace
(),
**
kwargs
,
)
else
:
kwargs
=
{
"grad"
:
True
,
"layout"
:
"NN"
,
"out"
:
dx
}
if
with_ub_all_gather_dy
:
kwargs
[
"ub_algo"
]
=
ub_algo_dy
kwargs
[
"ub"
]
=
ub_comm_dy
elif
with_ub_all_gather_x
:
kwargs
[
"ub_algo"
]
=
ub_algo_x
kwargs
[
"ub"
]
=
ub_comm_x
elif
with_ub_reduce_scatter_dx
:
kwargs
[
"ub_algo"
]
=
ub_algo_dx
kwargs
[
"ub"
]
=
ub_comm_dx
kwargs
[
"extra_output_tensor"
]
=
dx_local
gemm
(
w
,
dy
,
dx
.
dtype
,
get_workspace
(),
**
kwargs
)
grad_input
=
reshape
(
dx_local
,
input_dims
)
# Perform wgrad GEMM
if
not
weight_requires_grad
:
pass
elif
with_fp8_compute
:
kwargs
=
{
"accumulate"
:
accumulate_into_grad_weight
,
"out"
:
grad_weight
,
"use_split_accumulator"
:
True
,
}
if
with_ub_reduce_scatter_dx
:
kwargs
[
"ub_algo"
]
=
ub_algo_dx
kwargs
[
"ub"
]
=
ub_comm_dx
fp8_gemm
(
x
.
transpose_2d
(),
x
.
_scale_inv
,
0
,
x
.
_fp8_dtype
,
dy
.
transpose_2d
(),
dy
.
_scale_inv
,
0
,
dy
.
_fp8_dtype
,
grad_weight
.
dtype
,
get_workspace
(),
**
kwargs
,
)
else
:
kwargs
=
{
"accumulate"
:
accumulate_into_grad_weight
,
"layout"
:
"NT"
,
"grad"
:
True
,
"use_bias"
:
bias_requires_grad
,
"out"
:
grad_weight
,
}
if
with_ub_reduce_scatter_dx
:
kwargs
[
"ub_algo"
]
=
ub_algo_dx
kwargs
[
"ub"
]
=
ub_comm_dx
grad_weight
,
db
,
_
=
gemm
(
# Perform wgrad GEMM
dw
,
*
_
=
general_gemm
(
x
,
dy
,
grad_weight
.
dtype
,
get_workspace
(),
**
kwargs
,
out_dtype
=
dw_dtype
,
accumulate
=
accumulate_into_grad_weight
,
layout
=
"NT"
,
out
=
dw
,
use_split_accumulator
=
_2X_ACC_WGRAD
,
grad
=
True
,
ub
=
ub_comm_wgrad
,
ub_type
=
ub_type_wgrad
,
bulk_overlap
=
with_bulk_overlap
,
)
# Bulk overlap reduce-scatter with non-FP8 buffer is
# in-place. Need to copy grad input tensor to avoid data
# corruption in Userbuffers buffer.
if
with_wgrad_reduce_scatter_dx
:
dx_local
=
dx_local
.
clone
()
# Compute grad bias if needed
if
db_async
is
not
None
:
db_async
.
wait
()
if
bias_requires_grad
:
if
db
is
None
:
db
=
dy
.
sum
(
dim
=
0
)
extra_outputs
[
"grad_bias"
]
=
db
return
grad_input
,
grad_weight
,
extra_outputs
return
dx_local
,
dw
,
extra_outputs
def
fuser_backward
(
self
,
...
...
@@ -633,40 +516,24 @@ class UserbuffersBackwardLinear(FusedOperation):
else
:
accumulate_into_main_grad
=
False
# Hackily workaround Userbuffers bug with non-FP8 dgrad
# reduce-scatter overlap
weight_requires_grad
=
linear_op_ctx
.
weight_requires_grad
if
not
linear_op_ctx
.
with_fp8_compute
and
not
weight_requires_grad
:
warnings
.
warn
(
"There is a correctness bug when using Userbuffers "
"to overlap a dgrad reduce-scatter with a non-FP8 dgrad GEMM. "
"Hackily working around by overlapping dgrad reduce-scatter "
"with wgrad GEMM, even though wgrad isn't needed. "
"Please contact Transformer Engine team "
"if you encounter this use-case."
)
weight_requires_grad
=
True
# Linear backward pass
retval
=
UserbuffersBackwardLinear
.
_functional_backward
(
grad_output
=
grad_output
,
input
=
x_local
,
weight
=
linear_op
.
weight
,
input_dims
=
linear_op_ctx
.
input_dims
,
weight_dims
=
linear_op
.
weight
.
size
(),
weight_requires_grad
=
weight_requires_grad
,
weight_requires_grad
=
linear_op_ctx
.
weight_requires_grad
,
bias_requires_grad
=
(
bias_op
is
not
None
),
device
=
linear_op
.
device
,
dtype
=
linear_op_ctx
.
dtype
,
grad_weight
=
grad_weight
,
accumulate_into_grad_weight
=
accumulate_into_main_grad
,
tensor_parallel_mode
=
self
.
tensor_parallel_mode
,
tensor_parallel_group
=
self
.
tensor_parallel_group
,
sequence_parallel
=
self
.
sequence_parallel
,
with_fp8_compute
=
linear_op_ctx
.
with_fp8_compute
,
weight_fp8_meta
=
linear_op_ctx
.
weight_fp8_meta
,
grad_output_fp8_meta
=
linear_op_ctx
.
grad_output_fp8_meta
,
grad_input_fp8_meta
=
linear_op_ctx
.
grad_input_fp8_meta
,
with_quantized_compute
=
linear_op_ctx
.
with_quantized_compute
,
input_quantizer
=
linear_op_ctx
.
input_quantizer
,
weight_quantizer
=
linear_op_ctx
.
weight_quantizer
,
grad_output_quantizer
=
linear_op_ctx
.
grad_output_quantizer
,
grad_input_quantizer
=
None
,
# Not supported
ub_comm_name
=
linear_op
.
_userbuffers_options
[
"comm_name"
],
)
grad_input
,
grad_weight
,
extra_outputs
=
retval
...
...
@@ -707,8 +574,6 @@ def fuse_userbuffers_backward_linear(
"""
return
ops
### TODO Debug Userbuffers support
# Return immediately if environment is not distributed
if
not
torch
.
distributed
.
is_initialized
()
or
torch
.
distributed
.
get_world_size
()
==
1
:
return
ops
...
...
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
View file @
f8c2af4c
...
...
@@ -4,20 +4,25 @@
"""Linear layer forward with Userbuffers communication."""
# pylint: skip-file ### TODO Debug Userbuffers support
from
__future__
import
annotations
from
collections.abc
import
Iterable
from
typing
import
Any
,
Optional
import
torch
from
transformer_engine_torch
import
CommOverlap
Algo
from
transformer_engine_torch
import
CommOverlap
Type
from
...cpp_extensions
import
general_gemm
from
...distributed
import
get_distributed_world_size
from
...float8_tensor
import
Float8Tensor
from
...fp8
import
FP8GlobalStateManager
,
get_fp8_te_dtype
from
...module.base
import
get_ub
,
get_workspace
from
...fp8
import
FP8GlobalStateManager
from
...module.base
import
(
fill_userbuffers_buffer_for_all_gather
,
get_ub
,
get_workspace
,
_2X_ACC_FPROP
,
)
from
...tensor.quantized_tensor
import
QuantizedTensorBase
,
Quantizer
from
...tensor.float8_tensor
import
Float8Quantizer
from
...tensor._internal.float8_tensor_base
import
Float8TensorBase
from
...utils
import
canonicalize_device
,
canonicalize_dtype
from
..basic
import
BasicLinear
,
Bias
,
ReduceScatter
from
..op
import
(
...
...
@@ -26,12 +31,6 @@ from ..op import (
FusibleOperation
,
OperationContext
,
)
from
.._common
import
(
convert_tensor
,
get_fp8_meta_from_fp8_tensor
,
is_float8_tensor
,
reshape
,
)
class
UserbuffersForwardLinear
(
FusedOperation
):
...
...
@@ -51,9 +50,6 @@ class UserbuffersForwardLinear(FusedOperation):
reduce_scatter
:
Optional
[
ReduceScatter
],
)
->
None
:
### TODO Debug Userbuffers support
raise
NotImplementedError
(
"Userbuffers support has been broken by recent refactors"
)
# Basic operations that comprise this fused operation
op_idxs
=
{
"linear"
:
0
,
"bias"
:
None
,
"reduce_scatter"
:
None
}
ops
=
[
linear
]
...
...
@@ -98,10 +94,10 @@ class UserbuffersForwardLinear(FusedOperation):
tensor_parallel_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
tensor_parallel_size
:
Optional
[
int
]
=
None
,
sequence_parallel
:
bool
=
False
,
with_
fp8
_compute
:
bool
=
False
,
input_
fp8_meta
:
Optional
[
dict
[
str
,
Any
]
]
=
None
,
weight_
fp8_meta
:
Optional
[
dict
[
str
,
Any
]
]
=
None
,
output_
fp8_meta
:
Optional
[
dict
[
str
,
Any
]
]
=
None
,
with_
quantized
_compute
:
bool
=
False
,
input_
quantizer
:
Optional
[
Quantizer
]
=
None
,
weight_
quantizer
:
Optional
[
Quantizer
]
=
None
,
output_
quantizer
:
Optional
[
Quantizer
]
=
None
,
ub_comm_name
:
str
,
)
->
tuple
[
torch
.
Tensor
,
dict
]:
"""Functional API for forward pass
...
...
@@ -127,16 +123,14 @@ class UserbuffersForwardLinear(FusedOperation):
parallelism, i.e. distributing input or output tensors
along outer dimension (sequence or batch dim) when not
distributing along inner dimension (embedding dim)
with_fp8_compute: bool, default = `False`
Whether to perform compute in FP8
input_fp8_meta: dict, optional
FP8 metadata for casting input tensor to FP8. Required for
FP8 compute if input is not already in FP8.
weight_fp8_meta: dict, optional
FP8 metadata for casting weight tensor to FP8. Required for
FP8 compute if weight is not already in FP8.
output_fp8_meta: dict, optional
FP8 metadata for casting output tensor to FP8
with_quantized_compute: bool, default = `False`
Whether to perform compute with quantized data.
input_quantizer: Quantizer, optional
Builder class for quantized input tensor.
weight_quantizer: Quantizer, optional
Builder class for quantized weight tensor.
output_quantizer: Quantizer, optional
Builder class for quantized output tensor.
ub_comm_name: str
Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is
used to access the corresponding Userbuffers communicators
...
...
@@ -166,23 +160,6 @@ class UserbuffersForwardLinear(FusedOperation):
if
dtype
not
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
raise
ValueError
(
f
"Supported dtypes are float32, float16, bfloat16 (got
{
dtype
}
)"
)
# Input tensor dims
input_dims
=
tuple
(
input
.
size
())
weight_dims
=
tuple
(
weight
.
size
())
if
len
(
weight_dims
)
!=
2
:
raise
ValueError
(
f
"Weight tensor is not 2D (shape=
{
weight_dims
}
)"
)
if
len
(
input_dims
)
==
0
or
weight_dims
[
1
]
!=
input_dims
[
-
1
]:
raise
ValueError
(
f
"Input tensor (shape=
{
input_dims
}
) "
f
"and weight tensor (shape=
{
weight_dims
}
) "
"are not compatible"
)
# Output tensor dims
output_dims
=
list
(
input_dims
)
output_dims
[
0
]
=
-
1
output_dims
[
-
1
]
=
weight_dims
[
0
]
# Check tensor parallel group
if
tensor_parallel_size
is
None
:
tensor_parallel_size
=
get_distributed_world_size
(
tensor_parallel_group
)
...
...
@@ -196,235 +173,106 @@ class UserbuffersForwardLinear(FusedOperation):
if
not
sequence_parallel
:
raise
RuntimeError
(
f
"Invalid configuration for Userbuffers (
{
sequence_parallel
=
}
)"
)
# Check if FP8 is enabled
if
with_fp8_compute
:
if
input_fp8_meta
is
None
and
not
is_float8_tensor
(
input
):
raise
ValueError
(
"No FP8 metadata was provided for casting input to FP8"
)
if
weight_fp8_meta
is
None
and
not
is_float8_tensor
(
weight
):
raise
ValueError
(
"No FP8 metadata was provided for casting weight to FP8"
)
# Check quantizers
if
with_quantized_compute
:
if
input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for input tensor"
)
if
weight_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for weight tensor"
)
if
output_quantizer
is
not
None
:
raise
ValueError
(
"FP8 output is not supported"
)
else
:
input_fp8_meta
=
None
weight_fp8_meta
=
None
output_fp8_meta
=
None
with_fp8_output
=
(
with_fp8_compute
and
tensor_parallel_mode
!=
"row"
and
output_fp8_meta
is
not
None
)
input_quantizer
=
None
weight_quantizer
=
None
output_quantizer
=
None
# Get Userbuffers communicator
ub_comm
=
get_ub
(
ub_comm_name
+
"_fprop"
)
ub_local_buffer
=
ub_comm
.
get_ubuf_output
(
0
)
ub_global_buffer
=
ub_comm
.
get_ubuf_output
(
1
)
with_ub_all_gather
=
tensor_parallel_mode
==
"column"
with_ub_reduce_scatter
=
tensor_parallel_mode
==
"row"
ub_type
=
CommOverlapType
.
AG
if
with_ub_all_gather
else
CommOverlapType
.
RS
# Choose Userbuffers communication algorithm
ub_algo
=
None
# Initialize input tensor
x_local
=
input
x
=
None
if
with_ub_all_gather
:
if
with_fp8_compute
and
ub_comm
.
is_atomic_gemm
():
ub_algo
=
CommOverlapAlgo
.
ATOMIC_GEMM_AG_P2P
else
:
ub_algo
=
CommOverlapAlgo
.
SPLIT_PIPELINED_AG_P2P
elif
with_ub_reduce_scatter
:
is_atomic_gemm
=
with_fp8_compute
and
ub_comm
.
is_atomic_gemm
()
ub_algo
=
{
(
True
,
True
):
CommOverlapAlgo
.
ATOMIC_GEMM_RS_P2P
,
(
True
,
False
):
CommOverlapAlgo
.
SPLIT_PIPELINED_RS_P2P
,
(
False
,
True
):
CommOverlapAlgo
.
ATOMIC_GEMM_RS
,
(
False
,
False
):
CommOverlapAlgo
.
SPLIT_PIPELINED_RS
,
}[(
ub_comm
.
is_p2p_overlap
(),
is_atomic_gemm
)]
else
:
raise
RuntimeError
(
"Could not choose Userbuffers communication algorithm"
)
# Cast input tensor to correct dtype
x_local
=
reshape
(
input
,
(
-
1
,
input_dims
[
-
1
]),
device
=
device
,
dtype
=
dtype
,
)
if
with_fp8_compute
and
not
is_float8_tensor
(
x_local
):
fp8_dtype
=
get_fp8_te_dtype
(
input_fp8_meta
[
"recipe"
],
fprop_tensor
=
True
,
)
with_transpose_cache
=
weight
.
requires_grad
if
tensor_parallel_mode
==
"column"
and
sequence_parallel
:
with_transpose_cache
=
False
x_local
=
Float8Tensor
.
to_float8
(
if
input_quantizer
is
not
None
:
if
not
isinstance
(
x_local
,
QuantizedTensorBase
):
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
if
isinstance
(
input_quantizer
,
Float8Quantizer
):
input_quantizer
.
set_usage
(
columnwise
=
False
)
x_local
=
input_quantizer
(
x_local
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
x
,
x_local
=
fill_userbuffers_buffer_for_all_gather
(
ub_comm
,
x_local
,
fp8_meta
=
input_fp8_meta
,
fp8_meta_forward
=
True
,
fp8_meta_index
=
0
,
fp8_dtype
=
fp8_dtype
,
data
=
(
ub_local_buffer
if
with_ub_all_gather
else
None
),
with_transpose_cache
=
with_transpose_cache
,
input_quantizer
,
tensor_parallel_group
,
)
elif
not
with_fp8_compute
and
is_float8_tensor
(
x_local
):
if
with_ub_all_gather
:
x_local
=
ub_local_buffer
.
copy_
(
x_local
)
else
:
x_local
=
x_local
.
dequantize
()
# Initialize buffers for UB all-gather if needed
x
=
x_local
if
with_ub_all_gather
:
if
with_fp8_compute
:
x
=
Float8Tensor
.
make_like
(
x_local
,
data
=
ub_global_buffer
)
if
x_local
.
_data
.
data_ptr
()
!=
ub_local_buffer
.
data_ptr
():
ub_local_buffer
.
copy_
(
x_local
.
_data
)
else
:
x_local
.
_data
=
torch
.
empty_like
(
x_local
.
_data
)
else
:
if
with_quantized_compute
:
if
not
isinstance
(
x_local
,
QuantizedTensorBase
):
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
x_local
=
input_quantizer
(
x_local
)
else
:
x
=
ub_global_buffer
if
x_local
.
data_ptr
()
!=
ub_local_buffer
.
data_ptr
():
ub_local_buffer
.
copy_
(
x_local
)
else
:
x_local
=
torch
.
empty_like
(
x_local
)
# Check weight tensor
w
=
convert_tensor
(
weight
,
device
=
device
,
dtype
=
dtype
,
memory_format
=
torch
.
contiguous_format
,
)
if
with_fp8_compute
and
not
is_float8_tensor
(
w
):
fp8_dtype
=
get_fp8_te_dtype
(
weight_fp8_meta
[
"recipe"
],
fprop_tensor
=
True
,
)
w
=
Float8Tensor
.
to_float8
(
w
,
fp8_meta
=
weight_fp8_meta
,
fp8_meta_index
=
0
,
fp8_dtype
=
fp8_dtype
,
)
elif
not
with_fp8_compute
and
is_float8_tensor
(
w
):
if
isinstance
(
x_local
,
QuantizedTensorBase
):
x_local
=
x_local
.
dequantize
(
dtype
=
dtype
)
if
x_local
.
dtype
!=
dtype
:
x_local
=
x_local
.
to
(
dtype
=
dtype
)
x
=
x_local
# Initialize weight tensor
w
=
weight
w_is_quantized
=
isinstance
(
w
,
QuantizedTensorBase
)
if
with_quantized_compute
and
not
w_is_quantized
:
weight_quantizer
.
set_usage
(
rowwise
=
True
)
w
=
weight_quantizer
(
w
)
elif
not
with_quantized_compute
and
w_is_quantized
:
w
=
w
.
dequantize
()
if
not
with_quantized_compute
and
w
.
dtype
!=
dtype
:
w
=
w
.
to
(
dtype
=
dtype
)
# Check bias tensor
b
=
None
if
bias
is
not
None
:
b
=
convert_tensor
(
bias
,
device
=
device
,
dtype
=
dtype
,
memory_format
=
torch
.
contiguous_format
,
)
# Construct output tensor
y
=
None
y_local
=
None
# Construct output tensor if needed
reduce_scatter_output
=
None
if
with_ub_reduce_scatter
:
# Initialize buffers for UB reduce-scatter
if
with_fp8_output
:
fp8_meta_key
=
FP8GlobalStateManager
.
get_meta_tensor_key
(
forward
=
True
)
fp8_dtype
=
get_fp8_te_dtype
(
output_fp8_meta
[
"recipe"
],
fprop_tensor
=
True
,
)
y
=
Float8Tensor
(
data
=
ub_global_buffer
,
fp8_meta
=
output_fp8_meta
,
fp8_meta_forward
=
True
,
fp8_meta_index
=
0
,
fp8_dtype
=
fp8_dtype
,
fp8_scale_inv
=
output_fp8_meta
[
fp8_meta_key
].
scale_inv
[
0
],
dtype
=
dtype
,
)
ub_comm
.
set_ubuf_scale_inv
(
y
.
_scale_inv
)
else
:
y
=
ub_global_buffer
y_local
=
torch
.
empty
(
(
x
.
size
(
0
)
//
tensor_parallel_size
,
weight_dims
[
0
]),
dtype
=
dtype
,
device
=
device
,
)
else
:
# Allocate output tensor
if
with_fp8_output
:
fp8_dtype
=
get_fp8_te_dtype
(
output_fp8_meta
[
"recipe"
],
fprop_tensor
=
True
,
)
data
=
torch
.
empty
(
(
x
.
size
(
0
),
weight_dims
[
0
]),
dtype
=
torch
.
uint8
,
device
=
device
,
)
y
=
Float8Tensor
(
data
=
data
,
fp8_meta
=
output_fp8_meta
,
fp8_meta_forward
=
True
,
fp8_meta_index
=
0
,
fp8_dtype
=
fp8_dtype
,
dtype
=
dtype
,
)
else
:
y
=
torch
.
empty
(
(
x
.
size
(
0
),
weight_dims
[
0
]),
dtype
=
dtype
,
device
=
device
,
)
y_local
=
y
y_local_size
=
list
(
x
.
size
())
y_local_size
[
0
]
//=
tensor_parallel_size
y_local_size
[
-
1
]
=
w
.
size
(
0
)
reduce_scatter_output
=
torch
.
empty
(
y_local_size
,
dtype
=
dtype
,
device
=
device
)
# Perform GEMM
if
with_fp8_compute
:
kwargs
=
{
"out"
:
y
,
"bias"
:
b
,
"use_bias"
:
(
b
is
not
None
),
"use_split_accumulator"
:
False
,
"ub_algo"
:
ub_algo
,
"ub"
:
ub_comm
,
}
if
with_ub_all_gather
:
kwargs
[
"extra_output_tensor"
]
=
x_local
.
_data
if
with_ub_reduce_scatter
:
kwargs
[
"extra_output_tensor"
]
=
y_local
if
with_fp8_output
:
fp8_meta
,
fp8_meta_index
=
get_fp8_meta_from_fp8_tensor
(
y
)
kwargs
.
update
(
{
"out"
:
y
.
_data
,
"out_index"
:
fp8_meta_index
,
"fp8_meta_tensor"
:
fp8_meta
,
"D_dtype"
:
y
.
_fp8_dtype
,
}
)
fp8_gemm
(
w
.
_data
,
w
.
_scale_inv
,
0
,
w
.
_fp8_dtype
,
x
.
_data
,
x
.
_scale_inv
,
0
,
x
.
_fp8_dtype
,
y
.
dtype
,
get_workspace
(),
**
kwargs
,
)
gemm_output
,
*
_
,
reduce_scatter_output
=
general_gemm
(
w
,
x
,
get_workspace
(),
out_dtype
=
dtype
,
quantization_params
=
output_quantizer
,
bias
=
bias
,
use_split_accumulator
=
_2X_ACC_FPROP
,
ub
=
ub_comm
,
ub_type
=
ub_type
,
extra_output
=
reduce_scatter_output
,
)
if
with_ub_reduce_scatter
:
y_local
=
reduce_scatter_output
else
:
kwargs
=
{
"out"
:
y
,
"bias"
:
b
,
"use_bias"
:
(
b
is
not
None
),
"ub_algo"
:
ub_algo
,
"ub"
:
ub_comm
,
}
if
with_ub_all_gather
:
kwargs
[
"extra_output_tensor"
]
=
x_local
if
with_ub_reduce_scatter
:
kwargs
[
"extra_output_tensor"
]
=
y_local
gemm
(
w
,
x
,
y
.
dtype
,
get_workspace
(),
**
kwargs
)
# Reshape output tensor
out
=
reshape
(
y_local
,
output_dims
)
y_local
=
gemm_output
# Detach input tensor if needed
# Note: PyTorch autograd produces esoteric errors if we save
# input tensor as context for backward pass.
if
x_local
is
input
:
x_local
=
x_local
.
detach
()
# Configure input tensor for backward pass
if
with_quantized_compute
and
isinstance
(
x_local
,
QuantizedTensorBase
):
if
not
(
isinstance
(
x_local
,
Float8TensorBase
)
and
with_ub_all_gather
):
# FP8 does not support all-gather of transpose data
x_local
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
# Return cast tensors
extra_outputs
=
{
"input"
:
x_local
,
"weight"
:
w
}
return
out
,
extra_outputs
return
y_local
,
extra_outputs
def
fuser_forward
(
self
,
...
...
@@ -450,23 +298,22 @@ class UserbuffersForwardLinear(FusedOperation):
if
basic_op_kwargs
[
idx
]:
raise
ValueError
(
"Bias operation forward does not expect keyword arguments"
)
# FP8 metadata
with_fp8_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
input_fp8_meta
=
None
weight_fp8_meta
=
None
output_fp8_meta
=
None
grad_output_fp8_meta
=
None
grad_input_fp8_meta
=
None
if
with_fp8_compute
:
input_fp8_meta
=
linear_op
.
get_fp8_meta
(
"input"
)
weight_fp8_meta
=
linear_op
.
get_fp8_meta
(
"param"
)
next_op
=
basic_op_next_ops
[
-
1
]
if
next_op
is
not
None
and
next_op
.
num_fp8_scales
(
"input"
)
>
0
:
output_fp8_meta
=
next_op
.
get_fp8_meta
(
"input"
)
grad_output_fp8_meta
=
linear_op
.
get_fp8_meta
(
"grad_output"
)
# Quantization metadata
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
input_quantizer
=
None
weight_quantizer
=
None
grad_output_quantizer
=
None
grad_input_quantizer
=
None
if
with_quantized_compute
:
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
not
recipe
.
delayed
()
and
not
recipe
.
mxfp8
():
raise
RuntimeError
(
"Userbuffers is only supported with FP8 delayed scaling recipe"
)
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
grad_output_quantizer
=
linear_op
.
get_quantizer
(
"backward"
,
0
)
prev_op
=
basic_op_prev_ops
[
0
]
if
prev_op
is
not
None
and
prev_op
.
num_
fp8_scales
(
"grad_output"
)
>
0
:
grad_input_
fp8_meta
=
prev_op
.
get_
fp8_meta
(
"grad_output"
)
if
prev_op
is
not
None
and
prev_op
.
num_
quantizers
(
"backward"
)
>
0
and
recipe
.
delayed
()
:
grad_input_
quantizer
=
prev_op
.
get_
quantizer
(
"backward"
,
0
)
# Get autocast dtype if needed
dtype
=
None
...
...
@@ -482,26 +329,26 @@ class UserbuffersForwardLinear(FusedOperation):
input
=
input_
,
weight
=
linear_op
.
weight
,
bias
=
bias
,
device
=
linear_op
.
device
,
dtype
=
dtype
,
tensor_parallel_mode
=
self
.
tensor_parallel_mode
,
tensor_parallel_group
=
self
.
tensor_parallel_group
,
tensor_parallel_size
=
self
.
tensor_parallel_size
,
sequence_parallel
=
self
.
sequence_parallel
,
with_
fp8
_compute
=
with_
fp8
_compute
,
input_
fp8_meta
=
input_fp8_meta
,
weight_
fp8_meta
=
weight_fp8_meta
,
output_
fp8_meta
=
output_fp8_meta
,
with_
quantized
_compute
=
with_
quantized
_compute
,
input_
quantizer
=
input_quantizer
,
weight_
quantizer
=
weight_quantizer
,
output_
quantizer
=
None
,
# Not supported
ub_comm_name
=
linear_op
.
_userbuffers_options
[
"comm_name"
],
)
x_local
=
extra_outputs
[
"input"
]
# Save state for backward pass
linear_op_ctx
.
save_for_backward
(
x_local
)
linear_op_ctx
.
with_fp8_compute
=
with_fp8_compute
linear_op_ctx
.
weight_fp8_meta
=
weight_fp8_meta
linear_op_ctx
.
grad_output_fp8_meta
=
grad_output_fp8_meta
linear_op_ctx
.
grad_input_fp8_meta
=
grad_input_fp8_meta
linear_op_ctx
.
with_quantized_compute
=
with_quantized_compute
linear_op_ctx
.
input_quantizer
=
input_quantizer
linear_op_ctx
.
weight_quantizer
=
weight_quantizer
linear_op_ctx
.
grad_output_quantizer
=
grad_output_quantizer
linear_op_ctx
.
grad_input_quantizer
=
grad_input_quantizer
linear_op_ctx
.
dtype
=
dtype
linear_op_ctx
.
input_dims
=
input_
.
size
()
linear_op_ctx
.
input_requires_grad
=
input_
.
requires_grad
...
...
@@ -529,8 +376,6 @@ def fuse_userbuffers_forward_linear(
"""
return
ops
### TODO Debug Userbuffers support
# Return immediately if environment is not distributed
if
not
torch
.
distributed
.
is_initialized
()
or
torch
.
distributed
.
get_world_size
()
==
1
:
return
ops
...
...
transformer_engine/pytorch/setup.py
View file @
f8c2af4c
...
...
@@ -55,7 +55,17 @@ if __name__ == "__main__":
description
=
"Transformer acceleration library - Torch Lib"
,
ext_modules
=
ext_modules
,
cmdclass
=
{
"build_ext"
:
CMakeBuildExtension
},
install_requires
=
[
"torch"
],
setup_requires
=
[
"torch>=2.1"
,
"nvidia-cuda-runtime-cu12"
,
"nvidia-cublas-cu12"
,
"nvidia-cudnn-cu12"
,
"nvidia-cuda-cccl-cu12"
,
"nvidia-cuda-nvcc-cu12"
,
"nvidia-nvtx-cu12"
,
"nvidia-cuda-nvrtc-cu12"
,
],
install_requires
=
[
"torch>=2.1"
],
tests_require
=
[
"numpy"
,
"torchvision"
],
)
if
any
(
x
in
sys
.
argv
for
x
in
(
"."
,
"sdist"
,
"bdist_wheel"
)):
...
...
Prev
1
…
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment