Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
f8c2af4c
Commit
f8c2af4c
authored
May 21, 2025
by
yuguo
Browse files
Merge commit '
1d903f5e
' of...
Merge commit '
1d903f5e
' of
https://github.com/NVIDIA/TransformerEngine
parents
e92773a3
1d903f5e
Changes
211
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1856 additions
and
1953 deletions
+1856
-1953
transformer_engine/pytorch/csrc/type_converters.cpp
transformer_engine/pytorch/csrc/type_converters.cpp
+0
-0
transformer_engine/pytorch/csrc/type_shim.h
transformer_engine/pytorch/csrc/type_shim.h
+0
-359
transformer_engine/pytorch/csrc/util.cpp
transformer_engine/pytorch/csrc/util.cpp
+77
-0
transformer_engine/pytorch/csrc/util.h
transformer_engine/pytorch/csrc/util.h
+0
-2
transformer_engine/pytorch/distributed.py
transformer_engine/pytorch/distributed.py
+48
-18
transformer_engine/pytorch/fp8.py
transformer_engine/pytorch/fp8.py
+2
-2
transformer_engine/pytorch/graph.py
transformer_engine/pytorch/graph.py
+3
-1
transformer_engine/pytorch/jit.py
transformer_engine/pytorch/jit.py
+11
-10
transformer_engine/pytorch/module/_common.py
transformer_engine/pytorch/module/_common.py
+0
-37
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+189
-33
transformer_engine/pytorch/module/grouped_linear.py
transformer_engine/pytorch/module/grouped_linear.py
+13
-7
transformer_engine/pytorch/module/layernorm.py
transformer_engine/pytorch/module/layernorm.py
+4
-7
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+349
-252
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+363
-247
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+362
-266
transformer_engine/pytorch/module/rmsnorm.py
transformer_engine/pytorch/module/rmsnorm.py
+4
-6
transformer_engine/pytorch/ops/basic/basic_linear.py
transformer_engine/pytorch/ops/basic/basic_linear.py
+8
-3
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
...r_engine/pytorch/ops/fused/userbuffers_backward_linear.py
+285
-420
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
...er_engine/pytorch/ops/fused/userbuffers_forward_linear.py
+127
-282
transformer_engine/pytorch/setup.py
transformer_engine/pytorch/setup.py
+11
-1
No files found.
transformer_engine/pytorch/csrc/
extensions/
type_converters.cpp
→
transformer_engine/pytorch/csrc/type_converters.cpp
View file @
f8c2af4c
File moved
transformer_engine/pytorch/csrc/type_shim.h
deleted
100644 → 0
View file @
e92773a3
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#pragma once
#include <ATen/ATen.h>
#include "common/utils.cuh"
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: { \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) { \
case at::ScalarType::Float: { \
using scalar_t_in = float; \
switch (TYPEOUT) { \
case at::ScalarType::Float: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) { \
case at::ScalarType::Double: { \
using scalar_t_in = double; \
switch (TYPEOUT) { \
case at::ScalarType::Double: { \
using scalar_t_out = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_in = float; \
switch (TYPEOUT) { \
case at::ScalarType::Float: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
{
// lanes is intended to be <= 32.
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
x
[
tid
]
+
x
[
tid
+
i
];
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
x
[
tid
]
+
x
[
tid
+
32
];
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
#ifdef __HIP_PLATFORM_AMD__
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down
(
final
,
i
,
THREADS_PER_WARP
);
#else
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
#endif
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
}
__syncthreads
();
// Avoid potential write before read race when reduce_block_into_lanes is called back to back
return
final
;
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes_max_op
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
{
// lanes is intended to be <= 32.
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
i
]));
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
32
]));
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
#ifdef __HIP_PLATFORM_AMD__
final
=
fmaxf
(
fabsf
(
final
),
fabsf
(
__shfl_down
(
final
,
i
,
THREADS_PER_WARP
)));
#else
final
=
fmaxf
(
fabsf
(
final
),
fabsf
(
__shfl_down_sync
(
0xffffffff
,
final
,
i
)));
#endif
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
transformer_engine/pytorch/csrc/
extensions/swizzle
.cpp
→
transformer_engine/pytorch/csrc/
util
.cpp
View file @
f8c2af4c
...
@@ -4,10 +4,10 @@
...
@@ -4,10 +4,10 @@
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#include "extensions.h"
#include "transformer_engine/transformer_engine.h"
#include "util.h"
#include "util.h"
#include "common.h"
std
::
optional
<
at
::
Tensor
>
swizzle_scaling_factors
(
transformer_engine
::
TensorWrapper
&
input
,
std
::
optional
<
at
::
Tensor
>
swizzle_scaling_factors
(
transformer_engine
::
TensorWrapper
&
input
,
bool
rowwise
)
{
bool
rowwise
)
{
using
namespace
transformer_engine
::
pytorch
;
using
namespace
transformer_engine
::
pytorch
;
...
@@ -45,80 +45,33 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
...
@@ -45,80 +45,33 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
transformer_engine
::
TensorWrapper
input_cu
(
NVTE_MXFP8_1D_SCALING
);
transformer_engine
::
TensorWrapper
input_cu
(
NVTE_MXFP8_1D_SCALING
);
transformer_engine
::
TensorWrapper
output_cu
(
NVTE_MXFP8_1D_SCALING
);
transformer_engine
::
TensorWrapper
output_cu
(
NVTE_MXFP8_1D_SCALING
);
if
(
rowwise
)
{
if
(
rowwise
)
{
input_cu
.
set_rowwise_data
(
input
.
dptr
(),
DType
::
kFloat8E4M3
,
input_shape
);
input_cu
.
set_rowwise_data
(
input
.
dptr
(),
transformer_engine
::
DType
::
kFloat8E4M3
,
input_shape
);
input_cu
.
set_rowwise_scale_inv
(
scale_inv_dptr
,
DType
::
kFloat8E8M0
,
scale_inv_shape
);
input_cu
.
set_rowwise_scale_inv
(
scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
output_cu
.
set_rowwise_data
(
input
.
dptr
(),
DType
::
kFloat8E4M3
,
input_shape
);
scale_inv_shape
);
output_cu
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
DType
::
kFloat8E8M0
,
scale_inv_shape
);
output_cu
.
set_rowwise_data
(
input
.
dptr
(),
transformer_engine
::
DType
::
kFloat8E4M3
,
input_shape
);
output_cu
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
scale_inv_shape
);
}
else
{
}
else
{
input_cu
.
set_columnwise_data
(
input
.
columnwise_dptr
(),
DType
::
kFloat8E4M3
,
input_shape
);
input_cu
.
set_columnwise_data
(
input
.
columnwise_dptr
(),
transformer_engine
::
DType
::
kFloat8E4M3
,
input_cu
.
set_columnwise_scale_inv
(
scale_inv_dptr
,
DType
::
kFloat8E8M0
,
scale_inv_shape
);
input_shape
);
output_cu
.
set_columnwise_data
(
input
.
columnwise_dptr
(),
DType
::
kFloat8E4M3
,
input_shape
);
input_cu
.
set_columnwise_scale_inv
(
scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
output_cu
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
DType
::
kFloat8E8M0
,
scale_inv_shape
);
scale_inv_shape
);
output_cu
.
set_columnwise_data
(
input
.
columnwise_dptr
(),
transformer_engine
::
DType
::
kFloat8E4M3
,
input_shape
);
output_cu
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
scale_inv_shape
);
}
}
// Launch kernel
// Launch kernel
nvte_swizzle_scaling_factors
(
input_cu
.
data
(),
output_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
nvte_swizzle_scaling_factors
(
input_cu
.
data
(),
output_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
if
(
rowwise
)
{
if
(
rowwise
)
{
input
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
DType
::
kFloat8E8M0
,
scale_inv_shape
);
input
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
scale_inv_shape
);
}
else
{
}
else
{
input
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
DType
::
kFloat8E8M0
,
scale_inv_shape
);
input
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
scale_inv_shape
);
}
}
return
swizzled_scale_inv
;
return
swizzled_scale_inv
;
}
}
at
::
Tensor
rowwise_swizzle
(
at
::
Tensor
input
,
at
::
Tensor
scale_inv
)
{
using
namespace
transformer_engine
::
pytorch
;
NVTE_CHECK
(
input
.
element_size
()
==
1
,
"8-bit input required for swizzling scaling factors."
);
auto
options
=
at
::
TensorOptions
().
dtype
(
scale_inv
.
dtype
()).
device
(
torch
::
kCUDA
);
auto
swizzled_scale_inv
=
at
::
empty_like
(
scale_inv
,
options
);
void
*
scale_inv_dptr
=
getDataPtr
(
scale_inv
,
0
);
void
*
swizzled_scale_inv_dptr
=
getDataPtr
(
swizzled_scale_inv
,
0
);
auto
input_cu
=
makeTransformerEngineTensor
(
input
.
data_ptr
(),
getTensorShape
(
input
),
DType
::
kFloat8E4M3
,
nullptr
,
nullptr
,
scale_inv_dptr
,
getTensorShape
(
scale_inv
),
NVTE_MXFP8_1D_SCALING
);
auto
output_cu
=
makeTransformerEngineTensor
(
input
.
data_ptr
(),
getTensorShape
(
input
),
DType
::
kFloat8E4M3
,
nullptr
,
nullptr
,
swizzled_scale_inv_dptr
,
getTensorShape
(
swizzled_scale_inv
),
NVTE_MXFP8_1D_SCALING
);
// Launch kernel
nvte_swizzle_scaling_factors
(
input_cu
.
data
(),
output_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
return
swizzled_scale_inv
;
}
at
::
Tensor
columnwise_swizzle
(
at
::
Tensor
input
,
at
::
Tensor
scale_inv
)
{
using
namespace
transformer_engine
::
pytorch
;
NVTE_CHECK
(
input
.
element_size
()
==
1
,
"8-bit input required for swizzling scaling factors."
);
auto
options
=
at
::
TensorOptions
().
dtype
(
scale_inv
.
dtype
()).
device
(
torch
::
kCUDA
);
auto
swizzled_scale_inv
=
at
::
empty_like
(
scale_inv
,
options
);
// Return immediately if tensor is empty
if
(
scale_inv
.
numel
()
==
0
)
{
return
swizzled_scale_inv
;
}
void
*
scale_inv_dptr
=
getDataPtr
(
scale_inv
,
0
);
void
*
swizzled_scale_inv_dptr
=
getDataPtr
(
swizzled_scale_inv
,
0
);
auto
input_cu
=
makeTransformerEngineTensor
(
nullptr
,
input
.
data_ptr
(),
{
1
},
getTensorShape
(
input
),
DType
::
kFloat8E4M3
,
nullptr
,
nullptr
,
nullptr
,
scale_inv_dptr
,
{
1
},
getTensorShape
(
scale_inv
),
NVTE_MXFP8_1D_SCALING
);
auto
output_cu
=
makeTransformerEngineTensor
(
nullptr
,
input
.
data_ptr
(),
{
1
},
getTensorShape
(
input
),
DType
::
kFloat8E4M3
,
nullptr
,
nullptr
,
nullptr
,
swizzled_scale_inv_dptr
,
{
1
},
getTensorShape
(
swizzled_scale_inv
),
NVTE_MXFP8_1D_SCALING
);
// Launch kernel
nvte_swizzle_scaling_factors
(
input_cu
.
data
(),
output_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
return
swizzled_scale_inv
;
}
transformer_engine/pytorch/csrc/util.h
View file @
f8c2af4c
...
@@ -13,8 +13,6 @@
...
@@ -13,8 +13,6 @@
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transformer_engine.h"
bool
non_tn_fp8_gemm_supported
();
/* Swizzle the scaling factor of the input tensor.
/* Swizzle the scaling factor of the input tensor.
*
*
* The returned swizzled scaling factor tensor should be kept alive during the GEMM.
* The returned swizzled scaling factor tensor should be kept alive during the GEMM.
...
...
transformer_engine/pytorch/distributed.py
View file @
f8c2af4c
...
@@ -19,7 +19,12 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
...
@@ -19,7 +19,12 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from
torch.distributed.fsdp._common_utils
import
_get_module_fsdp_state
from
torch.distributed.fsdp._common_utils
import
_get_module_fsdp_state
from
torch.distributed.fsdp._traversal_utils
import
_get_fsdp_states_with_modules
from
torch.distributed.fsdp._traversal_utils
import
_get_fsdp_states_with_modules
from
.utils
import
non_tn_fp8_gemm_supported
,
safely_set_viewless_tensor_data
,
needs_quantized_gemm
from
.
import
torch_version
from
.utils
import
(
is_non_tn_fp8_gemm_supported
,
safely_set_viewless_tensor_data
,
needs_quantized_gemm
,
)
from
.constants
import
dist_group_type
from
.constants
import
dist_group_type
from
.fp8
import
FP8GlobalStateManager
,
fp8_autocast
from
.fp8
import
FP8GlobalStateManager
,
fp8_autocast
from
.tensor.float8_tensor
import
Float8Quantizer
,
Float8Tensor
,
Float8CurrentScalingQuantizer
from
.tensor.float8_tensor
import
Float8Quantizer
,
Float8Tensor
,
Float8CurrentScalingQuantizer
...
@@ -267,17 +272,36 @@ def _get_active_autocast_contexts():
...
@@ -267,17 +272,36 @@ def _get_active_autocast_contexts():
"""
"""
autocast_cached
=
torch
.
is_autocast_cache_enabled
()
autocast_cached
=
torch
.
is_autocast_cache_enabled
()
gpu_autocast_enabled
=
torch
.
is_autocast_enabled
()
if
torch_version
()
>=
(
2
,
4
,
0
):
gpu_autocast_dtype
=
torch
.
get_autocast_gpu_dtype
()
gpu_autocast_enabled
=
torch
.
is_autocast_enabled
(
"cuda"
)
gpu_autocast_ctx
=
torch
.
cuda
.
amp
.
autocast
(
gpu_autocast_dtype
=
torch
.
get_autocast_dtype
(
"cuda"
)
gpu_autocast_enabled
,
gpu_autocast_dtype
,
autocast_cached
gpu_autocast_ctx
=
torch
.
amp
.
autocast
(
)
"cuda"
,
enabled
=
gpu_autocast_enabled
,
dtype
=
gpu_autocast_dtype
,
cache_enabled
=
autocast_cached
,
)
cpu_autocast_enabled
=
torch
.
is_autocast_cpu_enabled
()
cpu_autocast_enabled
=
torch
.
is_autocast_enabled
(
"cpu"
)
cpu_autocast_dtype
=
torch
.
get_autocast_cpu_dtype
()
cpu_autocast_dtype
=
torch
.
get_autocast_dtype
(
"cpu"
)
cpu_autocast_ctx
=
torch
.
cpu
.
amp
.
autocast
(
cpu_autocast_ctx
=
torch
.
amp
.
autocast
(
cpu_autocast_enabled
,
cpu_autocast_dtype
,
autocast_cached
"cpu"
,
)
enabled
=
cpu_autocast_enabled
,
dtype
=
cpu_autocast_dtype
,
cache_enabled
=
autocast_cached
,
)
else
:
gpu_autocast_enabled
=
torch
.
is_autocast_enabled
()
gpu_autocast_dtype
=
torch
.
get_autocast_gpu_dtype
()
gpu_autocast_ctx
=
torch
.
cuda
.
amp
.
autocast
(
gpu_autocast_enabled
,
gpu_autocast_dtype
,
autocast_cached
)
cpu_autocast_enabled
=
torch
.
is_autocast_cpu_enabled
()
cpu_autocast_dtype
=
torch
.
get_autocast_cpu_dtype
()
cpu_autocast_ctx
=
torch
.
cpu
.
amp
.
autocast
(
cpu_autocast_enabled
,
cpu_autocast_dtype
,
autocast_cached
)
return
gpu_autocast_ctx
,
cpu_autocast_ctx
return
gpu_autocast_ctx
,
cpu_autocast_ctx
...
@@ -561,7 +585,9 @@ def has_te_modules(network):
...
@@ -561,7 +585,9 @@ def has_te_modules(network):
"""
"""
from
.module
import
LayerNorm
,
RMSNorm
from
.module
import
LayerNorm
,
RMSNorm
from
.module.base
import
TransformerEngineBaseModule
from
.module.base
import
TransformerEngineBaseModule
from
.attention
import
UnfusedDotProductAttention
,
DotProductAttention
,
MultiheadAttention
from
.attention.dot_product_attention.backends
import
UnfusedDotProductAttention
from
.attention.dot_product_attention.dot_product_attention
import
DotProductAttention
from
.attention.multi_head_attention
import
MultiheadAttention
from
.transformer
import
TransformerLayer
from
.transformer
import
TransformerLayer
te_classes_list
=
[
te_classes_list
=
[
...
@@ -893,8 +919,10 @@ def _all_gather_fp8(
...
@@ -893,8 +919,10 @@ def _all_gather_fp8(
# Note: We cannot directly all-gather the transposed FP8 tensor,
# Note: We cannot directly all-gather the transposed FP8 tensor,
# so temporarily modify quantizer to avoid creating FP8 transpose.
# so temporarily modify quantizer to avoid creating FP8 transpose.
if
not
isinstance
(
inp
,
Float8TensorBase
):
if
not
isinstance
(
inp
,
Float8TensorBase
):
if
quantizer
is
None
:
assert
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
))
raise
ValueError
(
"Input tensor is not FP8 and no quantizer was provided"
)
# we cannot directly gather the transposed fp8 tensor
# so we need to disable columnwise usage for the quantizer
# and then set it back to the original value after quantizing
init_rowwise_usage
=
quantizer
.
rowwise_usage
init_rowwise_usage
=
quantizer
.
rowwise_usage
init_columnwise_usage
=
quantizer
.
columnwise_usage
init_columnwise_usage
=
quantizer
.
columnwise_usage
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
...
@@ -938,7 +966,7 @@ def _all_gather_fp8(
...
@@ -938,7 +966,7 @@ def _all_gather_fp8(
# Make sure FP8 transpose is populated if needed
# Make sure FP8 transpose is populated if needed
needs_transpose
=
(
needs_transpose
=
(
quantizer
is
not
None
and
quantizer
.
columnwise_usage
and
not
non_tn_fp8_gemm_supported
()
quantizer
is
not
None
and
quantizer
.
columnwise_usage
and
not
is_
non_tn_fp8_gemm_supported
()
)
)
if
needs_transpose
:
if
needs_transpose
:
if
handle
is
not
None
:
if
handle
is
not
None
:
...
@@ -1037,11 +1065,11 @@ def _all_gather_mxfp8(
...
@@ -1037,11 +1065,11 @@ def _all_gather_mxfp8(
dtype
=
inp
.
dtype
dtype
=
inp
.
dtype
elif
isinstance
(
inp
,
MXFP8TensorBase
):
elif
isinstance
(
inp
,
MXFP8TensorBase
):
if
inp
.
_rowwise_data
is
not
None
:
if
inp
.
_rowwise_data
is
not
None
:
in_shape
=
inp
.
_rowwise_data
.
device
.
size
()
in_shape
=
inp
.
_rowwise_data
.
size
()
device
=
inp
.
_rowwise_data
.
device
device
=
inp
.
_rowwise_data
.
device
dtype
=
inp
.
_rowwise_data
.
dtype
dtype
=
inp
.
_rowwise_data
.
dtype
elif
inp
.
_columnwise_data
is
not
None
:
elif
inp
.
_columnwise_data
is
not
None
:
in_shape
=
inp
.
_columnwise_data
.
device
.
size
()
in_shape
=
inp
.
_columnwise_data
.
size
()
device
=
inp
.
_columnwise_data
.
device
device
=
inp
.
_columnwise_data
.
device
dtype
=
inp
.
_columnwise_data
.
dtype
dtype
=
inp
.
_columnwise_data
.
dtype
else
:
else
:
...
@@ -1474,7 +1502,9 @@ def _is_te_module(module):
...
@@ -1474,7 +1502,9 @@ def _is_te_module(module):
"""
"""
from
.module
import
LayerNorm
,
RMSNorm
from
.module
import
LayerNorm
,
RMSNorm
from
.module.base
import
TransformerEngineBaseModule
from
.module.base
import
TransformerEngineBaseModule
from
.attention
import
UnfusedDotProductAttention
,
DotProductAttention
,
MultiheadAttention
from
.attention.dot_product_attention.dot_product_attention
import
DotProductAttention
from
.attention.dot_product_attention.backends
import
UnfusedDotProductAttention
from
.attention.multi_head_attention
import
MultiheadAttention
from
.transformer
import
TransformerLayer
from
.transformer
import
TransformerLayer
te_classes_list
=
[
te_classes_list
=
[
...
...
transformer_engine/pytorch/fp8.py
View file @
f8c2af4c
...
@@ -520,8 +520,8 @@ class FP8GlobalStateManager:
...
@@ -520,8 +520,8 @@ class FP8GlobalStateManager:
return
return
# Store updated amaxes and scales from phase 1 post forward.
# Store updated amaxes and scales from phase 1 post forward.
fp8_meta
[
"updated_amax_history_fwd"
]
=
fp8_meta
[
"scaling_fwd"
].
amax_history
fp8_meta
[
"updated_amax_history_fwd"
]
=
fp8_meta
[
"scaling_fwd"
].
amax_history
.
clone
()
fp8_meta
[
"updated_scale_fwd"
]
=
fp8_meta
[
"scaling_fwd"
].
scale
fp8_meta
[
"updated_scale_fwd"
]
=
fp8_meta
[
"scaling_fwd"
].
scale
.
clone
()
# Retrieve stashed amaxes and scales from phase 1 pre forward.
# Retrieve stashed amaxes and scales from phase 1 pre forward.
buffer_position_key
=
"global_fp8_buffer_pos_fwd_recompute"
buffer_position_key
=
"global_fp8_buffer_pos_fwd_recompute"
...
...
transformer_engine/pytorch/graph.py
View file @
f8c2af4c
...
@@ -536,7 +536,9 @@ def _make_graphed_callables(
...
@@ -536,7 +536,9 @@ def _make_graphed_callables(
# Only Set the FP8 meta for the modules included by forward
# Only Set the FP8 meta for the modules included by forward
continue
continue
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
from
transformer_engine.pytorch.attention
import
DotProductAttention
from
transformer_engine.pytorch.attention.dot_product_attention
import
(
DotProductAttention
,
)
if
(
if
(
isinstance
(
m
,
DotProductAttention
)
isinstance
(
m
,
DotProductAttention
)
...
...
transformer_engine/pytorch/jit.py
View file @
f8c2af4c
...
@@ -8,6 +8,9 @@ from functools import wraps
...
@@ -8,6 +8,9 @@ from functools import wraps
from
typing
import
Callable
,
Optional
,
Tuple
from
typing
import
Callable
,
Optional
,
Tuple
import
torch
import
torch
from
.
import
torch_version
from
.utils
import
gpu_autocast_ctx
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
# pylint: disable=unnecessary-lambda-assignment
# pylint: disable=unnecessary-lambda-assignment
...
@@ -32,13 +35,13 @@ def lazy_compile(func):
...
@@ -32,13 +35,13 @@ def lazy_compile(func):
jit_fuser
=
lambda
func
:
func
jit_fuser
=
lambda
func
:
func
if
torch
.
_
_version
__
>=
"2"
and
bool
(
int
(
os
.
getenv
(
"NVTE_TORCH_COMPILE"
,
"1"
))):
if
torch_version
()
>=
(
2
,
0
,
0
)
and
bool
(
int
(
os
.
getenv
(
"NVTE_TORCH_COMPILE"
,
"1"
))):
jit_fuser
=
lazy_compile
jit_fuser
=
lazy_compile
# See: https://github.com/NVIDIA/TransformerEngine/issues/597
# See: https://github.com/NVIDIA/TransformerEngine/issues/597
dropout_fuser
=
torch
.
jit
.
script
dropout_fuser
=
torch
.
jit
.
script
if
torch
.
_
_version
__
>=
"2.2"
and
bool
(
int
(
os
.
getenv
(
"NVTE_TORCH_COMPILE"
,
"1"
))):
if
torch_version
()
>=
(
2
,
2
,
0
)
and
bool
(
int
(
os
.
getenv
(
"NVTE_TORCH_COMPILE"
,
"1"
))):
dropout_fuser
=
lazy_compile
dropout_fuser
=
lazy_compile
...
@@ -51,11 +54,9 @@ def set_jit_fusion_options() -> None:
...
@@ -51,11 +54,9 @@ def set_jit_fusion_options() -> None:
if
not
IS_HIP_EXTENSION
:
if
not
IS_HIP_EXTENSION
:
"""Set PyTorch JIT layer fusion options."""
"""Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels
# flags required to enable jit fusion kernels
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
if
torch_version
()
>=
(
2
,
2
,
0
):
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
1
])
if
TORCH_MAJOR
==
2
and
TORCH_MINOR
>=
2
:
pass
pass
elif
(
TORCH_MAJOR
==
2
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>=
1
0
):
elif
torch_version
()
>=
(
1
,
10
,
0
):
# nvfuser
# nvfuser
torch
.
_C
.
_jit_set_profiling_executor
(
True
)
torch
.
_C
.
_jit_set_profiling_executor
(
True
)
torch
.
_C
.
_jit_set_profiling_mode
(
True
)
torch
.
_C
.
_jit_set_profiling_mode
(
True
)
...
@@ -124,7 +125,7 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:
...
@@ -124,7 +125,7 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:
def
bias_gelu_fused
(
inp
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
bias_gelu_fused
(
inp
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Disable native AMP for bias_gelu_fused_"""
"""Disable native AMP for bias_gelu_fused_"""
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
with
gpu_
autocast
_ctx
(
enabled
=
False
):
if
bias
is
not
None
and
bias
.
numel
()
!=
0
:
if
bias
is
not
None
and
bias
.
numel
()
!=
0
:
return
bias_gelu_fused_
(
inp
,
bias
)
return
bias_gelu_fused_
(
inp
,
bias
)
return
gelu_fused_
(
inp
)
return
gelu_fused_
(
inp
)
...
@@ -134,7 +135,7 @@ def bgrad_dgelu_fused(
...
@@ -134,7 +135,7 @@ def bgrad_dgelu_fused(
grad_output
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
grad_output
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
)
->
Tuple
[
Optional
[
torch
.
Tensor
],
torch
.
Tensor
]:
)
->
Tuple
[
Optional
[
torch
.
Tensor
],
torch
.
Tensor
]:
"""Disable native AMP for `bgrad_dgelu_fused_`"""
"""Disable native AMP for `bgrad_dgelu_fused_`"""
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
with
gpu_
autocast
_ctx
(
enabled
=
False
):
if
bias
is
not
None
and
bias
.
numel
()
!=
0
:
if
bias
is
not
None
and
bias
.
numel
()
!=
0
:
return
bgrad_dgelu_fused_
(
grad_output
,
inp
,
bias
)
return
bgrad_dgelu_fused_
(
grad_output
,
inp
,
bias
)
return
None
,
dgelu_fused_
(
grad_output
,
inp
)
return
None
,
dgelu_fused_
(
grad_output
,
inp
)
...
@@ -175,7 +176,7 @@ def bias_dropout_add_fused_train(
...
@@ -175,7 +176,7 @@ def bias_dropout_add_fused_train(
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Disable native AMP and enable grad for BDA"""
"""Disable native AMP and enable grad for BDA"""
with
torch
.
enable_grad
():
with
torch
.
enable_grad
():
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
with
gpu_
autocast
_ctx
(
enabled
=
False
):
return
bias_dropout_add_fused_train_
(
x
,
bias
,
residual
,
prob
)
return
bias_dropout_add_fused_train_
(
x
,
bias
,
residual
,
prob
)
...
@@ -191,7 +192,7 @@ def bias_dropout_add_fused_inference(
...
@@ -191,7 +192,7 @@ def bias_dropout_add_fused_inference(
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
prob
:
float
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Disable native AMP for BDA"""
"""Disable native AMP for BDA"""
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
with
gpu_
autocast
_ctx
(
enabled
=
False
):
return
bias_dropout_add_fused_inference_
(
x
,
bias
,
residual
,
prob
)
return
bias_dropout_add_fused_inference_
(
x
,
bias
,
residual
,
prob
)
...
...
transformer_engine/pytorch/module/_common.py
View file @
f8c2af4c
...
@@ -6,8 +6,6 @@
...
@@ -6,8 +6,6 @@
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Union
,
Callable
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Union
,
Callable
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
reduce
from
operator
import
mul
as
multiply_op
import
queue
import
queue
import
torch
import
torch
...
@@ -15,7 +13,6 @@ import torch
...
@@ -15,7 +13,6 @@ import torch
from
..
import
cpp_extensions
as
tex
from
..
import
cpp_extensions
as
tex
from
..constants
import
TE_DType
from
..constants
import
TE_DType
from
..utils
import
get_default_init_method
from
..utils
import
get_default_init_method
from
..tensor.float8_tensor
import
Float8Tensor
import
warnings
import
warnings
try
:
try
:
from
lightop
import
rmsnorm_forward
,
rmsnorm_backward
from
lightop
import
rmsnorm_forward
,
rmsnorm_backward
...
@@ -24,7 +21,6 @@ except ImportError:
...
@@ -24,7 +21,6 @@ except ImportError:
enable_lightop
=
False
enable_lightop
=
False
warnings
.
warn
(
"Failed to import lightop module. Falling back to alternative implementation."
,
UserWarning
)
warnings
.
warn
(
"Failed to import lightop module. Falling back to alternative implementation."
,
UserWarning
)
def
_get_normalization_func
(
normalization
:
str
,
forward
:
bool
):
def
_get_normalization_func
(
normalization
:
str
,
forward
:
bool
):
fwd_normalization_funcs
=
{
fwd_normalization_funcs
=
{
"LayerNorm"
:
tex
.
layernorm_fwd
,
"LayerNorm"
:
tex
.
layernorm_fwd
,
...
@@ -40,39 +36,6 @@ def _get_normalization_func(normalization: str, forward: bool):
...
@@ -40,39 +36,6 @@ def _get_normalization_func(normalization: str, forward: bool):
return
bwd_normalization_funcs
[
normalization
]
return
bwd_normalization_funcs
[
normalization
]
def
_fix_gathered_fp8_transpose
(
fp8_tensor
:
Float8Tensor
,
tp_size
:
int
)
->
Float8Tensor
:
"""Reorder FP8 transposes after Userbuffers gather.
The all-gather is performed in-place in the Float8Tensor's
row-wise data, and afterwards we need to do a transpose to get the
correct ordering. This misuses data fields in Float8Tensor and
should be considered an evil hack. It would be best to move
transpose logic into CommOverlap::get_buffer.
Responsibility for fixing: adener, tmoon
"""
assert
isinstance
(
fp8_tensor
,
Float8Tensor
),
"Tensor is not a Float8Tensor"
assert
tp_size
>
1
,
"The tensor transpose cannot be interleaved when TP size is 1"
assert
fp8_tensor
.
_data
is
not
None
,
"The tensor does not hold any rowwise data"
assert
(
fp8_tensor
.
_data
.
shape
[
0
]
%
tp_size
==
0
),
"Leading dimension of data is not divisble by TP size"
data
=
fp8_tensor
.
_data
batched_size
=
reduce
(
multiply_op
,
data
.
shape
[
1
:])
interleaved_shape
=
[
tp_size
,
data
.
shape
[
0
]
//
tp_size
,
batched_size
]
transposed_shape
=
[
data
.
shape
[
0
]
//
tp_size
,
batched_size
*
tp_size
]
fp8_tensor
.
_transpose
=
(
data
.
view
(
interleaved_shape
).
transpose
(
0
,
1
).
contiguous
().
view
(
transposed_shape
)
)
fp8_tensor
.
_transpose_invalid
=
False
fp8_tensor
.
_data
=
None
return
fp8_tensor
def
apply_normalization
(
def
apply_normalization
(
inputmat
:
torch
.
Tensor
,
inputmat
:
torch
.
Tensor
,
ln_out
:
torch
.
Tensor
,
ln_out
:
torch
.
Tensor
,
...
...
transformer_engine/pytorch/module/base.py
View file @
f8c2af4c
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
"""Base modules and utilities for TransformerEngine PyTorch API"""
"""Base modules and utilities for TransformerEngine PyTorch API"""
import
io
import
io
import
math
import
os
import
os
import
pickle
import
pickle
import
warnings
import
warnings
...
@@ -35,10 +36,13 @@ from ..distributed import (
...
@@ -35,10 +36,13 @@ from ..distributed import (
_fsdp_gather_tensors
,
_fsdp_gather_tensors
,
)
)
from
..constants
import
dist_group_type
from
..constants
import
dist_group_type
from
..tensor
import
QuantizedTensor
,
Quantizer
from
..tensor.quantized_tensor
import
QuantizedTensor
,
QuantizedTensorBase
,
Quantizer
from
..tensor.float8_tensor
import
Float8Quantizer
,
Float8CurrentScalingQuantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
..tensor._internal.float8_tensor_base
import
Float8TensorBase
from
..tensor._internal.float8_tensor_base
import
Float8TensorBase
from
..tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
..tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
..utils
import
torch_get_autocast_gpu_dtype
from
..tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
..tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
...common.recipe
import
Recipe
from
...common.recipe
import
Recipe
from
...debug.pytorch.debug_state
import
TEDebugState
from
...debug.pytorch.debug_state
import
TEDebugState
...
@@ -451,6 +455,142 @@ def destroy_ub():
...
@@ -451,6 +455,142 @@ def destroy_ub():
layers_atomic_ring_exchange
=
[]
layers_atomic_ring_exchange
=
[]
def
fill_userbuffers_buffer_for_all_gather
(
comm
,
local_tensor
:
torch
.
Tensor
,
quantizer
:
Optional
[
Quantizer
],
process_group
,
)
->
tuple
[
torch
.
Tensor
|
QuantizedTensorBase
,
torch
.
Tensor
|
QuantizedTensorBase
]:
"""Fill local shard of Userbuffers buffer with data for all-gather
Returns the full tensor and the local shard, both using the
Userbuffers buffer as their underlying data. These tensors should
be used carefully (e.g. only immediately before and after a
Userbuffers operation) since the underlying data may be
overwritten by other Userbuffers operations.
May perform blocking communication if needed for the gathered
tensor's metadata, e.g. scaling factors.
"""
# Tensor dimensions
local_shape
=
local_tensor
.
size
()
if
not
local_shape
:
raise
ValueError
(
f
"Invalid local tensor (shape=
{
tuple
(
local_shape
)
}
)"
)
process_group_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
global_shape
=
list
(
local_shape
)
global_shape
[
0
]
*=
process_group_size
# Unquantized data
if
quantizer
is
None
:
if
isinstance
(
local_tensor
,
QuantizedTensorBase
):
local_tensor
=
local_tensor
.
dequantize
()
if
comm
.
is_fp8_ubuf
():
raise
RuntimeError
(
"Attempting to all-gather unquantized tensor, "
"but Userbuffers is initialized with FP8 buffers"
)
comm
.
copy_into_buffer
(
local_tensor
,
local_chunk
=
True
)
global_tensor
=
comm
.
get_buffer
(
shape
=
global_shape
)
return
global_tensor
,
local_tensor
# FP8 data
if
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)):
if
not
isinstance
(
local_tensor
,
Float8TensorBase
):
if
isinstance
(
local_tensor
,
QuantizedTensorBase
):
local_tensor
.
dequantize
()
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
local_tensor
=
quantizer
(
local_tensor
)
if
not
comm
.
is_fp8_ubuf
():
raise
RuntimeError
(
"Attempting to all-gather FP8 tensor, "
"but Userbuffers is not initialized with FP8 buffers"
)
comm
.
copy_into_buffer
(
local_tensor
.
_data
,
local_chunk
=
True
)
global_tensor_data
=
comm
.
get_buffer
(
shape
=
global_shape
)
global_tensor
=
Float8TensorBase
(
data
=
global_tensor_data
,
fp8_scale_inv
=
local_tensor
.
_scale_inv
,
fp8_dtype
=
local_tensor
.
_fp8_dtype
,
quantizer
=
quantizer
,
)
return
global_tensor
,
local_tensor
# MXFP8 data
if
isinstance
(
quantizer
,
MXFP8Quantizer
):
# Cast to MXFP8 if needed
if
not
isinstance
(
local_tensor
,
MXFP8TensorBase
):
if
isinstance
(
local_tensor
,
QuantizedTensorBase
):
local_tensor
.
dequantize
()
local_tensor
=
quantizer
(
local_tensor
)
if
not
comm
.
is_fp8_ubuf
():
raise
RuntimeError
(
"Attempting to all-gather MXFP8 tensor, "
"but Userbuffers is not initialized with FP8 buffers"
)
# Check which MXFP8 buffer to communicate
if
quantizer
.
rowwise_usage
==
quantizer
.
columnwise_usage
:
raise
ValueError
(
"Userbuffers can only communicate one MXFP8 buffer at a time, "
f
"but quantizer has rowwise_usage=
{
quantizer
.
rowwise_usage
}
, "
f
"columnwise_usage=
{
quantizer
.
columnwise_usage
}
"
)
with_rowwise_data
=
quantizer
.
rowwise_usage
# Copy MXFP8 data to local chunk of Userbuffers buffer
local_data
=
(
local_tensor
.
_rowwise_data
if
with_rowwise_data
else
local_tensor
.
_columnwise_data
)
comm
.
copy_into_buffer
(
local_data
,
local_chunk
=
True
)
# Gather scaling-inverses
if
math
.
prod
(
local_shape
[:
-
1
])
%
128
!=
0
:
raise
ValueError
(
"Userbuffers requires MXFP8 tensor dims that are divisible by 128, "
f
"but got MXFP8 tensor with shape=
{
tuple
(
local_shape
)
}
"
)
local_scale_inv
=
(
local_tensor
.
_rowwise_scale_inv
if
with_rowwise_data
else
local_tensor
.
_columnwise_scale_inv
)
local_scale_inv_size
=
list
(
local_scale_inv
.
size
())
global_scale_inv
=
torch
.
empty
(
[
process_group_size
*
local_scale_inv_size
[
0
]]
+
local_scale_inv_size
[
1
:],
dtype
=
local_scale_inv
.
dtype
,
device
=
local_scale_inv
.
device
,
)
torch
.
distributed
.
all_gather_into_tensor
(
global_scale_inv
,
local_scale_inv
,
group
=
process_group
,
)
# Construct MXFP8 tensor with Userbuffers buffer
rowwise_data
,
rowwise_scale_inv
=
None
,
None
columnwise_data
,
columnwise_scale_inv
=
None
,
None
global_data
=
comm
.
get_buffer
(
shape
=
global_shape
)
if
with_rowwise_data
:
rowwise_data
,
rowwise_scale_inv
=
global_data
,
global_scale_inv
else
:
columnwise_data
,
columnwise_scale_inv
=
global_data
,
global_scale_inv
global_tensor
=
MXFP8TensorBase
(
rowwise_data
=
rowwise_data
,
rowwise_scale_inv
=
rowwise_scale_inv
,
columnwise_data
=
columnwise_data
,
columnwise_scale_inv
=
columnwise_scale_inv
,
fp8_dtype
=
local_tensor
.
_fp8_dtype
,
quantizer
=
quantizer
,
)
return
global_tensor
,
local_tensor
# Unsupported data format
raise
ValueError
(
f
"Unsupported quantizer for Userbuffers (
{
quantizer
}
)"
)
class
TransformerEngineBaseModule
(
torch
.
nn
.
Module
,
ABC
):
class
TransformerEngineBaseModule
(
torch
.
nn
.
Module
,
ABC
):
"""Base TE module."""
"""Base TE module."""
...
@@ -625,7 +765,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -625,7 +765,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
reset
(
"scaling_fwd"
)
reset
(
"scaling_fwd"
)
reset
(
"scaling_bwd"
)
reset
(
"scaling_bwd"
)
def
get_extra_state
(
self
)
->
torch
.
Tensor
:
def
get_extra_state
(
self
)
->
Optional
[
torch
.
Tensor
]
:
"""Save before checkpointing."""
"""Save before checkpointing."""
# This implementation is working around a few issues:
# This implementation is working around a few issues:
...
@@ -659,25 +799,26 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -659,25 +799,26 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Store FP8 state if needed
# Store FP8 state if needed
state
=
None
state
=
None
fp8_checkpoint
=
self
.
fp8_meta
[
"fp8_checkpoint"
]
or
self
.
fp8
or
self
.
fp8_calibration
fp8_checkpoint
=
self
.
fp8_meta
[
"fp8_checkpoint"
]
or
self
.
fp8
or
self
.
fp8_calibration
if
fp8_checkpoint
:
if
not
fp8_checkpoint
:
return
None
# Copy tensors to CPU and store
state
=
{}
# Copy tensors to CPU and store
state
[
"recipe"
]
=
self
.
fp8_meta
[
"recipe"
]
state
=
{}
if
state
[
"recipe"
].
delayed
():
state
[
"recipe"
]
=
self
.
fp8_meta
[
"recipe"
]
state
[
"scale_fwd"
]
=
to_cpu
(
self
.
fp8_meta
[
"scaling_fwd"
].
scale
)
if
state
[
"recipe"
].
delayed
():
state
[
"amax_history_fwd"
]
=
to_cpu
(
self
.
fp8_meta
[
"scaling_fwd"
].
amax_history
)
state
[
"scale_fwd"
]
=
to_cpu
(
self
.
fp8_meta
[
"scaling_fwd"
].
scale
)
state
[
"scale_bwd"
]
=
to_cpu
(
self
.
fp8_meta
[
"scaling_bwd"
].
scale
)
state
[
"amax_history_fwd"
]
=
to_cpu
(
self
.
fp8_meta
[
"scaling_fwd"
].
amax_history
)
state
[
"amax_history_bwd"
]
=
to_cpu
(
self
.
fp8_meta
[
"scaling_bwd"
].
amax_history
)
state
[
"scale_bwd"
]
=
to_cpu
(
self
.
fp8_meta
[
"scaling_bwd"
].
scale
)
state
[
"amax_history_bwd"
]
=
to_cpu
(
self
.
fp8_meta
[
"scaling_bwd"
].
amax_history
)
# Store other pickelable values
extra
=
{}
# Store other pickelable values
for
k
,
v
in
self
.
fp8_meta
.
items
():
extra
=
{}
if
k
!=
"buffer_index_and_autocast_key"
and
isinstance
(
for
k
,
v
in
self
.
fp8_meta
.
items
():
v
,
(
bool
,
int
,
float
,
str
,
tuple
,
list
)
if
k
!=
"buffer_index_and_autocast_key"
and
isinstance
(
):
v
,
(
bool
,
int
,
float
,
str
,
tuple
,
list
)
extra
[
k
]
=
v
):
state
[
"extra_fp8_variables"
]
=
extra
extra
[
k
]
=
v
state
[
"extra_fp8_variables"
]
=
extra
# Serialize state into byte tensor
# Serialize state into byte tensor
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -685,7 +826,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -685,7 +826,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state_serialized
=
torch
.
frombuffer
(
state_serialized
,
dtype
=
torch
.
uint8
)
state_serialized
=
torch
.
frombuffer
(
state_serialized
,
dtype
=
torch
.
uint8
)
return
state_serialized
return
state_serialized
def
set_extra_state
(
self
,
state
:
torch
.
Tensor
)
->
None
:
def
set_extra_state
(
self
,
state
:
Optional
[
torch
.
Tensor
]
)
->
None
:
"""Load previous state."""
"""Load previous state."""
if
state
is
None
:
if
state
is
None
:
return
return
...
@@ -734,7 +875,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -734,7 +875,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Get activation data type for AMP."""
"""Get activation data type for AMP."""
# Native AMP (`torch.autocast`) gets highest priority
# Native AMP (`torch.autocast`) gets highest priority
if
torch
.
is_autocast_enabled
():
if
torch
.
is_autocast_enabled
():
self
.
activation_dtype
=
torch
.
get_autocast_gpu_dtype
()
self
.
activation_dtype
=
torch
_
get_autocast_gpu_dtype
()
return
return
# All checks after this have already been performed once, thus skip
# All checks after this have already been performed once, thus skip
...
@@ -898,11 +1039,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -898,11 +1039,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Non-FP8 case: bgrad is fused with wgrad for this case.
# Non-FP8 case: bgrad is fused with wgrad for this case.
if
not
ctx
.
fp8
and
not
ctx
.
debug
:
if
not
ctx
.
fp8
and
not
ctx
.
debug
:
if
gather_grad_output
:
if
gather_grad_output
:
if
not
ctx
.
ub_overlap_ag
or
ctx
.
ub_obj_gradout
is
None
:
if
not
ctx
.
ub_overlap_ag
or
ctx
.
ub_obj_gradout
is
None
:
# Perform NCCL all-gather
grad_output
,
_
=
gather_along_first_dim
(
grad_output
,
ctx
.
tp_group
)
grad_output
,
_
=
gather_along_first_dim
(
grad_output
,
ctx
.
tp_group
)
else
:
else
:
# Initialize Userbuffers all-gather
ctx
.
ub_obj_gradout
.
copy_into_buffer
(
grad_output
,
quantizer
,
local_chunk
=
True
)
grad_output
,
_
=
fill_userbuffers_buffer_for_all_gather
(
grad_output
=
ctx
.
ub_obj_gradout
.
get_buffer
(
quantizer
)
ctx
.
ub_obj_gradout
,
grad_output
,
None
,
ctx
.
tp_group
,
)
return
grad_output
,
None
return
grad_output
,
None
# FP8 with all-gather: unfused bgrad, fused cast + transpose
# FP8 with all-gather: unfused bgrad, fused cast + transpose
...
@@ -925,8 +1070,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -925,8 +1070,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_output
=
quantizer
(
grad_output
)
grad_output
=
quantizer
(
grad_output
)
# Copy into communication buffer, and replace original gradient with it
# Copy into communication buffer, and replace original gradient with it
ctx
.
ub_obj_gradout
.
copy_into_buffer
(
grad_output
,
quantizer
,
local_chunk
=
True
)
grad_output
,
_
=
fill_userbuffers_buffer_for_all_gather
(
grad_output
=
ctx
.
ub_obj_gradout
.
get_buffer
(
quantizer
)
ctx
.
ub_obj_gradout
,
grad_output
,
quantizer
,
ctx
.
tp_group
,
)
else
:
else
:
grad_output
,
_
=
gather_along_first_dim
(
grad_output
,
_
=
gather_along_first_dim
(
grad_output
,
grad_output
,
...
@@ -1140,7 +1289,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1140,7 +1289,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
raise
ValueError
(
raise
ValueError
(
"tensor and quantizer kwargs must be provided to construct FP8 workspace"
"tensor and quantizer kwargs must be provided to construct FP8 workspace"
)
)
if
cache_name
is
not
None
:
# Ensure the tensor in the cache is an instance of torch.Tensor,
# as it persists beyond a single forward pass.
# Setting internal=True would cause the data to be removed in prepare_for_saving(...).
quantizer_internal
=
quantizer
.
internal
quantizer
.
internal
=
False
out
=
quantizer
.
quantize
(
tensor
,
dtype
=
workspace_dtype
)
out
=
quantizer
.
quantize
(
tensor
,
dtype
=
workspace_dtype
)
if
cache_name
is
not
None
:
quantizer
.
internal
=
quantizer_internal
# Update cache
# Update cache
if
cache_name
is
not
None
:
if
cache_name
is
not
None
:
...
@@ -1188,7 +1346,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1188,7 +1346,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if
self
.
wgrad_store
is
None
or
not
self
.
wgrad_store
.
delay_wgrad_compute
():
if
self
.
wgrad_store
is
None
or
not
self
.
wgrad_store
.
delay_wgrad_compute
():
return
return
with
torch
.
cuda
.
nvtx
.
range
(
f
"_
{
self
.
__class__
.
__name__
}
_wgrad"
):
with
torch
.
cuda
.
nvtx
.
range
(
f
"_
{
self
.
__class__
.
__name__
}
_wgrad"
):
(
wgrad
,
grad
_bias_
,
_
,
_
),
_
=
self
.
wgrad_store
.
pop
()
(
wgrad
,
b
grad
),
_
=
self
.
wgrad_store
.
pop
()
if
not
self
.
fuse_wgrad_accumulation
:
if
not
self
.
fuse_wgrad_accumulation
:
unfused_weights
=
[
getattr
(
self
,
name
)
for
name
in
self
.
weight_names
]
unfused_weights
=
[
getattr
(
self
,
name
)
for
name
in
self
.
weight_names
]
weight_tensor
=
noop_cat
(
unfused_weights
)
weight_tensor
=
noop_cat
(
unfused_weights
)
...
@@ -1197,9 +1355,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1197,9 +1355,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if
self
.
use_bias
:
if
self
.
use_bias
:
bias_tensor
=
noop_cat
([
getattr
(
self
,
name
)
for
name
in
self
.
bias_names
])
bias_tensor
=
noop_cat
([
getattr
(
self
,
name
)
for
name
in
self
.
bias_names
])
if
bias_tensor
.
grad
is
None
:
if
bias_tensor
.
grad
is
None
:
bias_tensor
.
grad
=
grad_bias_
.
to
(
bias_tensor
.
dtype
)
bias_tensor
.
grad
=
bgrad
.
to
(
bias_tensor
.
dtype
)
del
grad_bias_
del
wgrad
def
_validate_name
(
self
):
def
_validate_name
(
self
):
"""
"""
...
...
transformer_engine/pytorch/module/grouped_linear.py
View file @
f8c2af4c
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
"""GroupedLinear API"""
"""GroupedLinear API"""
from
typing
import
Union
,
Optional
,
Callable
,
Tuple
,
List
from
typing
import
Union
,
Optional
,
Callable
,
Tuple
,
List
import
warnings
import
functools
import
functools
import
torch
import
torch
...
@@ -43,7 +44,7 @@ from ..graph import is_graph_capturing
...
@@ -43,7 +44,7 @@ from ..graph import is_graph_capturing
from
..cpu_offload
import
is_cpu_offload_enabled
from
..cpu_offload
import
is_cpu_offload_enabled
from
..tensor.quantized_tensor
import
(
from
..tensor.quantized_tensor
import
(
QuantizedTensor
,
QuantizedTensor
Base
,
Quantizer
,
Quantizer
,
prepare_for_saving
,
prepare_for_saving
,
restore_from_saved
,
restore_from_saved
,
...
@@ -182,11 +183,11 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -182,11 +183,11 @@ class _GroupedLinear(torch.autograd.Function):
# TODO: update after #1638 is merged. # pylint: disable=fixme
# TODO: update after #1638 is merged. # pylint: disable=fixme
if
weight_requires_grad
:
if
weight_requires_grad
:
for
inputmat
in
inputmats
:
for
inputmat
in
inputmats
:
if
isinstance
(
inputmat
,
QuantizedTensor
):
if
isinstance
(
inputmat
,
QuantizedTensor
Base
):
inputmat
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
inputmat
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
if
inp
.
requires_grad
:
if
inp
.
requires_grad
:
for
weight
in
weights_fp8
:
for
weight
in
weights_fp8
:
if
isinstance
(
weight
,
QuantizedTensor
):
if
isinstance
(
weight
,
QuantizedTensor
Base
):
weight
.
update_usage
(
columnwise_usage
=
True
)
weight
.
update_usage
(
columnwise_usage
=
True
)
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
...
@@ -299,7 +300,7 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -299,7 +300,7 @@ class _GroupedLinear(torch.autograd.Function):
)
)
for
weight
,
quantizer
in
zip
(
weights
,
ctx
.
weight_quantizers
):
for
weight
,
quantizer
in
zip
(
weights
,
ctx
.
weight_quantizers
):
if
quantizer
is
not
None
and
isinstance
(
weight
,
QuantizedTensor
):
if
quantizer
is
not
None
and
isinstance
(
weight
,
QuantizedTensor
Base
):
weight
.
update_usage
(
weight
.
update_usage
(
rowwise_usage
=
quantizer
.
rowwise_usage
,
rowwise_usage
=
quantizer
.
rowwise_usage
,
columnwise_usage
=
quantizer
.
columnwise_usage
,
columnwise_usage
=
quantizer
.
columnwise_usage
,
...
@@ -663,7 +664,7 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -663,7 +664,7 @@ class GroupedLinear(TransformerEngineBaseModule):
produced)
produced)
"""
"""
assert
not
isinstance
(
assert
not
isinstance
(
inp
,
QuantizedTensor
inp
,
QuantizedTensor
Base
),
"GroupedLinear doesn't support input tensor in FP8."
),
"GroupedLinear doesn't support input tensor in FP8."
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
...
@@ -675,9 +676,14 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -675,9 +676,14 @@ class GroupedLinear(TransformerEngineBaseModule):
weight_tensors
=
[
getattr
(
self
,
f
"weight
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
weight_tensors
=
[
getattr
(
self
,
f
"weight
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
bias_tensors
=
[
getattr
(
self
,
f
"bias
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
bias_tensors
=
[
getattr
(
self
,
f
"bias
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
if
not
self
.
fp8
:
if
not
self
.
fp8
and
any
(
isinstance
(
w
,
QuantizedTensorBase
)
for
w
in
weight_tensors
):
warnings
.
warn
(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
weight_tensors
=
[
weight_tensors
=
[
w
.
dequantize
()
if
isinstance
(
w
,
QuantizedTensor
)
else
w
for
w
in
weight_tensors
w
.
dequantize
()
if
isinstance
(
w
,
QuantizedTensorBase
)
else
w
for
w
in
weight_tensors
]
]
input_quantizers
,
weight_quantizers
,
output_quantizers
=
(
input_quantizers
,
weight_quantizers
,
output_quantizers
=
(
...
...
transformer_engine/pytorch/module/layernorm.py
View file @
f8c2af4c
...
@@ -94,6 +94,9 @@ class LayerNorm(_LayerNormOp):
...
@@ -94,6 +94,9 @@ class LayerNorm(_LayerNormOp):
)
)
kwargs
[
"dtype"
]
=
params_dtype
kwargs
[
"dtype"
]
=
params_dtype
# Flag for sequence parallelism (custom Megatron-LM integration)
self
.
sequence_parallel
:
Optional
[
bool
]
=
sequence_parallel
# Initialize layer norm operation
# Initialize layer norm operation
super
().
__init__
(
super
().
__init__
(
normalized_shape
,
normalized_shape
,
...
@@ -102,12 +105,6 @@ class LayerNorm(_LayerNormOp):
...
@@ -102,12 +105,6 @@ class LayerNorm(_LayerNormOp):
**
kwargs
,
**
kwargs
,
)
)
# Flag for sequence parallelism (custom Megatron-LM integration)
self
.
sequence_parallel
:
Optional
[
bool
]
=
sequence_parallel
if
sequence_parallel
is
not
None
:
self
.
weight
.
sequence_parallel
=
sequence_parallel
self
.
bias
.
sequence_parallel
=
sequence_parallel
def
reset_layer_norm_parameters
(
self
)
->
None
:
def
reset_layer_norm_parameters
(
self
)
->
None
:
"""Init LN params"""
"""Init LN params"""
warnings
.
warn
(
warnings
.
warn
(
...
@@ -136,7 +133,7 @@ class LayerNorm(_LayerNormOp):
...
@@ -136,7 +133,7 @@ class LayerNorm(_LayerNormOp):
super
().
reset_parameters
()
super
().
reset_parameters
()
# Set flag for sequence parallelism (custom Megatron-LM integration)
# Set flag for sequence parallelism (custom Megatron-LM integration)
if
getattr
(
self
,
"
sequence_parallel
"
,
None
)
is
not
None
:
if
self
.
sequence_parallel
is
not
None
:
self
.
weight
.
sequence_parallel
=
self
.
sequence_parallel
self
.
weight
.
sequence_parallel
=
self
.
sequence_parallel
self
.
bias
.
sequence_parallel
=
self
.
sequence_parallel
self
.
bias
.
sequence_parallel
=
self
.
sequence_parallel
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
f8c2af4c
...
@@ -9,7 +9,6 @@ from typing import Callable, Dict, Optional, Tuple, Union
...
@@ -9,7 +9,6 @@ from typing import Callable, Dict, Optional, Tuple, Union
from
functools
import
reduce
from
functools
import
reduce
from
operator
import
mul
as
multiply_op
from
operator
import
mul
as
multiply_op
import
functools
import
torch
import
torch
from
torch.nn
import
init
from
torch.nn
import
init
...
@@ -18,6 +17,7 @@ import transformer_engine_torch as tex
...
@@ -18,6 +17,7 @@ import transformer_engine_torch as tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
import
torch_version
from
.base
import
(
from
.base
import
(
fill_userbuffers_buffer_for_all_gather
,
get_workspace
,
get_workspace
,
get_ub
,
get_ub
,
TransformerEngineBaseModule
,
TransformerEngineBaseModule
,
...
@@ -53,9 +53,10 @@ from ..distributed import (
...
@@ -53,9 +53,10 @@ from ..distributed import (
from
..constants
import
GemmParallelModes
,
dist_group_type
from
..constants
import
GemmParallelModes
,
dist_group_type
from
..jit
import
no_torch_dynamo
from
..jit
import
no_torch_dynamo
from
..graph
import
is_graph_capturing
from
..graph
import
is_graph_capturing
from
._common
import
apply_normalization
,
noop_cat
,
_fix_gathered_fp8_transpose
,
WeightGradStore
from
._common
import
apply_normalization
,
noop_cat
,
WeightGradStore
from
..tensor.quantized_tensor
import
(
from
..tensor.quantized_tensor
import
(
QuantizedTensor
,
QuantizedTensor
,
QuantizedTensorBase
,
Quantizer
,
Quantizer
,
prepare_for_saving
,
prepare_for_saving
,
restore_from_saved
,
restore_from_saved
,
...
@@ -144,8 +145,10 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -144,8 +145,10 @@ class _LayerNormLinear(torch.autograd.Function):
# Make sure input dimensions are compatible
# Make sure input dimensions are compatible
out_features
,
in_features
=
weight
.
shape
out_features
,
in_features
=
weight
.
shape
inp_shape
=
inp
.
shape
inp_shape
=
inp
.
shape
inp_requires_grad
=
inp
.
requires_grad
assert
inp_shape
[
-
1
]
==
in_features
,
"GEMM not possible"
assert
inp_shape
[
-
1
]
==
in_features
,
"GEMM not possible"
inputmat
=
inp
.
view
((
-
1
,
in_features
))
inp
=
inp
.
view
((
-
1
,
in_features
))
inputmat
=
inp
if
fp8
:
if
fp8
:
assert_dim_for_fp8_exec
(
inputmat
,
weight
)
assert_dim_for_fp8_exec
(
inputmat
,
weight
)
...
@@ -158,42 +161,43 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -158,42 +161,43 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_range_pop
(
f
"
{
nvtx_label
}
.norm_input_cast"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.norm_input_cast"
)
tp_world_size
=
get_distributed_world_size
(
tp_group
)
tp_world_size
=
get_distributed_world_size
(
tp_group
)
ub_overlap_ag_fprop
=
(
ub_overlap_ag_fprop
and
is_grad_enabled
and
not
return_layernorm_output
)
weight_requires_grad
=
weight
.
requires_grad
weight_requires_grad
=
weight
.
requires_grad
backward_needs_input
=
is_grad_enabled
and
weight_requires_grad
backward_needs_input
=
is_grad_enabled
and
weight_requires_grad
with_input_all_gather
=
parallel_mode
==
"column"
and
sequence_parallel
with_input_all_gather
=
parallel_mode
==
"column"
and
sequence_parallel
# Check if Userbuffers is supported
# Configure Userbuffers communication (comm+GEMM overlap)
if
fp8
:
ub_obj
=
None
if
any
([
ub_overlap_ag_fprop
,
ub_overlap_rs_fprop
])
and
not
(
ub_type
=
None
FP8GlobalStateManager
.
get_fp8_recipe
().
float8_per_tensor_scaling
()
ub_overlap_ag_fprop
=
(
):
ub_overlap_ag_fprop
and
is_grad_enabled
and
not
return_layernorm_output
raise
NotImplementedError
(
)
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
if
ub_overlap_rs_fprop
:
" current scaling"
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
)
ub_type
=
tex
.
CommOverlapType
.
RS
elif
ub_overlap_ag_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_type
=
tex
.
CommOverlapType
.
AG
# Configure quantizer for norm output
# Configure quantizer for norm output
if
fp8
:
if
fp8
:
if
input_quantizer
is
None
:
if
input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for input tensor"
)
raise
ValueError
(
"Missing quantizer for input tensor"
)
columnwise_usage
=
backward_needs_input
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
backward_needs_input
)
if
(
if
with_input_all_gather
and
isinstance
(
columnwise_usage
input_quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)
and
with_input_all_gather
and
not
isinstance
(
input_quantizer
,
MXFP8Quantizer
)
):
):
columnwise
_usage
=
False
# All-gather is not supported with FP8
column
-
wise
data
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
columnwise_usag
e
)
input_quantizer
.
set_usage
(
columnwise
=
Fals
e
)
#
Avoid quantized norm kernel if norm output will be returned
#
Do TP communication in high precision if quantized format
#
or if a gather of ln_out must be in high precis
ion
.
#
does not support communicat
ion
force_hp_blockwise_ln_out_gather
=
(
force_hp_blockwise_ln_out_gather
=
(
fp8
and
with_input_all_gather
and
isinstance
(
input_quantizer
,
Float8BlockQuantizer
)
fp8
and
with_input_all_gather
and
isinstance
(
input_quantizer
,
Float8BlockQuantizer
)
)
# Perform TP communication in high precision.
)
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
with_quantized_norm
=
(
with_quantized_norm
=
(
fp8
fp8
and
not
return_layernorm_output
and
not
return_layernorm_output
...
@@ -215,16 +219,19 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -215,16 +219,19 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_ln_sm_margin
,
fwd_ln_sm_margin
,
zero_centered_gamma
,
zero_centered_gamma
,
)
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.norm"
)
# Store unquantized layer norm output if we need to return it
ln_out_return
=
None
ln_out_return
=
None
if
return_layernorm_output
or
return_layernorm_output_gathered
:
if
return_layernorm_output
or
return_layernorm_output_gathered
:
ln_out_return
=
ln_out
ln_out_return
=
ln_out
nvtx_range_pop
(
f
"
{
nvtx_label
}
.norm"
)
# Prepare GEMM input
# ------------------------------------------------------
# Prepare GEMM input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# Note: Cast to expected dtype and perform tensor-parallel communication
# ------------------------------------------------------
nvtx_range_push
(
f
"
{
nvtx_label
}
.gemm_input_cast_comm"
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.gemm_input_cast_comm"
)
ln_out_total
=
None
ln_out_total
=
None
ub_obj_fprop
=
None
if
with_input_all_gather
:
if
with_input_all_gather
:
if
return_layernorm_output_gathered
:
if
return_layernorm_output_gathered
:
# Perform all-gather in high precision if gathered
# Perform all-gather in high precision if gathered
...
@@ -232,47 +239,53 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -232,47 +239,53 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out
,
tp_group
)
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out
,
tp_group
)
ln_out_return
=
ln_out_total
ln_out_return
=
ln_out_total
if
fp8
or
debug
:
if
fp8
or
debug
:
ln_out
=
input_quantizer
(
ln_out
)
if
not
force_hp_blockwise_ln_out_gather
:
ln_out
=
input_quantizer
(
ln_out
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
ln_out_total
=
input_quantizer
(
ln_out_total
)
ln_out_total
=
input_quantizer
(
ln_out_total
)
else
:
else
:
quantizer
=
None
if
fp8
or
debug
:
if
fp8
or
debug
:
quantizer
=
input_quantizer
if
not
with_quantized_norm
and
not
force_hp_blockwise_ln_out_gather
:
if
not
with_quantized_norm
and
not
force_hp_blockwise_ln_out_gather
:
ln_out
=
input_quantizer
(
ln_out
)
ln_out
=
quantizer
(
ln_out
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
ub_overlap_ag_fprop
:
if
ub_overlap_ag_fprop
:
# Initialize Userbuffers all-gather
# Copy into Userbuffers buffer
ln_out_total
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ub_obj_fprop
=
get_ub
(
ub_name
+
"_fprop"
)
ub_obj
,
ub_obj_fprop
.
get_buffer
(
input_quantizer
,
local_chunk
=
True
).
copy_
(
ln_out
)
ln_out
,
ln_out_total
=
ub_obj_fprop
.
get_buffer
(
input_quantizer
)
quantizer
,
else
:
tp_group
,
# All-gather with NCCL
)
else
:
# Perform NCCL all-gather
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out
,
ln_out
,
tp_group
,
tp_group
,
quantizer
=
(
input_
quantizer
if
fp8
or
debug
else
None
)
,
quantizer
=
quantizer
,
)
)
else
:
else
:
if
(
fp8
or
debug
)
and
not
with_quantized_norm
:
if
(
fp8
or
debug
)
and
not
with_quantized_norm
:
ln_out
=
input_quantizer
(
ln_out
)
ln_out
=
input_quantizer
(
ln_out
)
ln_out_total
=
ln_out
ln_out_total
=
ln_out
nvtx_range_pop
(
f
"
{
nvtx_label
}
.gemm_input_cast_comm"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.gemm_input_cast_comm"
)
# ------------------------------------------------------
# GEMM input tensor is ready...
# ------------------------------------------------------
# Cast weight to expected dtype
# ------------------------------------------------------
# Prepare weight tensor
# ------------------------------------------------------
weightmat
=
weight
weightmat
=
weight
quantized_weight
=
False
quantized_weight
=
False
if
not
fp8
and
not
debug
:
if
fp8
or
debug
:
weightmat
=
cast_if_needed
(
weightmat
,
activation_dtype
)
quantized_weight
=
not
isinstance
(
weight
,
QuantizedTensorBase
)
else
:
quantized_weight
=
not
isinstance
(
weight
,
QuantizedTensor
)
# Configure quantizer
# Configure quantizer
if
weight_quantizer
is
not
None
:
if
weight_quantizer
is
not
None
:
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
#
FP8 cast to workspace buffer
#
Get quantized weight
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
weightmat
=
module
.
get_weight_workspace
(
weightmat
=
module
.
get_weight_workspace
(
tensor
=
weight
,
tensor
=
weight
,
quantizer
=
weight_quantizer
,
quantizer
=
weight_quantizer
,
...
@@ -282,17 +295,21 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -282,17 +295,21 @@ class _LayerNormLinear(torch.autograd.Function):
fsdp_group
=
fsdp_group
,
fsdp_group
=
fsdp_group
,
workspace_dtype
=
activation_dtype
,
workspace_dtype
=
activation_dtype
,
)
)
weightmat
.
update_usage
(
rowwise_usage
=
True
)
else
:
weightmat
=
cast_if_needed
(
weightmat
,
activation_dtype
)
# Cast for AMP
# ------------------------------------------------------
# Weight tensor is ready for GEMM...
# ------------------------------------------------------
# Cast bias to expected dtype
# Cast bias to expected dtype
bias_dtype
=
activation_dtype
bias_dtype
=
activation_dtype
if
needs_quantized_gemm
(
ln_out_total
)
and
activation_dtype
==
torch
.
float32
:
if
needs_quantized_gemm
(
ln_out_total
)
and
activation_dtype
==
torch
.
float32
:
# cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16
bias_dtype
=
torch
.
bfloat16
bias_dtype
=
torch
.
bfloat16
bias
=
cast_if_needed
(
bias
,
bias_dtype
)
if
bias
is
not
None
else
bias
bias
=
cast_if_needed
(
bias
,
bias_dtype
)
if
bias
is
not
None
else
bias
# Configure output quantizer
if
output_quantizer
is
not
None
:
output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
# Calibrate quantizers if needed
# Calibrate quantizers if needed
if
not
fp8
and
fp8_calibration
:
if
not
fp8
and
fp8_calibration
:
if
input_quantizer
is
not
None
:
if
input_quantizer
is
not
None
:
...
@@ -300,47 +317,80 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -300,47 +317,80 @@ class _LayerNormLinear(torch.autograd.Function):
if
weight_quantizer
is
not
None
:
if
weight_quantizer
is
not
None
:
weight_quantizer
.
calibrate
(
weight
)
weight_quantizer
.
calibrate
(
weight
)
ub_obj
=
None
# Choose whether to use GEMM kernel with split accumulator
ub_type
=
None
use_split_accumulator
=
_2X_ACC_FPROP
rs_out
=
None
if
ub_overlap_rs_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_type
=
tex
.
CommOverlapType
.
RS
out_shape
=
[
reduce
(
multiply_op
,
inp_shape
[:
-
1
])
//
tp_world_size
,
out_features
]
rs_out
=
torch
.
empty
(
out_shape
,
dtype
=
activation_dtype
,
device
=
ln_out_total
.
device
)
elif
ub_overlap_ag_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_type
=
tex
.
CommOverlapType
.
AG
if
fp8
:
assert
ub_obj
.
is_fp8_ubuf
(),
"AG overlap with FP8 GEMM inputs requires FP8 buffer."
ln_out_total
=
ub_obj
.
get_buffer
(
input_quantizer
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.gemm"
)
fprop_gemm_use_split_accumulator
=
_2X_ACC_FPROP
if
fp8
:
if
fp8
:
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
hasattr
(
recipe
,
"fp8_gemm_fprop"
):
if
hasattr
(
recipe
,
"fp8_gemm_fprop"
):
fprop_gemm_use_split_accumulator
=
recipe
.
fp8_gemm_fprop
.
use_split_accumulator
use_split_accumulator
=
recipe
.
fp8_gemm_fprop
.
use_split_accumulator
# Configure output quantizer
if
output_quantizer
is
not
None
:
output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
out
,
*
_
,
rs_out
=
general_gemm
(
# Output buffer for Userbuffers reduce-scatter
reduce_scatter_out
=
None
if
ub_overlap_rs_fprop
:
out_shape
=
list
(
inp_shape
)
out_shape
[
0
]
//=
tp_world_size
out_shape
[
-
1
]
=
out_features
reduce_scatter_out
=
torch
.
empty
(
out_shape
,
dtype
=
activation_dtype
,
device
=
inp
.
device
)
# ------------------------------------------------------
# Forward GEMM
# Note: y = x * w^T
# ------------------------------------------------------
nvtx_range_push
(
f
"
{
nvtx_label
}
.gemm"
)
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
weightmat
,
weightmat
,
ln_out_total
,
ln_out_total
,
get_workspace
(),
get_workspace
(),
quantization_params
=
output_quantizer
,
quantization_params
=
output_quantizer
,
out_dtype
=
activation_dtype
,
out_dtype
=
activation_dtype
,
bias
=
bias
,
bias
=
bias
,
use_split_accumulator
=
fprop_gemm_
use_split_accumulator
,
use_split_accumulator
=
use_split_accumulator
,
ub
=
ub_obj
,
ub
=
ub_obj
,
ub_type
=
ub_type
,
ub_type
=
ub_type
,
extra_output
=
r
s
_out
,
extra_output
=
r
educe_scatter
_out
,
)
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.gemm"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.gemm"
)
# ------------------------------------------------------
# Finished forward GEMM...
# ------------------------------------------------------
# Deallocate GEMM input tensor if no longer needed
if
not
weight
.
requires_grad
and
not
return_layernorm_output
:
ln_out
=
ln_out_total
=
None
clear_tensor_data
(
ln_out
,
ln_out_total
)
# ------------------------------------------------------
# Prepare output tensor
# Note: Perform tensor-parallel communication
# ------------------------------------------------------
out
=
None
if
ub_overlap_rs_fprop
:
out
=
reduce_scatter_out
elif
parallel_mode
==
"row"
and
tp_size
>
1
:
nvtx_range_push
(
f
"
{
nvtx_label
}
.row_parallel_comm"
)
out
=
gemm_out
if
sequence_parallel
:
out
,
_
=
reduce_scatter_along_first_dim
(
out
,
tp_group
)
elif
tensor_parallel
:
if
symmetric_ar_type
is
not
None
:
out
,
_
=
symmetric_all_reduce
(
out
,
tp_group
,
all_reduce_type
=
symmetric_ar_type
)
else
:
out
,
_
=
allreduce
(
out
,
tp_group
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.row_parallel_comm"
)
else
:
out
=
gemm_out
out
=
out
.
view
(
-
1
,
*
inp_shape
[
1
:
-
1
],
out_features
)
# ------------------------------------------------------
# Output tensor is ready to return...
# ------------------------------------------------------
if
not
weight
.
requires_grad
:
# ------------------------------------------------------
if
not
return_layernorm_output
:
# Cache state for backward pass
ln_out
=
ln_out_total
=
None
# ------------------------------------------------------
clear_tensor_data
(
ln_out
,
ln_out_total
)
if
is_grad_enabled
:
if
is_grad_enabled
:
ctx
.
weight_quantizer
=
weight_quantizer
ctx
.
weight_quantizer
=
weight_quantizer
...
@@ -351,19 +401,15 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -351,19 +401,15 @@ class _LayerNormLinear(torch.autograd.Function):
# Input with column-wise usage is needed for wgrad GEMM.
# Input with column-wise usage is needed for wgrad GEMM.
if
backward_needs_input
:
if
backward_needs_input
:
if
isinstance
(
ln_out
,
QuantizedTensor
):
if
isinstance
(
ln_out
,
QuantizedTensor
Base
):
# For sequence parallel in vanilla FP8, rowwise data is
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
# can be allgathered.
if
isinstance
(
ln_out
,
MXFP8TensorBase
)
or
not
ctx
.
ln_out_needs_gather
:
if
isinstance
(
ln_out
,
MXFP8TensorBase
)
or
not
ctx
.
ln_out_needs_gather
:
ln_out
.
update_usage
(
rowwise_usage
=
False
)
ln_out
.
update_usage
(
rowwise_usage
=
False
)
# For force_hp_blockwise_ln_out_gather, we should
# be saving the unquantized ln_out to ctx.
assert
not
force_hp_blockwise_ln_out_gather
# Weight with column-wise usage is needed for dgrad GEMM.
# Weight with column-wise usage is needed for dgrad GEMM.
if
isinstance
(
weightmat
,
QuantizedTensor
):
if
isinstance
(
weightmat
,
QuantizedTensor
Base
):
weightmat
.
update_usage
(
columnwise_usage
=
True
)
weightmat
.
update_usage
(
columnwise_usage
=
True
)
if
cpu_offloading
:
if
cpu_offloading
:
...
@@ -406,7 +452,7 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -406,7 +452,7 @@ class _LayerNormLinear(torch.autograd.Function):
)
)
ctx
.
save_for_backward
(
*
tensors_to_save
)
ctx
.
save_for_backward
(
*
tensors_to_save
)
ctx
.
tensor_objects
=
tensor_objects
ctx
.
tensor_objects
=
tensor_objects
ctx
.
requires_dgrad
=
inp
.
requires_grad
ctx
.
requires_dgrad
=
inp
_
requires_grad
ctx
.
requires_wgrad
=
weight
.
requires_grad
ctx
.
requires_wgrad
=
weight
.
requires_grad
ctx
.
quantized_weight
=
quantized_weight
ctx
.
quantized_weight
=
quantized_weight
if
fuse_wgrad_accumulation
and
weight
.
requires_grad
:
if
fuse_wgrad_accumulation
and
weight
.
requires_grad
:
...
@@ -439,7 +485,7 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -439,7 +485,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx
.
ub_bulk_wgrad
=
ub_bulk_wgrad
ctx
.
ub_bulk_wgrad
=
ub_bulk_wgrad
ctx
.
ub_bulk_dgrad
=
ub_bulk_dgrad
ctx
.
ub_bulk_dgrad
=
ub_bulk_dgrad
ctx
.
ub_name
=
ub_name
ctx
.
ub_name
=
ub_name
ctx
.
requires_dgrad
=
inp
.
requires_grad
ctx
.
requires_dgrad
=
inp
_
requires_grad
ctx
.
normalization
=
normalization
ctx
.
normalization
=
normalization
ctx
.
reduce_and_update_bwd_fp8_tensors
=
False
ctx
.
reduce_and_update_bwd_fp8_tensors
=
False
if
ctx
.
fp8
and
requires_grad
(
inp
,
ln_weight
,
ln_bias
,
weight
,
bias
):
if
ctx
.
fp8
and
requires_grad
(
inp
,
ln_weight
,
ln_bias
,
weight
,
bias
):
...
@@ -450,29 +496,16 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -450,29 +496,16 @@ class _LayerNormLinear(torch.autograd.Function):
ctx
.
wgrad_store
=
wgrad_store
ctx
.
wgrad_store
=
wgrad_store
ctx
.
debug
=
debug
ctx
.
debug
=
debug
# Row Parallel Linear
# ------------------------------------------------------
if
ub_overlap_rs_fprop
:
# Cached state for backward pass is ready...
out
=
rs_out
# ------------------------------------------------------
elif
parallel_mode
==
"row"
:
nvtx_range_push
(
f
"
{
nvtx_label
}
.row_parallel_comm"
)
if
sequence_parallel
:
out
,
_
=
reduce_scatter_along_first_dim
(
out
,
tp_group
)
elif
tensor_parallel
:
if
symmetric_ar_type
is
not
None
:
out
,
_
=
symmetric_all_reduce
(
out
,
tp_group
,
all_reduce_type
=
symmetric_ar_type
)
else
:
out
,
_
=
allreduce
(
out
,
tp_group
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.row_parallel_comm"
)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
out
=
out
.
view
(
-
1
,
*
inp_shape
[
1
:
-
1
],
out_features
)
if
return_layernorm_output
:
if
return_layernorm_output
:
if
return_layernorm_output_gathered
:
if
return_layernorm_output_gathered
:
shape
=
list
(
inp
.
shape
)
shape
=
list
(
inp
_
shape
)
shape
[
0
]
*=
tp_size
shape
[
0
]
*=
tp_size
return
out
,
ln_out_return
.
view
(
shape
)
return
out
,
ln_out_return
.
view
(
shape
)
return
out
,
ln_out_return
.
view
_as
(
inp
)
return
out
,
ln_out_return
.
view
(
inp
_shape
)
return
out
return
out
@
staticmethod
@
staticmethod
...
@@ -487,24 +520,6 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -487,24 +520,6 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_label
=
f
"
{
nvtx_label
}
.
{
ctx
.
ub_name
}
"
nvtx_label
=
f
"
{
nvtx_label
}
.
{
ctx
.
ub_name
}
"
with
torch
.
cuda
.
nvtx
.
range
(
"_LayerNormLinear_backward"
):
with
torch
.
cuda
.
nvtx
.
range
(
"_LayerNormLinear_backward"
):
if
(
ctx
.
fp8
and
any
(
[
ctx
.
ub_overlap_ag
,
ctx
.
ub_overlap_rs_dgrad
,
ctx
.
ub_bulk_dgrad
,
ctx
.
ub_bulk_wgrad
,
]
)
and
(
ctx
.
fp8_recipe
is
not
None
)
):
if
not
ctx
.
fp8_recipe
.
float8_per_tensor_scaling
():
raise
NotImplementedError
(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
saved_tensors
=
ctx
.
saved_tensors
saved_tensors
=
ctx
.
saved_tensors
(
# pylint: disable=unbalanced-tuple-unpacking
(
# pylint: disable=unbalanced-tuple-unpacking
inputmat
,
inputmat
,
...
@@ -549,66 +564,50 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -549,66 +564,50 @@ class _LayerNormLinear(torch.autograd.Function):
if
ctx
.
requires_wgrad
and
ctx
.
fuse_wgrad_accumulation
:
if
ctx
.
requires_wgrad
and
ctx
.
fuse_wgrad_accumulation
:
origin_weight
.
main_grad
=
main_grad
origin_weight
.
main_grad
=
main_grad
# Configure Userbuffers communication (comm+GEMM overlap)
ctx
.
ub_obj_gradout
=
None
ctx
.
ub_obj_gradout
=
None
ub_obj_dgrad
=
None
ub_obj_dgrad
=
None
ub_obj_wgrad
=
None
ub_obj_wgrad
=
None
ub_type_dgrad
=
None
ub_type_dgrad
=
None
ub_type_wgrad
=
None
ub_type_wgrad
=
None
dgrad_shape
=
[
reduce
(
multiply_op
,
ctx
.
inp_shape
[:
-
1
]),
ctx
.
inp_shape
[
-
1
]]
dgrad_shape
=
[
reduce
(
multiply_op
,
ctx
.
inp_shape
[:
-
1
]),
ctx
.
inp_shape
[
-
1
]]
rs_out
=
None
dgrad_bulk
=
None
if
ctx
.
ub_overlap_ag
:
if
ctx
.
ub_overlap_ag
:
# Overlap grad_output all-gather with dgrad compute
# Overlap grad_output all-gather with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
elif
ctx
.
ub_overlap_rs_dgrad
:
elif
ctx
.
ub_overlap_rs_dgrad
:
# Overlap dgrad reduce-scatter with dgrad compute
# Overlap dgrad reduce-scatter with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
RS
ub_type_dgrad
=
tex
.
CommOverlapType
.
RS
rs_out
=
torch
.
empty
(
dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
inputmat
.
device
)
else
:
else
:
if
ctx
.
ub_bulk_dgrad
:
if
ctx
.
ub_bulk_dgrad
:
# Overlap inputmat all-gather with dgrad compute
# Overlap inputmat all-gather with dgrad compute
# NOTE: Copying into communication buffer will always prefer rowwise data,
# and will copy columnwise data if rowwise does not exist. In that case,
# the all-gather will apply to the leading dimension of the transpose,
# which then needs to be interleaved correctly before WGRAD.
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
ub_obj_dgrad
.
copy_into_buffer
(
ln_out
,
ctx
.
input_quantizer
,
local_chunk
=
True
)
if
ctx
.
ub_bulk_wgrad
:
if
ctx
.
ub_bulk_wgrad
:
# Overlap dgrad reduce-scatter with wgrad compute
# Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
)
ub_obj_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
)
ub_type_wgrad
=
tex
.
CommOverlapType
.
RS
ub_type_wgrad
=
tex
.
CommOverlapType
.
RS
ub_obj_wgrad
.
set_buffer_params
(
ctx
.
grad_input_quantizer
)
dgrad_bulk
=
ub_obj_wgrad
.
get_buffer
(
ctx
.
grad_input_quantizer
)
# --------------------------------------------------
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# --------------------------------------------------
# Configure quantizer for grad output tensor
# Configure quantizer for grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
# requires column-wise usage
if
ctx
.
grad_output_quantizer
is
not
None
:
if
ctx
.
grad_output_quantizer
is
not
None
:
rowwise_usage
=
True
quantizer
=
ctx
.
grad_output_quantizer
columnwise_usage
=
True
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
if
ctx
.
ub_overlap_ag
and
isinstance
(
if
ctx
.
ub_overlap_ag
:
ctx
.
grad_output_quantizer
,
# Userbuffers only supports communication for one
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
),
# tensor usage at a time. Configure quantizer with
):
# usage for only dgrad GEMM.
# If data is in FP8 and communication is handled
quantizer
.
set_usage
(
columnwise
=
False
)
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage
=
False
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
rowwise_usage
,
columnwise
=
columnwise_usage
,
)
# Prepare grad output tensor
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# Note: Cast to expected dtype and perform tensor-parallel communication
...
@@ -624,12 +623,21 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -624,12 +623,21 @@ class _LayerNormLinear(torch.autograd.Function):
)
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.grad_output_preprocess"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.grad_output_preprocess"
)
# Launch tensor-parallel communication for LayerNorm out tensor
# --------------------------------------------------
# Grad output tensor is ready for computing grad input...
# --------------------------------------------------
# --------------------------------------------------
# Prepare GEMM input tensor
# Note: Input tensor is needed for wgrad GEMM.
# Tensor-parallel communication is overlapped with dgrad
# GEMM.
# --------------------------------------------------
ln_out_total
=
None
ln_out_total
=
None
ln_out_total_work
=
None
ln_out_total_work
=
None
if
ctx
.
ln_out_needs_gather
and
not
ctx
.
ub_bulk_dgrad
:
if
ctx
.
ln_out_needs_gather
:
quantizer
=
None
quantizer
=
None
if
ctx
.
input_quantizer
is
not
None
:
if
ctx
.
input_quantizer
is
not
None
and
not
ctx
.
force_hp_blockwise_ln_out_gather
:
quantizer
=
ctx
.
input_quantizer
quantizer
=
ctx
.
input_quantizer
if
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)):
if
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)):
# If data is in FP8, we compute FP8 transposes manually
# If data is in FP8, we compute FP8 transposes manually
...
@@ -637,70 +645,92 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -637,70 +645,92 @@ class _LayerNormLinear(torch.autograd.Function):
else
:
else
:
# wgrad GEMM requires input with column-wise usage
# wgrad GEMM requires input with column-wise usage
quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.column_parallel_comm_input"
)
if
ctx
.
ub_bulk_dgrad
:
# async_op is not compatible with high precision gather since
ln_out_total
,
_
=
fill_userbuffers_buffer_for_all_gather
(
# gather_along_first_dim does not offer callback chaining.
ub_obj_dgrad
,
gather_quantizer
=
None
if
ctx
.
force_hp_blockwise_ln_out_gather
else
quantizer
ln_out
,
ln_out_total
,
ln_out_total_work
=
gather_along_first_dim
(
quantizer
,
ln_out
,
ctx
.
tp_group
,
ctx
.
tp_group
,
)
async_op
=
True
,
else
:
quantizer
=
gather_quantizer
,
nvtx_range_push
(
f
"
{
nvtx_label
}
.column_parallel_comm_input"
)
)
ln_out_total
,
ln_out_total_work
=
gather_along_first_dim
(
nvtx_range_pop
(
f
"
{
nvtx_label
}
.column_parallel_comm_input"
)
ln_out
,
ctx
.
tp_group
,
async_op
=
True
,
quantizer
=
quantizer
,
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.column_parallel_comm_input"
)
else
:
else
:
ln_out_total
=
ln_out
ln_out_total
=
ln_out
# --------------------------------------------------
# Check whether to output wgrad GEMM directly into main grad
# Input tensor is ready for computing grad weight...
if
ctx
.
is_first_microbatch
is
not
None
:
# --------------------------------------------------
accumulate_wgrad_into_param_main_grad
=
(
ctx
.
fuse_wgrad_accumulation
and
not
ctx
.
is_first_microbatch
# --------------------------------------------------
)
# Compute grad input tensor
else
:
# Note: Gradient w.r.t. GEMM input (i.e. norm output).
accumulate_wgrad_into_param_main_grad
=
ctx
.
fuse_wgrad_accumulation
# --------------------------------------------------
# dgrad GEMM
# Make sure required data is available
if
ctx
.
grad_input_quantizer
is
not
None
:
if
isinstance
(
grad_output
,
QuantizedTensorBase
):
ctx
.
grad_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
grad_output
.
update_usage
(
rowwise_usage
=
True
)
if
ctx
.
weight_quantizer
is
not
None
and
isinstance
(
weight
,
QuantizedTensorBase
):
nvtx_range_push
(
f
"
{
nvtx_label
}
.dgrad_gemm"
)
weight
.
update_usage
(
columnwise_usage
=
True
)
dgrad_gemm_use_split_accumulator
=
_2X_ACC_DGRAD
# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator
=
_2X_ACC_DGRAD
if
ctx
.
fp8
:
if
ctx
.
fp8
:
recipe
=
ctx
.
fp8_recipe
recipe
=
ctx
.
fp8_recipe
if
hasattr
(
recipe
,
"fp8_gemm_dgrad"
):
if
hasattr
(
recipe
,
"fp8_gemm_dgrad"
):
dgrad_gemm_use_split_accumulator
=
recipe
.
fp8_gemm_dgrad
.
use_split_accumulator
use_split_accumulator
=
recipe
.
fp8_gemm_dgrad
.
use_split_accumulator
# Update grad input quantizer
if
ctx
.
grad_input_quantizer
is
not
None
:
ctx
.
grad_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
ctx
.
weight_quantizer
is
not
None
and
isinstance
(
weight
,
QuantizedTensor
):
# Output buffers for Userbuffers reduce-scatter
weight
.
update_usage
(
gemm_out
=
None
rowwise_usage
=
ctx
.
weight_quantizer
.
rowwise_usage
,
reduce_scatter_out
=
None
columnwise_usage
=
ctx
.
weight_quantizer
.
columnwise_usage
,
if
ctx
.
ub_overlap_rs_dgrad
:
reduce_scatter_out
=
torch
.
empty
(
dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
grad_outputs
[
0
].
device
)
)
dgrad
,
*
_
=
general_gemm
(
elif
ctx
.
ub_bulk_wgrad
:
gemm_out
=
ub_obj_wgrad
.
get_buffer
(
local_chunk
=
False
)
# dgrad GEMM
# Note: dx = dy * w
nvtx_range_push
(
f
"
{
nvtx_label
}
.dgrad_gemm"
)
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
weight
,
weight
,
grad_output
,
grad_output
,
get_workspace
(),
get_workspace
(),
layout
=
"NN"
,
layout
=
"NN"
,
grad
=
True
,
grad
=
True
,
quantization_params
=
ctx
.
grad_input_quantizer
,
quantization_params
=
ctx
.
grad_input_quantizer
,
out
=
dgrad_bulk
,
out
=
gemm_out
,
out_dtype
=
ctx
.
activation_dtype
,
out_dtype
=
ctx
.
activation_dtype
,
use_split_accumulator
=
dgrad_gemm_
use_split_accumulator
,
use_split_accumulator
=
use_split_accumulator
,
ub
=
ub_obj_dgrad
,
ub
=
ub_obj_dgrad
,
ub_type
=
ub_type_dgrad
,
ub_type
=
ub_type_dgrad
,
extra_output
=
r
s
_out
,
extra_output
=
r
educe_scatter
_out
,
bulk_overlap
=
ctx
.
ub_bulk_dgrad
,
bulk_overlap
=
ctx
.
ub_bulk_dgrad
,
)
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.dgrad_gemm"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.dgrad_gemm"
)
# Launch tensor-parallel communication
# Prepare grad input tensor
# Note: Perform tensor-parallel communication
dgrad
=
None
dgrad_work
=
None
dgrad_work
=
None
if
ctx
.
ub_overlap_rs_dgrad
:
if
ctx
.
ub_overlap_rs_dgrad
:
dgrad
=
rs_out
dgrad
=
reduce_scatter_out
elif
ctx
.
parallel_mode
==
"column"
and
not
ctx
.
ub_bulk_wgrad
:
elif
ctx
.
ub_bulk_wgrad
:
dgrad
=
ub_obj_wgrad
.
get_buffer
(
local_chunk
=
True
)
elif
ctx
.
parallel_mode
==
"column"
and
ctx
.
tp_size
>
1
:
nvtx_range_push
(
f
"
{
nvtx_label
}
.column_parallel_comm_dgrad"
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.column_parallel_comm_dgrad"
)
dgrad
=
gemm_out
if
ctx
.
sequence_parallel
:
if
ctx
.
sequence_parallel
:
if
ctx
.
return_layernorm_output
and
ctx
.
return_layernorm_output_gathered
:
dgrad
=
dgrad
+
grad_outputs
[
1
].
view_as
(
dgrad
)
dgrad
,
dgrad_work
=
reduce_scatter_along_first_dim
(
dgrad
,
dgrad_work
=
reduce_scatter_along_first_dim
(
dgrad
,
dgrad
,
ctx
.
tp_group
,
ctx
.
tp_group
,
...
@@ -709,41 +739,55 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -709,41 +739,55 @@ class _LayerNormLinear(torch.autograd.Function):
else
:
else
:
dgrad
,
dgrad_work
=
allreduce
(
dgrad
,
ctx
.
tp_group
,
async_op
=
True
)
dgrad
,
dgrad_work
=
allreduce
(
dgrad
,
ctx
.
tp_group
,
async_op
=
True
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.column_parallel_comm_dgrad"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.column_parallel_comm_dgrad"
)
else
:
dgrad
=
gemm_out
# --------------------------------------------------
# Grad input tensor has been computed...
# --------------------------------------------------
# --------------------------------------------------
# Compute grad weight
# --------------------------------------------------
# Compute grad weight tensor
wgrad
=
None
wgrad
=
None
if
ctx
.
requires_wgrad
:
if
ctx
.
requires_wgrad
:
# Synchronize tensor-parallel communication for input tensor
# Prepare input tensor
if
ctx
.
ub_bulk_dgrad
:
# Note: Synchronize tensor-parallel communication and
ln_out_total
=
ub_obj_dgrad
.
get_buffer
(
ctx
.
input_quantizer
)
# make sure required data is available
if
ctx
.
fp8
:
# FP8 GEMM on Hopper only supports TN layout so the gathered input must have
# a valid transpose.
if
ln_out
.
_data
is
None
:
# All-gather executed on columnwise data and result is in rowwise data,
# so we need to fix the interleaving before WGRAD.
ln_out_total
=
_fix_gathered_fp8_transpose
(
ln_out_total
,
ctx
.
tp_size
)
else
:
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
ln_out_total
.
_create_transpose
()
if
ln_out_total_work
is
not
None
:
if
ln_out_total_work
is
not
None
:
ln_out_total_work
.
wait
()
ln_out_total_work
.
wait
()
ln_out_total_work
=
None
ln_out_total_work
=
None
if
ctx
.
input_quantizer
is
not
None
and
not
isinstance
(
if
ctx
.
fp8
or
ctx
.
debug
:
ln_out_total
,
QuantizedTensor
if
isinstance
(
ln_out_total
,
QuantizedTensorBase
):
):
ln_out_total
.
update_usage
(
columnwise_usage
=
True
)
# Async gather may have been done in BF16
else
:
# call quantizer after gather.
ctx
.
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ctx
.
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ln_out_total
=
ctx
.
input_quantizer
(
ln_out_total
)
ln_out_total
=
ctx
.
input_quantizer
(
ln_out_total
)
# Prepare grad output tensor
# Make sure GEMM inputs have required data
# Note: Synchronize tensor-parallel communication and
if
isinstance
(
ln_out_total
,
QuantizedTensor
):
# make sure required data is available
ln_out_total
.
update_usage
(
columnwise_usage
=
True
)
if
ctx
.
ub_overlap_ag
and
isinstance
(
ctx
.
grad_output_quantizer
,
MXFP8Quantizer
):
if
isinstance
(
grad_output
,
QuantizedTensor
):
# UB does not support overlapping grad output
grad_output
.
update_usage
(
columnwise_usage
=
True
)
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
grad_output
,
_
=
gather_along_first_dim
(
grad_outputs
[
0
],
ctx
.
tp_group
,
quantizer
=
ctx
.
grad_output_quantizer
,
)
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
grad_output
,
QuantizedTensorBase
):
grad_output
.
update_usage
(
columnwise_usage
=
True
)
else
:
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
grad_output
=
ctx
.
grad_output_quantizer
(
grad_output
)
# Figure out whether to use split accumulator
# Figure out whether to use split accumulator
use_split_accumulator
=
_2X_ACC_WGRAD
use_split_accumulator
=
_2X_ACC_WGRAD
...
@@ -752,55 +796,95 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -752,55 +796,95 @@ class _LayerNormLinear(torch.autograd.Function):
if
hasattr
(
recipe
,
"fp8_gemm_wgrad"
):
if
hasattr
(
recipe
,
"fp8_gemm_wgrad"
):
use_split_accumulator
=
recipe
.
fp8_gemm_wgrad
.
use_split_accumulator
use_split_accumulator
=
recipe
.
fp8_gemm_wgrad
.
use_split_accumulator
# Output buffer for overlapping grad input
# Figure out whether to output wgrad GEMM directly into main grad
if
ctx
.
is_first_microbatch
is
not
None
:
accumulate_wgrad_into_param_main_grad
=
(
ctx
.
fuse_wgrad_accumulation
and
not
ctx
.
is_first_microbatch
)
else
:
accumulate_wgrad_into_param_main_grad
=
ctx
.
fuse_wgrad_accumulation
# Output buffer for overlapping FP8 grad input
# reduce-scatter with wgrad GEMM
# reduce-scatter with wgrad GEMM
reduce_scatter_out
=
None
if
ctx
.
ub_bulk_wgrad
and
ub_obj_wgrad
.
is_fp8_ubuf
():
if
ctx
.
ub_bulk_wgrad
and
ub_obj_wgrad
.
is_fp8_ubuf
():
r
s
_out
=
torch
.
empty
(
r
educe_scatter
_out
=
torch
.
empty
(
dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
inputmat
.
device
dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
grad_outputs
[
0
]
.
device
)
)
# wgrad GEMM
# Arguments to include in wgrad GEMM closure
# Note: Fuse with bgrad computation if needed
wgrad_gemm_kwargs
=
{
nvtx_range_push
(
f
"
{
nvtx_label
}
.wgrad_gemm"
)
"workspace"
:
get_workspace
(),
general_gemm_wgrad
=
functools
.
partial
(
"out_dtype"
:
(
general_gemm
,
out_dtype
=
(
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
),
),
workspace
=
get_workspace
(),
"quantization_params"
:
ctx
.
grad_weight_quantizer
,
layout
=
"NT"
,
"accumulate"
:
accumulate_wgrad_into_param_main_grad
,
grad
=
True
,
"layout"
:
"NT"
,
bias
=
(
bias
if
(
grad_bias
is
None
and
not
ctx
.
fp8
)
else
None
),
"out"
:
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
out
=
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
"bias"
:
(
bias
if
(
grad_bias
is
None
and
not
ctx
.
fp8
)
else
None
),
use_split_accumulator
=
use_split_accumulator
,
"use_split_accumulator"
:
use_split_accumulator
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
"grad"
:
True
,
quantization_params
=
ctx
.
grad_weight_quantizer
,
"ub"
:
ub_obj_wgrad
,
ub
=
ub_obj_wgrad
,
"ub_type"
:
ub_type_wgrad
,
ub_type
=
ub_type_wgrad
,
"extra_output"
:
reduce_scatter_out
,
extra_output
=
rs_out
,
"bulk_overlap"
:
ctx
.
ub_bulk_wgrad
,
bulk_overlap
=
ctx
.
ub_bulk_wgrad
,
}
)
def
wgrad_gemm
(
x
:
torch
.
Tensor
,
dy
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Perform wgrad GEMM: dw = dy^T * x
May be fused with bgrad computation.
May be called outside of this function to enable
some advanced communication/compute overlapping.
"""
nvtx_range_push
(
f
"
{
nvtx_label
}
.wgrad_gemm"
)
dw
,
db
,
*
_
=
general_gemm
(
x
,
dy
,
**
wgrad_gemm_kwargs
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.wgrad_gemm"
)
return
dw
,
db
# Choose whether to call wgrad GEMM now or delay
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
ctx
.
wgrad_store
.
put
([
ln_out_total
,
grad_output
],
general_gemm_wgrad
)
if
(
wgrad_gemm_kwargs
[
"ub"
]
is
not
None
or
wgrad_gemm_kwargs
[
"ub_type"
]
is
not
None
or
wgrad_gemm_kwargs
[
"extra_output"
]
is
not
None
or
wgrad_gemm_kwargs
[
"bulk_overlap"
]
):
raise
NotImplementedError
(
"Delayed weight grad computation is not supported "
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx
.
wgrad_store
.
put
([
ln_out_total
,
grad_output
],
wgrad_gemm
)
else
:
else
:
wgrad
,
grad_bias_
,
_
,
rs_out
=
general_gemm_wgrad
(
ln_out_total
,
grad_output
)
# Call wgrad GEMM now
wgrad
,
grad_bias_
=
wgrad_gemm
(
ln_out_total
,
grad_output
)
# Update grad bias if needed
if
grad_bias
is
None
:
if
grad_bias
is
None
:
grad_bias
=
grad_bias_
grad_bias
=
grad_bias_
del
grad_bias_
del
grad_bias_
# Deallocate input tensor
# Deallocate input tensor
if permitted
if
not
ctx
.
return_layernorm_output
:
if
not
ctx
.
return_layernorm_output
:
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
clear_tensor_data
(
ln_out_total
)
clear_tensor_data
(
ln_out_total
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.wgrad_gemm"
)
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if
ctx
.
ub_bulk_wgrad
:
if
ctx
.
ub_bulk_wgrad
:
if
ub_obj_wgrad
.
is_fp8_ubuf
():
if
ub_obj_wgrad
.
is_fp8_ubuf
():
dgrad
=
r
s
_out
dgrad
=
r
educe_scatter
_out
else
:
else
:
dgrad
=
ub_obj_wgrad
.
get_buffer
(
None
,
local_chunk
=
True
)
dgrad
=
ub_obj_wgrad
.
get_buffer
(
local_chunk
=
True
).
clone
()
# --------------------------------------------------
# Grad weight has been computed...
# --------------------------------------------------
# Don't return grad bias if not needed
# Don't return grad bias if not needed
if
not
ctx
.
use_bias
:
if
not
ctx
.
use_bias
:
...
@@ -879,7 +963,7 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -879,7 +963,7 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_range_pop
(
f
"
{
nvtx_label
}
.reduce_and_update_fp8_tensors"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.reduce_and_update_fp8_tensors"
)
# Scatter fp8 weight buffers
# Scatter fp8 weight buffers
# if ctx.fp8 and not isinstance(weight, QuantizedTensor):
# if ctx.fp8 and not isinstance(weight, QuantizedTensor
Base
):
# _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
# _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
return
(
return
(
...
@@ -1405,6 +1489,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1405,6 +1489,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
"Splitting QuantizedTensor into multiple params is not supported"
"Splitting QuantizedTensor into multiple params is not supported"
)
)
else
:
else
:
warnings
.
warn
(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights
=
[
w
.
dequantize
()
for
w
in
unfused_weights
]
unfused_weights
=
[
w
.
dequantize
()
for
w
in
unfused_weights
]
weight_tensor
=
noop_cat
(
unfused_weights
)
weight_tensor
=
noop_cat
(
unfused_weights
)
...
@@ -1511,7 +1599,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1511,7 +1599,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
grad_output_quantizer
=
None
grad_output_quantizer
=
None
output_quantizer
=
None
output_quantizer
=
None
input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
input_quantizer
.
internal
=
Fals
e
input_quantizer
.
internal
=
Tru
e
weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
weight_quantizer
.
internal
=
True
weight_quantizer
.
internal
=
True
if
fp8_output
:
if
fp8_output
:
...
@@ -1579,3 +1667,12 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1579,3 +1667,12 @@ class LayerNormLinear(TransformerEngineBaseModule):
self
.
quantizers
[
"scaling_bwd"
][
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_epsilon
=
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
].
amax_epsilon
=
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
# parallel related
if
self
.
sequence_parallel
and
self
.
parallel_mode
==
"row"
:
# customize grad_output_quantizer with amax reduction TP group
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
with_amax_reduction
=
True
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_reduction_group
=
self
.
tp_group
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
f8c2af4c
...
@@ -8,7 +8,6 @@ import warnings
...
@@ -8,7 +8,6 @@ import warnings
from
typing
import
Callable
,
Optional
,
Tuple
,
Union
from
typing
import
Callable
,
Optional
,
Tuple
,
Union
from
functools
import
reduce
from
functools
import
reduce
from
operator
import
mul
as
multiply_op
from
operator
import
mul
as
multiply_op
import
functools
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
...
@@ -20,6 +19,7 @@ import transformer_engine_torch as tex
...
@@ -20,6 +19,7 @@ import transformer_engine_torch as tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
import
torch_version
from
.base
import
(
from
.base
import
(
fill_userbuffers_buffer_for_all_gather
,
get_workspace
,
get_workspace
,
_ub_communicators
,
_ub_communicators
,
get_ub
,
get_ub
,
...
@@ -43,7 +43,6 @@ from ..utils import (
...
@@ -43,7 +43,6 @@ from ..utils import (
assert_dim_for_fp8_exec
,
assert_dim_for_fp8_exec
,
clear_tensor_data
,
clear_tensor_data
,
requires_grad
,
requires_grad
,
non_tn_fp8_gemm_supported
,
needs_quantized_gemm
,
needs_quantized_gemm
,
)
)
from
..distributed
import
(
from
..distributed
import
(
...
@@ -67,10 +66,11 @@ from ..tensor.float8_tensor import (
...
@@ -67,10 +66,11 @@ from ..tensor.float8_tensor import (
)
)
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
._common
import
apply_normalization
,
_fix_gathered_fp8_transpose
,
WeightGradStore
from
._common
import
apply_normalization
,
WeightGradStore
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
..tensor.quantized_tensor
import
(
from
..tensor.quantized_tensor
import
(
QuantizedTensor
,
QuantizedTensor
,
QuantizedTensorBase
,
Quantizer
,
Quantizer
,
prepare_for_saving
,
prepare_for_saving
,
restore_from_saved
,
restore_from_saved
,
...
@@ -201,24 +201,16 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -201,24 +201,16 @@ class _LayerNormMLP(torch.autograd.Function):
)
->
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
]:
)
->
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
]:
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
in_features
,
inp_shape
=
ln_weight
.
numel
(),
inp
.
shape
# Make sure input dimensions are compatible
# Make sure input dimensions are compatible
in_features
,
inp_shape
=
ln_weight
.
numel
(),
inp
.
shape
assert
inp_shape
[
-
1
]
==
in_features
,
"GEMM not possible"
assert
inp_shape
[
-
1
]
==
in_features
,
"GEMM not possible"
inputmat
=
inp
.
view
((
-
1
,
in_features
))
inputmat
=
inp
.
view
((
-
1
,
in_features
))
if
fp8
:
if
fp8
:
assert_dim_for_fp8_exec
(
inputmat
,
fc1_weight
,
fc2_weight
)
assert_dim_for_fp8_exec
(
inputmat
,
fc1_weight
,
fc2_weight
)
if
any
([
ub_overlap_ag
,
ub_overlap_rs
])
and
not
(
FP8GlobalStateManager
.
get_fp8_recipe
().
float8_per_tensor_scaling
()
):
raise
NotImplementedError
(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
activation_func
=
_act_func
(
activation_func
=
_act_func
(
activation
,
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8
else
None
activation
,
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8
else
None
)[
0
]
)[
0
]
device
=
inp
.
device
# Cast for native AMP
# Cast for native AMP
inputmat
=
cast_if_needed
(
inputmat
,
activation_dtype
)
inputmat
=
cast_if_needed
(
inputmat
,
activation_dtype
)
...
@@ -226,6 +218,38 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -226,6 +218,38 @@ class _LayerNormMLP(torch.autograd.Function):
if
ln_bias
is
not
None
:
if
ln_bias
is
not
None
:
ln_bias
=
cast_if_needed
(
ln_bias
,
activation_dtype
)
ln_bias
=
cast_if_needed
(
ln_bias
,
activation_dtype
)
tp_world_size
=
get_distributed_world_size
(
tp_group
)
backwards_needs_fc1_input
=
is_grad_enabled
and
fc1_weight
.
requires_grad
device
=
inp
.
device
# Configure Userbuffers communication (comm+GEMM overlap)
ub_overlap_ag
=
ub_overlap_ag
and
is_grad_enabled
and
not
return_layernorm_output_gathered
ub_overlap_rs
=
ub_overlap_rs
and
is_grad_enabled
# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator
=
_2X_ACC_FPROP
if
fp8
:
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
hasattr
(
recipe
,
"fp8_gemm_fprop"
):
use_split_accumulator
=
recipe
.
fp8_gemm_fprop
.
use_split_accumulator
# Configure quantizer for norm output
if
fp8
:
if
fc1_input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for FC1 input tensor"
)
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
backwards_needs_fc1_input
)
if
sequence_parallel
and
isinstance
(
fc1_input_quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)
):
# All-gather is not supported with FP8 column-wise data
fc1_input_quantizer
.
set_usage
(
columnwise
=
False
)
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_fc1_input_gather
=
(
fp8
and
sequence_parallel
and
isinstance
(
fc1_input_quantizer
,
Float8BlockQuantizer
)
)
# for fp8 DelayedScaling: layernorm output = FP8
# for fp8 DelayedScaling: layernorm output = FP8
# only output of the linear is returned
# only output of the linear is returned
# for return_layernorm_output: layernorm output = High precision, then cast to FP8
# for return_layernorm_output: layernorm output = High precision, then cast to FP8
...
@@ -241,29 +265,6 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -241,29 +265,6 @@ class _LayerNormMLP(torch.autograd.Function):
# Kernels not available for norm fusion.
# Kernels not available for norm fusion.
with_quantized_norm
=
False
with_quantized_norm
=
False
tp_world_size
=
get_distributed_world_size
(
tp_group
)
ub_overlap_ag
=
ub_overlap_ag
and
is_grad_enabled
and
not
return_layernorm_output_gathered
ub_overlap_rs
=
ub_overlap_rs
and
is_grad_enabled
backwards_needs_fc1_input
=
is_grad_enabled
and
fc1_weight
.
requires_grad
# TODO(kwyss): Support FP8 allgather of Float8BlockQuantizer recipe
force_hp_fc1_input_gather
=
(
fp8
and
sequence_parallel
and
isinstance
(
fc1_input_quantizer
,
Float8BlockQuantizer
)
)
# Perform TP communication in high precision.
# Configure quantizer for norm output
if
fp8
:
if
fc1_input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for FC1 input tensor"
)
columnwise_usage
=
backwards_needs_fc1_input
if
(
columnwise_usage
and
sequence_parallel
and
not
isinstance
(
fc1_input_quantizer
,
MXFP8Quantizer
)
):
columnwise_usage
=
False
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
columnwise_usage
)
# Apply normalization
# Apply normalization
ln_out
,
mu
,
rsigma
=
apply_normalization
(
ln_out
,
mu
,
rsigma
=
apply_normalization
(
inputmat
,
inputmat
,
...
@@ -297,39 +298,43 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -297,39 +298,43 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
ln_out_total
=
fc1_input_quantizer
(
ln_out_total
)
ln_out_total
=
fc1_input_quantizer
(
ln_out_total
)
else
:
else
:
quantizer
=
None
if
fp8
or
debug
:
if
fp8
or
debug
:
quantizer
=
fc1_input_quantizer
if
not
with_quantized_norm
and
not
force_hp_fc1_input_gather
:
if
not
with_quantized_norm
and
not
force_hp_fc1_input_gather
:
ln_out
=
fc1_input_quantizer
(
ln_out
)
ln_out
=
fc1_input_quantizer
(
ln_out
)
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
ub_overlap_ag
:
if
ub_overlap_ag
:
# Copy into Userbuffers buffer
# Copy into Userbuffers buffer
ub_obj_lnout
=
get_ub
(
"fc1_fprop"
)
ub_obj_lnout
=
get_ub
(
"fc1_fprop"
)
ub_obj_lnout
.
get_buffer
(
fc1_input_quantizer
,
local_chunk
=
True
).
copy_
(
ln_out
)
ln_out_total
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ln_out_total
=
ub_obj_lnout
.
get_buffer
(
fc1_input_quantizer
)
ub_obj_lnout
,
ln_out
,
quantizer
,
tp_group
,
)
else
:
else
:
# All-gather with NCCL
# All-gather with NCCL
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out
,
ln_out
,
tp_group
,
tp_group
,
quantizer
=
(
fc1_input_quantizer
if
fp8
or
debug
else
None
)
,
quantizer
=
quantizer
,
)
)
else
:
else
:
# NOTE: force_hp_fc1_input_gather is redundant with else, but
if
(
fp8
or
debug
)
and
not
with_quantized_norm
:
# here for clarity. We should not quantize ln_out if bwd needs
# to gather in hp.
if
(
fp8
or
debug
)
and
not
with_quantized_norm
and
not
force_hp_fc1_input_gather
:
ln_out
=
fc1_input_quantizer
(
ln_out
)
ln_out
=
fc1_input_quantizer
(
ln_out
)
ln_out_total
=
ln_out
ln_out_total
=
ln_out
# Cast weights to expected dtype
# Cast weights to expected dtype
fc1_weight_final
=
fc1_weight
fc1_weight_final
=
fc1_weight
fc2_weight_final
=
fc2_weight
fc2_weight_final
=
fc2_weight
if
fp8
or
debug
:
if
fp8
or
debug
:
# If weights are not quantized, we call get_weight_workspace,
# If weights are not quantized, we call get_weight_workspace,
# which handles weight caching etc.
# which handles weight caching etc.
# FP8 cast to workspace buffer
# FP8 cast to workspace buffer
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
fc1_weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
fc2_weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
fc1_weight_final
=
module
.
get_weight_workspace
(
fc1_weight_final
=
module
.
get_weight_workspace
(
tensor
=
fc1_weight
,
tensor
=
fc1_weight
,
quantizer
=
fc1_weight_quantizer
,
quantizer
=
fc1_weight_quantizer
,
...
@@ -339,7 +344,6 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -339,7 +344,6 @@ class _LayerNormMLP(torch.autograd.Function):
fsdp_group
=
fsdp_group
,
fsdp_group
=
fsdp_group
,
workspace_dtype
=
activation_dtype
,
workspace_dtype
=
activation_dtype
,
)
)
fc2_weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
fc2_weight_final
=
module
.
get_weight_workspace
(
fc2_weight_final
=
module
.
get_weight_workspace
(
tensor
=
fc2_weight
,
tensor
=
fc2_weight
,
quantizer
=
fc2_weight_quantizer
,
quantizer
=
fc2_weight_quantizer
,
...
@@ -349,6 +353,8 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -349,6 +353,8 @@ class _LayerNormMLP(torch.autograd.Function):
fsdp_group
=
fsdp_group
,
fsdp_group
=
fsdp_group
,
workspace_dtype
=
activation_dtype
,
workspace_dtype
=
activation_dtype
,
)
)
fc1_weight_final
.
update_usage
(
rowwise_usage
=
True
)
fc2_weight_final
.
update_usage
(
rowwise_usage
=
True
)
else
:
else
:
fc1_weight_final
=
cast_if_needed
(
fc1_weight_final
,
activation_dtype
)
fc1_weight_final
=
cast_if_needed
(
fc1_weight_final
,
activation_dtype
)
fc2_weight_final
=
cast_if_needed
(
fc2_weight_final
,
activation_dtype
)
fc2_weight_final
=
cast_if_needed
(
fc2_weight_final
,
activation_dtype
)
...
@@ -356,6 +362,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -356,6 +362,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Cast biases to expected dtype
# Cast biases to expected dtype
bias_dtype
=
activation_dtype
bias_dtype
=
activation_dtype
if
needs_quantized_gemm
(
ln_out_total
)
and
activation_dtype
==
torch
.
float32
:
if
needs_quantized_gemm
(
ln_out_total
)
and
activation_dtype
==
torch
.
float32
:
# cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16
bias_dtype
=
torch
.
bfloat16
bias_dtype
=
torch
.
bfloat16
if
fc1_bias
is
not
None
:
if
fc1_bias
is
not
None
:
fc1_bias
=
cast_if_needed
(
fc1_bias
,
bias_dtype
)
fc1_bias
=
cast_if_needed
(
fc1_bias
,
bias_dtype
)
...
@@ -369,7 +376,9 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -369,7 +376,9 @@ class _LayerNormMLP(torch.autograd.Function):
if
fc1_weight_quantizer
is
not
None
:
if
fc1_weight_quantizer
is
not
None
:
fc1_weight_quantizer
.
calibrate
(
fc1_weight
)
fc1_weight_quantizer
.
calibrate
(
fc1_weight
)
# ------------------------------------------------------
# FC1 GEMM
# FC1 GEMM
# ------------------------------------------------------
# There are 2 fusions possible:
# There are 2 fusions possible:
# - gemm_gelu_fusion - default for full precision, optional for fp8 - need to turn on gemm_gelu_fusion,
# - gemm_gelu_fusion - default for full precision, optional for fp8 - need to turn on gemm_gelu_fusion,
...
@@ -401,11 +410,16 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -401,11 +410,16 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_bias
if
not
bias_gelu_fusion
else
None
fc1_bias
if
not
bias_gelu_fusion
else
None
),
# otherwise bias is added later (fused with gelu)
),
# otherwise bias is added later (fused with gelu)
gelu
=
gemm_gelu_fusion
,
gelu
=
gemm_gelu_fusion
,
accumulate
=
_2X_ACC_FPROP
,
use_split_accumulator
=
use_split_accumulator
,
ub
=
ub_obj_lnout
,
ub
=
ub_obj_lnout
,
ub_type
=
tex
.
CommOverlapType
.
AG
if
ub_overlap_ag
else
None
,
ub_type
=
tex
.
CommOverlapType
.
AG
if
ub_overlap_ag
else
None
,
)
)
# ------------------------------------------------------
# Finished FC1 GEMM...
# ------------------------------------------------------
# Deallocate FC1 GEMM input tensor if no longer needed
if
not
is_grad_enabled
and
(
ln_out_total
is
not
ln_out_return
):
if
not
is_grad_enabled
and
(
ln_out_total
is
not
ln_out_return
):
clear_tensor_data
(
ln_out_total
)
clear_tensor_data
(
ln_out_total
)
...
@@ -439,45 +453,66 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -439,45 +453,66 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_input_quantizer
.
calibrate
(
act_out
)
fc2_input_quantizer
.
calibrate
(
act_out
)
fc2_weight_quantizer
.
calibrate
(
fc2_weight
)
fc2_weight_quantizer
.
calibrate
(
fc2_weight
)
# Configure Userbuffers reduce-scatter if needed
ub_obj_fc2out
=
None
ub_obj_fc2out
=
None
rs_out
=
None
reduce_scatter_out
=
None
fc2_out
=
None
if
ub_overlap_rs
:
if
ub_overlap_rs
:
ub_obj_fc2out
=
get_ub
(
"fc2_fprop"
)
ub_obj_fc2out
=
get_ub
(
"fc2_fprop"
)
dim_size
=
list
(
act_out
.
size
())
dim_size
=
list
(
act_out
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
//
tp_world_size
dim_size
[
0
]
//=
tp_world_size
dim_size
[
1
]
=
fc2_weight
.
size
(
0
)
dim_size
[
-
1
]
=
fc2_weight
.
size
(
0
)
rs_out
=
torch
.
empty
(
dim_size
,
dtype
=
activation_dtype
,
device
=
device
)
reduce_scatter_out
=
torch
.
empty
(
dim_size
,
dtype
=
activation_dtype
,
device
=
device
)
else
:
dim_size
=
list
(
act_out
.
size
())
dim_size
[
1
]
=
fc2_weight
.
size
(
0
)
fc2_out
=
torch
.
empty
(
dim_size
,
dtype
=
activation_dtype
,
device
=
device
)
# ------------------------------------------------------
# FC2 GEMM
# FC2 GEMM
_
=
general_gemm
(
# ------------------------------------------------------
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
fc2_weight_final
,
fc2_weight_final
,
act_out
,
act_out
,
get_workspace
(),
get_workspace
(),
out_dtype
=
activation_dtype
,
out_dtype
=
activation_dtype
,
bias
=
fc2_bias
,
bias
=
fc2_bias
,
quantization_params
=
fc2_output_quantizer
,
quantization_params
=
fc2_output_quantizer
,
out
=
fc2_out
,
use_split_accumulator
=
use_split_accumulator
,
use_split_accumulator
=
_2X_ACC_FPROP
,
ub
=
ub_obj_fc2out
,
ub
=
ub_obj_fc2out
,
ub_type
=
tex
.
CommOverlapType
.
RS
if
ub_overlap_rs
else
None
,
ub_type
=
tex
.
CommOverlapType
.
RS
if
ub_overlap_rs
else
None
,
extra_output
=
r
s
_out
,
extra_output
=
r
educe_scatter
_out
,
)
)
# ------------------------------------------------------
# Finished FC2 GEMM...
# ------------------------------------------------------
# Deallocate tensors if no longer needed
if
not
is_grad_enabled
:
clear_tensor_data
(
act_out
,
fc1_out_without_bias
,
fc1_out
)
# Prepare output tensor
# Note: Perform tensor-parallel communication if needed
fc2_out
=
None
if
ub_overlap_rs
:
fc2_out
=
reduce_scatter_out
elif
set_parallel_mode
and
sequence_parallel
:
fc2_out
,
_
=
reduce_scatter_along_first_dim
(
gemm_out
,
tp_group
)
elif
set_parallel_mode
and
tensor_parallel
:
if
symmetric_ar_type
is
not
None
:
fc2_out
,
_
=
symmetric_all_reduce
(
gemm_out
,
tp_group
,
all_reduce_type
=
symmetric_ar_type
)
else
:
fc2_out
,
_
=
allreduce
(
gemm_out
,
tp_group
)
else
:
fc2_out
=
gemm_out
fc2_out
=
fc2_out
.
view
(
-
1
,
*
inp_shape
[
1
:
-
1
],
fc2_out
.
shape
[
-
1
])
#
Weight with column-wise usage is needed for dgrad GEMM.
#
Cache state for backward pass
if
is_grad_enabled
:
if
is_grad_enabled
:
if
isinstance
(
fc1_weight_final
,
QuantizedTensor
):
# Weight with column-wise usage is needed for dgrad GEMM.
if
isinstance
(
fc1_weight_final
,
QuantizedTensorBase
):
fc1_weight_final
.
update_usage
(
columnwise_usage
=
True
)
fc1_weight_final
.
update_usage
(
columnwise_usage
=
True
)
if
isinstance
(
fc2_weight_final
,
QuantizedTensor
):
if
isinstance
(
fc2_weight_final
,
QuantizedTensor
Base
):
fc2_weight_final
.
update_usage
(
columnwise_usage
=
True
)
fc2_weight_final
.
update_usage
(
columnwise_usage
=
True
)
if
not
is_grad_enabled
:
clear_tensor_data
(
act_out
,
fc1_out_without_bias
,
fc1_out
)
else
:
if
cpu_offloading
:
if
cpu_offloading
:
mark_activation_offload
(
mark_activation_offload
(
inputmat
,
mu
,
rsigma
,
ln_out
,
fc1_out
,
fc1_out_without_bias
,
act_out
inputmat
,
mu
,
rsigma
,
ln_out
,
fc1_out
,
fc1_out_without_bias
,
act_out
...
@@ -504,8 +539,6 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -504,8 +539,6 @@ class _LayerNormMLP(torch.autograd.Function):
if
not
return_layernorm_output
:
if
not
return_layernorm_output
:
clear_tensor_data
(
ln_out
)
clear_tensor_data
(
ln_out
)
ln_out
=
None
ln_out
=
None
elif
force_hp_fc1_input_gather
:
assert
not
isinstance
(
ln_out
,
QuantizedTensor
)
if
not
fc2_weight
.
requires_grad
:
if
not
fc2_weight
.
requires_grad
:
clear_tensor_data
(
act_out
)
clear_tensor_data
(
act_out
)
act_out
=
None
act_out
=
None
...
@@ -592,28 +625,12 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -592,28 +625,12 @@ class _LayerNormMLP(torch.autograd.Function):
ctx
.
wgrad_store
=
wgrad_store
ctx
.
wgrad_store
=
wgrad_store
# Row Parallel Linear
if
ub_overlap_rs
:
fc2_out
=
rs_out
elif
set_parallel_mode
and
sequence_parallel
:
fc2_out
,
_
=
reduce_scatter_along_first_dim
(
fc2_out
,
tp_group
)
elif
set_parallel_mode
and
tensor_parallel
:
if
symmetric_ar_type
is
not
None
:
fc2_out
,
_
=
symmetric_all_reduce
(
fc2_out
,
tp_group
,
all_reduce_type
=
symmetric_ar_type
)
else
:
fc2_out
,
_
=
allreduce
(
fc2_out
,
tp_group
)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
fc2_out
=
fc2_out
.
view
(
-
1
,
*
inp_shape
[
1
:
-
1
],
fc2_out
.
shape
[
-
1
])
if
return_layernorm_output
:
if
return_layernorm_output
:
if
return_layernorm_output_gathered
:
if
return_layernorm_output_gathered
:
shape
=
list
(
inp_shape
)
shape
=
list
(
inp_shape
)
shape
[
0
]
*=
tp_size
shape
[
0
]
*=
tp_size
return
fc2_out
,
ln_out_return
.
view
(
shape
)
return
fc2_out
,
ln_out_return
.
view
(
shape
)
return
fc2_out
,
ln_out_return
.
view
_as
(
inp
)
return
fc2_out
,
ln_out_return
.
view
(
inp
_shape
)
return
fc2_out
return
fc2_out
@
staticmethod
@
staticmethod
...
@@ -622,24 +639,6 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -622,24 +639,6 @@ class _LayerNormMLP(torch.autograd.Function):
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
...]:
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
...]:
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
with
torch
.
cuda
.
nvtx
.
range
(
"_LayerNormMLP_backward"
):
with
torch
.
cuda
.
nvtx
.
range
(
"_LayerNormMLP_backward"
):
if
(
ctx
.
fp8
and
any
(
[
ctx
.
ub_overlap_ag
,
ctx
.
ub_overlap_rs_dgrad
,
ctx
.
ub_bulk_dgrad
,
ctx
.
ub_bulk_wgrad
,
]
)
and
(
ctx
.
fp8_recipe
is
not
None
)
):
if
not
ctx
.
fp8_recipe
.
float8_per_tensor_scaling
():
raise
NotImplementedError
(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
saved_tensors
=
ctx
.
saved_tensors
saved_tensors
=
ctx
.
saved_tensors
(
# pylint: disable=unbalanced-tuple-unpacking
(
# pylint: disable=unbalanced-tuple-unpacking
inputmat
,
inputmat
,
...
@@ -699,6 +698,16 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -699,6 +698,16 @@ class _LayerNormMLP(torch.autograd.Function):
# fc2_weight_fp8 if ctx.fp8 and not isinstance(fc2_weight, Float8Tensor) else None,
# fc2_weight_fp8 if ctx.fp8 and not isinstance(fc2_weight, Float8Tensor) else None,
# )
# )
# Choose whether to use GEMM kernel with split accumulator
dgrad_use_split_accumulator
=
_2X_ACC_DGRAD
wgrad_use_split_accumulator
=
_2X_ACC_WGRAD
if
ctx
.
fp8
:
recipe
=
ctx
.
fp8_recipe
if
hasattr
(
recipe
,
"fp8_gemm_dgrad"
):
dgrad_use_split_accumulator
=
recipe
.
fp8_gemm_dgrad
.
use_split_accumulator
if
hasattr
(
recipe
,
"fp8_gemm_wgrad"
):
wgrad_use_split_accumulator
=
recipe
.
fp8_gemm_wgrad
.
use_split_accumulator
# No need to do bulk DGRAD/WGRAD overlap if WGRAD is not required
# No need to do bulk DGRAD/WGRAD overlap if WGRAD is not required
ctx
.
ub_bulk_dgrad
=
ctx
.
fc1_weight_requires_grad
and
ctx
.
ub_bulk_dgrad
ctx
.
ub_bulk_dgrad
=
ctx
.
fc1_weight_requires_grad
and
ctx
.
ub_bulk_dgrad
ctx
.
ub_bulk_wgrad
=
ctx
.
fc1_weight_requires_grad
and
ctx
.
ub_bulk_wgrad
ctx
.
ub_bulk_wgrad
=
ctx
.
fc1_weight_requires_grad
and
ctx
.
ub_bulk_wgrad
...
@@ -707,20 +716,13 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -707,20 +716,13 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
# requires column-wise usage
if
ctx
.
fc2_grad_output_quantizer
is
not
None
:
if
ctx
.
fc2_grad_output_quantizer
is
not
None
:
rowwise_usage
=
True
quantizer
=
ctx
.
fc2_grad_output_quantizer
columnwise_usage
=
True
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
if
ctx
.
ub_overlap_ag
and
isinstance
(
if
ctx
.
ub_overlap_ag
:
ctx
.
fc2_grad_output_quantizer
,
# Userbuffers only supports communication for one
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
),
# tensor usage at a time. Configure quantizer with
):
# usage for only dgrad GEMM.
# If data is in FP8 and communication is handled
quantizer
.
set_usage
(
columnwise
=
False
)
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage
=
False
ctx
.
fc2_grad_output_quantizer
.
set_usage
(
rowwise
=
rowwise_usage
,
columnwise
=
columnwise_usage
,
)
# Prepare FC2 grad output tensor
# Prepare FC2 grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# Note: Cast to expected dtype and perform tensor-parallel communication
...
@@ -738,14 +740,10 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -738,14 +740,10 @@ class _LayerNormMLP(torch.autograd.Function):
# Launch tensor-parallel communication for FC1 GEMM input
# Launch tensor-parallel communication for FC1 GEMM input
ln_out_total
=
None
ln_out_total
=
None
ln_out_total_work
=
None
ln_out_total_work
=
None
if
(
ub_obj_fc1_dgrad
=
None
ctx
.
fc1_weight_requires_grad
if
ctx
.
fc1_weight_requires_grad
and
ctx
.
tensor_parallel
and
ctx
.
sequence_parallel
:
and
ctx
.
tensor_parallel
and
ctx
.
sequence_parallel
and
not
ctx
.
ub_bulk_dgrad
):
quantizer
=
None
quantizer
=
None
if
ctx
.
fp8
or
ctx
.
debug
:
if
ctx
.
fp8
or
ctx
.
debug
and
not
ctx
.
force_hp_fc1_input_gather
:
quantizer
=
ctx
.
fc1_input_quantizer
quantizer
=
ctx
.
fc1_input_quantizer
if
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)):
if
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)):
# If data is in FP8, we compute FP8 transposes manually
# If data is in FP8, we compute FP8 transposes manually
...
@@ -753,13 +751,21 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -753,13 +751,21 @@ class _LayerNormMLP(torch.autograd.Function):
else
:
else
:
# wgrad GEMM requires input with column-wise usage
# wgrad GEMM requires input with column-wise usage
quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
gather_quantizer
=
None
if
ctx
.
force_hp_fc1_input_gather
else
quantizer
if
ctx
.
ub_bulk_dgrad
:
ln_out_total
,
ln_out_total_work
=
gather_along_first_dim
(
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
)
ln_out
,
ln_out_total
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ctx
.
tp_group
,
ub_obj_fc1_dgrad
,
async_op
=
True
,
ln_out
,
quantizer
=
gather_quantizer
,
quantizer
,
)
ctx
.
tp_group
,
)
else
:
ln_out_total
,
ln_out_total_work
=
gather_along_first_dim
(
ln_out
,
ctx
.
tp_group
,
async_op
=
True
,
quantizer
=
quantizer
,
)
else
:
else
:
ln_out_total
=
ln_out
ln_out_total
=
ln_out
...
@@ -770,6 +776,11 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -770,6 +776,11 @@ class _LayerNormMLP(torch.autograd.Function):
)
)
else
:
else
:
accumulate_wgrad_into_param_main_grad
=
ctx
.
fuse_wgrad_accumulation
accumulate_wgrad_into_param_main_grad
=
ctx
.
fuse_wgrad_accumulation
# --------------------------------------------------
# FC2 DGRAD
# --------------------------------------------------
# There are 6 possible fusion paths
# There are 6 possible fusion paths
# 1 high-precision bias_gelu_fusion: gemm, FC1_bias + gelu,
# 1 high-precision bias_gelu_fusion: gemm, FC1_bias + gelu,
# 2 high-precision fc2_dgrad_gemm_gelu_fusion: gemm + gelu, FC1_bias + quantize
# 2 high-precision fc2_dgrad_gemm_gelu_fusion: gemm + gelu, FC1_bias + quantize
...
@@ -784,12 +795,15 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -784,12 +795,15 @@ class _LayerNormMLP(torch.autograd.Function):
and
(
not
ctx
.
debug
)
and
(
not
ctx
.
debug
)
)
)
# FC2 DGRAD; Unconditional
# Make sure required data is available
if
ctx
.
fc2_weight_quantizer
is
not
None
and
isinstance
(
ctx
.
fc2_weight
,
QuantizedTensor
):
if
isinstance
(
grad_output
,
QuantizedTensorBase
):
ctx
.
fc2_weight
.
update_usage
(
grad_output
.
update_usage
(
rowwise_usage
=
True
)
rowwise_usage
=
ctx
.
fc2_weight_quantizer
.
rowwise_usage
,
if
ctx
.
fc2_weight_quantizer
is
not
None
and
isinstance
(
columnwise_usage
=
ctx
.
fc2_weight_quantizer
.
columnwise_usage
,
ctx
.
fc2_weight
,
QuantizedTensorBase
)
):
ctx
.
fc2_weight
.
update_usage
(
columnwise_usage
=
True
)
# Perform GEMM
gemm_output
,
*
_
=
general_gemm
(
gemm_output
,
*
_
=
general_gemm
(
fc2_weight
,
fc2_weight
,
grad_output
,
grad_output
,
...
@@ -804,52 +818,107 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -804,52 +818,107 @@ class _LayerNormMLP(torch.autograd.Function):
out_dtype
=
ctx
.
activation_dtype
,
out_dtype
=
ctx
.
activation_dtype
,
gelu
=
fc2_dgrad_gemm_gelu_fusion
,
gelu
=
fc2_dgrad_gemm_gelu_fusion
,
gelu_in
=
fc1_out
if
fc2_dgrad_gemm_gelu_fusion
else
None
,
gelu_in
=
fc1_out
if
fc2_dgrad_gemm_gelu_fusion
else
None
,
use_split_accumulator
=
_2X_ACC_DGRAD
,
use_split_accumulator
=
dgrad_use_split_accumulator
,
ub
=
ub_obj_fc2_dgrad
,
ub
=
ub_obj_fc2_dgrad
,
ub_type
=
tex
.
CommOverlapType
.
AG
if
ctx
.
ub_overlap_ag
else
None
,
ub_type
=
tex
.
CommOverlapType
.
AG
if
ctx
.
ub_overlap_ag
else
None
,
)
)
# Prepare input grad tensor
dact
=
None
fc2_dgrad
=
None
if
fc2_dgrad_gemm_gelu_fusion
:
if
fc2_dgrad_gemm_gelu_fusion
:
dact
=
gemm_output
dact
=
gemm_output
fc2_dgrad
=
None
else
:
else
:
fc2_dgrad
=
gemm_output
fc2_dgrad
=
gemm_output
# --------------------------------------------------
# Finished FC2 DGRAD...
# --------------------------------------------------
# --------------------------------------------------
# FC2 WGRAD
# FC2 WGRAD
# --------------------------------------------------
fc2_wgrad
=
None
if
ctx
.
fc2_weight_requires_grad
:
if
ctx
.
fc2_weight_requires_grad
:
if
isinstance
(
act_out
,
QuantizedTensor
):
act_out
.
update_usage
(
rowwise_usage
=
True
,
columnwise_usage
=
True
)
if
isinstance
(
grad_output
,
QuantizedTensor
):
# Prepare input tensor
grad_output
.
update_usage
(
rowwise_usage
=
True
,
columnwise_usage
=
True
)
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
act_out
,
QuantizedTensorBase
):
act_out
.
update_usage
(
columnwise_usage
=
True
)
else
:
ctx
.
fc2_input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
act_out
=
ctx
.
fc2_input_quantizer
(
act_out
)
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if
ctx
.
ub_overlap_ag
and
isinstance
(
ctx
.
fc2_grad_output_quantizer
,
MXFP8Quantizer
):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx
.
fc2_grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
grad_output
,
_
=
gather_along_first_dim
(
grad_outputs
[
0
],
ctx
.
tp_group
,
quantizer
=
ctx
.
fc2_grad_output_quantizer
,
)
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
grad_output
,
QuantizedTensorBase
):
grad_output
.
update_usage
(
columnwise_usage
=
True
)
else
:
ctx
.
fc2_grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
grad_output
=
ctx
.
fc2_grad_output_quantizer
(
grad_output
)
# Whether to set grad arg in general_gemm
grad_arg
=
True
grad_arg
=
True
if
ctx
.
fp8
and
ctx
.
fp8_recipe
.
float8_block_scaling
():
if
ctx
.
fp8
and
ctx
.
fp8_recipe
.
float8_block_scaling
():
grad_arg
=
False
grad_arg
=
False
general_gemm_fc2_wgrad
=
functools
.
partial
(
general_gemm
,
# Arguments to include in wgrad GEMM closure
out_dtype
=
(
fc2_wgrad_gemm_kwargs
=
{
"workspace"
:
get_workspace
(),
"out_dtype"
:
(
origin_fc2_weight
.
main_grad
.
dtype
origin_fc2_weight
.
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
else
ctx
.
activation_dtype
),
),
workspace
=
get_workspace
(),
"quantization_params"
:
ctx
.
fc2_grad_weight_quantizer
,
# wgrad in high precision
quantization_params
=
ctx
.
fc2_grad_weight_quantizer
,
# wgrad in high precision
"accumulate"
:
accumulate_wgrad_into_param_main_grad
,
layout
=
"NT"
,
"layout"
:
"NT"
,
grad
=
grad_arg
,
"out"
:
origin_fc2_weight
.
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
bias
=
fc2_bias
if
fc2_bias
is
not
None
and
fc2_bias_grad
is
None
else
None
,
"bias"
:
fc2_bias
if
fc2_bias
is
not
None
and
fc2_bias_grad
is
None
else
None
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
"use_split_accumulator"
:
wgrad_use_split_accumulator
,
use_split_accumulator
=
_2X_ACC_WGRAD
,
"grad"
:
grad_arg
,
out
=
origin_fc2_weight
.
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
}
)
def
fc2_wgrad_gemm
(
x
:
torch
.
Tensor
,
dy
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Perform FC2 WGRAD GEMM
May be called outside of this function to enable
some advanced communication/compute overlapping.
"""
dw
,
db
,
*
_
=
general_gemm
(
x
,
dy
,
**
fc2_wgrad_gemm_kwargs
)
return
dw
,
db
# Choose whether to call wgrad GEMM now or delay
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
ctx
.
wgrad_store
.
put
([
act_out
,
grad_output
],
general_gemm_fc2_wgrad
)
ctx
.
wgrad_store
.
put
([
act_out
,
grad_output
],
fc2_wgrad_gemm
)
fc2_wgrad
=
None
else
:
else
:
fc2_wgrad
,
fc2_bias_grad_
,
*
_
=
general_gemm_fc2_wgrad
(
act_out
,
grad_output
,
)
# Call wgrad GEMM now
fc2_wgrad
,
fc2_bias_grad_
=
fc2_wgrad_gemm
(
act_out
,
grad_output
)
# Update grad bias if needed
if
fc2_bias_grad
is
None
:
if
fc2_bias_grad
is
None
:
if
(
if
(
ctx
.
fp8
ctx
.
fp8
...
@@ -858,12 +927,17 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -858,12 +927,17 @@ class _LayerNormMLP(torch.autograd.Function):
):
):
# BGRAD not fused with GEMM for float8 blockwise gemm.
# BGRAD not fused with GEMM for float8 blockwise gemm.
fc2_bias_grad_
=
act_out
.
view
(
-
1
,
act_out
.
shape
[
-
1
]).
sum
(
dim
=
0
)
fc2_bias_grad_
=
act_out
.
view
(
-
1
,
act_out
.
shape
[
-
1
]).
sum
(
dim
=
0
)
fc2_bias_grad
=
fc2_bias_grad_
fc2_bias_grad
=
fc2_bias_grad_
del
fc2_bias_grad_
del
fc2_bias_grad_
# Deallocate input tensor if permitted
if
ctx
.
wgrad_store
is
not
None
and
not
ctx
.
wgrad_store
.
delay_wgrad_compute
():
if
ctx
.
wgrad_store
is
not
None
and
not
ctx
.
wgrad_store
.
delay_wgrad_compute
():
clear_tensor_data
(
act_out
)
clear_tensor_data
(
act_out
)
# --------------------------------------------------
# Finished FC2 WGRAD...
# --------------------------------------------------
# bias computation
# bias computation
fc1_bias_grad
=
None
fc1_bias_grad
=
None
fuse_gemm_and_bias_fc1_wgrad
=
False
fuse_gemm_and_bias_fc1_wgrad
=
False
...
@@ -927,63 +1001,69 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -927,63 +1001,69 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_fc1_dgrad
=
None
ub_obj_fc1_dgrad
=
None
ub_obj_fc1_wgrad
=
None
ub_obj_fc1_wgrad
=
None
ub_type_fc1_dgrad
=
None
ub_type_fc1_dgrad
=
None
ub_type_fc1_wgrad
=
None
fc1_dgrad_shape
=
[
reduce
(
multiply_op
,
inputmat
.
shape
[:
-
1
]),
inputmat
.
shape
[
-
1
]]
fc1_dgrad_shape
=
[
reduce
(
multiply_op
,
inputmat
.
shape
[:
-
1
]),
inputmat
.
shape
[
-
1
]]
fc1_dgrad_rs_out
=
None
fc1_dgrad_bulk
=
None
if
ctx
.
ub_overlap_rs_dgrad
:
if
ctx
.
ub_overlap_rs_dgrad
:
# Overlap DGRAD+RS
# Overlap DGRAD+RS
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
)
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
)
ub_type_fc1_dgrad
=
tex
.
CommOverlapType
.
RS
ub_type_fc1_dgrad
=
tex
.
CommOverlapType
.
RS
fc1_dgrad_rs_out
=
torch
.
empty
(
fc1_dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
"cuda"
)
else
:
else
:
if
ctx
.
ub_bulk_dgrad
:
if
ctx
.
ub_bulk_dgrad
:
# Overlap ln_out all-gather with DGRAD compute
# Overlap ln_out all-gather with DGRAD compute
# NOTE: Copying into communication buffer will always prefer rowwise data,
# and will copy columnwise data if rowwise does not exist. In that case,
# the all-gather will apply to the leading dimension of the transpose,
# which then needs to be interleaved correctly before WGRAD.
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
)
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
)
ub_type_fc1_dgrad
=
tex
.
CommOverlapType
.
AG
ub_type_fc1_dgrad
=
tex
.
CommOverlapType
.
AG
ub_obj_fc1_dgrad
.
copy_into_buffer
(
ln_out
,
ctx
.
fc1_input_quantizer
,
local_chunk
=
True
)
if
ctx
.
ub_bulk_wgrad
:
if
ctx
.
ub_bulk_wgrad
:
# Overlap FC1 DGRAD reduce-scatter with WGRAD compute
# Overlap FC1 DGRAD reduce-scatter with WGRAD compute
ub_obj_fc1_wgrad
=
get_ub
(
"fc1_wgrad"
)
ub_obj_fc1_wgrad
=
get_ub
(
"fc1_wgrad"
)
fc1_dgrad_bulk
=
ub_obj_fc1_wgrad
.
get_buffer
(
None
)
ub_type_fc1_wgrad
=
tex
.
CommOverlapType
.
RS
# FC1 DGRAD: Unconditional
# --------------------------------------------------
# FC1 DGRAD
# --------------------------------------------------
# Make sure required data is available
if
ctx
.
fc1_weight_quantizer
is
not
None
and
isinstance
(
if
ctx
.
fc1_weight_quantizer
is
not
None
and
isinstance
(
ctx
.
fc1_weight_quantizer
,
QuantizedTensor
ctx
.
fc1_weight_quantizer
,
QuantizedTensor
Base
):
):
ctx
.
fc1_weight
.
update_usage
(
ctx
.
fc1_weight
.
update_usage
(
columnwise_usage
=
True
)
rowwise_usage
=
ctx
.
fc1_weight_quantizer
.
rowwise_usage
,
columnwise_usage
=
ctx
.
fc1_weight_quantizer
.
columnwise_usage
,
# Output buffers for Userbuffers reduce-scatter
gemm_out
=
None
reduce_scatter_out
=
None
if
ctx
.
ub_overlap_rs_dgrad
:
reduce_scatter_out
=
torch
.
empty
(
fc1_dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
"cuda"
)
)
fc1_dgrad
,
*
_
,
fc1_dgrad_rs_out
=
general_gemm
(
if
ctx
.
ub_bulk_wgrad
:
gemm_out
=
ub_obj_fc1_wgrad
.
get_buffer
(
local_chunk
=
False
)
# dgrad GEMM
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
fc1_weight
,
fc1_weight
,
dact
,
dact
,
get_workspace
(),
get_workspace
(),
out
=
fc1_dgrad_bulk
,
out
=
gemm_out
,
out_dtype
=
ctx
.
activation_dtype
,
out_dtype
=
ctx
.
activation_dtype
,
quantization_params
=
ctx
.
fc1_grad_input_quantizer
,
quantization_params
=
ctx
.
fc1_grad_input_quantizer
,
layout
=
"NN"
,
layout
=
"NN"
,
grad
=
True
,
grad
=
True
,
use_split_accumulator
=
dgrad_use_split_accumulator
,
ub
=
ub_obj_fc1_dgrad
,
ub
=
ub_obj_fc1_dgrad
,
ub_type
=
ub_type_fc1_dgrad
,
ub_type
=
ub_type_fc1_dgrad
,
extra_output
=
fc1_dgrad_rs
_out
,
extra_output
=
reduce_scatter
_out
,
bulk_overlap
=
ctx
.
ub_bulk_dgrad
,
bulk_overlap
=
ctx
.
ub_bulk_dgrad
,
)
)
# Overlap dgrad-RS/AR with wgrad
# Prepare grad input tensor
# Note: Perform tensor-parallel communication
fc1_dgrad
=
None
fc1_dgrad_work
=
None
fc1_dgrad_work
=
None
if
ctx
.
ub_overlap_rs_dgrad
:
if
ctx
.
ub_overlap_rs_dgrad
:
fc1_dgrad
=
fc1_dgrad_rs_out
fc1_dgrad
=
reduce_scatter_out
elif
ctx
.
ub_bulk_wgrad
:
fc1_dgrad
=
ub_obj_fc1_wgrad
.
get_buffer
(
local_chunk
=
True
)
elif
ctx
.
set_parallel_mode
and
not
ctx
.
ub_bulk_wgrad
:
elif
ctx
.
set_parallel_mode
and
not
ctx
.
ub_bulk_wgrad
:
fc1_dgrad
=
gemm_out
if
ctx
.
sequence_parallel
:
if
ctx
.
sequence_parallel
:
if
ctx
.
return_layernorm_output
and
ctx
.
return_layernorm_output_gathered
:
if
ctx
.
return_layernorm_output
and
ctx
.
return_layernorm_output_gathered
:
fc1_dgrad
=
fc1_dgrad
+
grad_outputs
[
1
].
view_as
(
fc1_dgrad
)
fc1_dgrad
=
fc1_dgrad
+
grad_outputs
[
1
].
view_as
(
fc1_dgrad
)
...
@@ -994,90 +1074,125 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -994,90 +1074,125 @@ class _LayerNormMLP(torch.autograd.Function):
)
)
elif
ctx
.
tensor_parallel
:
elif
ctx
.
tensor_parallel
:
fc1_dgrad
,
fc1_dgrad_work
=
allreduce
(
fc1_dgrad
,
ctx
.
tp_group
,
async_op
=
True
)
fc1_dgrad
,
fc1_dgrad_work
=
allreduce
(
fc1_dgrad
,
ctx
.
tp_group
,
async_op
=
True
)
else
:
fc1_dgrad
=
gemm_out
# --------------------------------------------------
# Finished FC1 DGRAD...
# --------------------------------------------------
# --------------------------------------------------
# FC1 WGRAD
# FC1 WGRAD
# --------------------------------------------------
fc1_wgrad
=
None
fc1_wgrad
=
None
if
ctx
.
fc1_weight_requires_grad
:
if
ctx
.
fc1_weight_requires_grad
:
# Synchronize tensor-parallel communication for FC1 GEMM input tensor
# Prepare input tensor
if
ctx
.
ub_bulk_dgrad
:
# Note: Synchronize tensor-parallel communication and
ln_out_total
=
ub_obj_fc1_dgrad
.
get_buffer
(
ctx
.
fc1_input_quantizer
)
# make sure required data is available
if
ctx
.
fp8
:
if
ln_out
.
_data
is
None
:
# All-gather executed on columnwise data and result is in rowwise data,
# so we need to fix the interleaving before WGRAD.
ln_out_total
=
_fix_gathered_fp8_transpose
(
ln_out_total
,
ctx
.
tp_size
)
elif
not
non_tn_fp8_gemm_supported
():
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
ln_out_total
.
_create_transpose
()
if
ln_out_total_work
is
not
None
:
if
ln_out_total_work
is
not
None
:
ln_out_total_work
.
wait
()
ln_out_total_work
.
wait
()
ln_out_total_work
=
None
ln_out_total_work
=
None
if
ctx
.
fc1_input_quantizer
is
not
None
and
not
isinstance
(
if
ctx
.
fp8
or
ctx
.
debug
:
ln_out_total
,
QuantizedTensor
if
isinstance
(
ln_out_total
,
QuantizedTensorBase
):
):
ln_out_total
.
update_usage
(
columnwise_usage
=
True
)
# Async gather in BF16 does not asynchronously
else
:
# call quantizer after gather.
ctx
.
fc1_input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ctx
.
fc1_input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ln_out_total
=
ctx
.
fc1_input_quantizer
(
ln_out_total
)
ln_out_total
=
ctx
.
fc1_input_quantizer
(
ln_out_total
)
# Prepare grad output tensor
# Make sure GEMM inputs have required data
# Note: Synchronize tensor-parallel communication and
if
isinstance
(
ln_out_total
,
QuantizedTensor
):
# make sure required data is available
ln_out_total
.
update_usage
(
columnwise_usage
=
True
)
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
dact
,
QuantizedTensor
):
if
isinstance
(
dact
,
QuantizedTensorBase
):
dact
.
update_usage
(
columnwise_usage
=
True
)
dact
.
update_usage
(
columnwise_usage
=
True
)
else
:
ctx
.
fc1_grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
dact
=
ctx
.
fc1_grad_output_quantizer
(
dact
)
# Output buffer for overlapping grad input
# Output buffer for overlapping grad input
# reduce-scatter with wgrad GEMM
# reduce-scatter with wgrad GEMM
reduce_scatter_out
=
None
if
ctx
.
ub_bulk_wgrad
and
ub_obj_fc1_wgrad
.
is_fp8_ubuf
():
if
ctx
.
ub_bulk_wgrad
and
ub_obj_fc1_wgrad
.
is_fp8_ubuf
():
fc1_dgrad_rs
_out
=
torch
.
empty
(
reduce_scatter
_out
=
torch
.
empty
(
fc1_dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
"cuda"
fc1_dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
"cuda"
)
)
# wgrad GEMM
#
Arguments to include in
wgrad GEMM
closure
gene
ra
l
_gemm_
fc1_wgrad
=
functools
.
partial
(
fc1_wg
ra
d
_gemm_
kwargs
=
{
general_gemm
,
"workspace"
:
get_workspace
()
,
out_dtype
=
(
"
out_dtype
"
:
(
origin_fc1_weight
.
main_grad
.
dtype
origin_fc1_weight
.
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
else
ctx
.
activation_dtype
),
),
workspace
=
get_workspace
(),
"quantization_params"
:
ctx
.
fc1_grad_weight_quantizer
,
layout
=
"NT"
,
"accumulate"
:
accumulate_wgrad_into_param_main_grad
,
quantization_params
=
ctx
.
fc1_grad_weight_quantizer
,
"layout"
:
"NT"
,
grad
=
fuse_gemm_and_bias_fc1_wgrad
,
"out"
:
origin_fc1_weight
.
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
bias
=
fc1_bias
if
fuse_gemm_and_bias_fc1_wgrad
else
None
,
"bias"
:
fc1_bias
if
fuse_gemm_and_bias_fc1_wgrad
else
None
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
"use_split_accumulator"
:
wgrad_use_split_accumulator
,
out
=
origin_fc1_weight
.
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
"grad"
:
fuse_gemm_and_bias_fc1_wgrad
,
ub
=
ub_obj_fc1_wgrad
,
"ub"
:
ub_obj_fc1_wgrad
,
ub_type
=
tex
.
CommOverlapType
.
RS
if
ctx
.
ub_bulk_wgrad
else
None
,
"ub_type"
:
ub_type_fc1_wgrad
,
extra_output
=
fc1_dgrad_rs_out
,
"extra_output"
:
reduce_scatter_out
,
bulk_overlap
=
ctx
.
ub_bulk_wgrad
,
"bulk_overlap"
:
ctx
.
ub_bulk_wgrad
,
)
}
def
fc1_wgrad_gemm
(
x
:
torch
.
Tensor
,
dy
:
torch
.
Tensor
,
_is_delayed
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Perform FC1 WGRAD GEMM
May be called outside of this function to enable
some advanced communication/compute overlapping.
"""
dw
,
db
,
*
_
=
general_gemm
(
x
,
dy
,
**
fc1_wgrad_gemm_kwargs
)
return
dw
,
db
# Choose whether to call wgrad GEMM now or delay
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
ctx
.
wgrad_store
.
put
([
ln_out_total
,
dact
],
general_gemm_fc1_wgrad
)
if
(
fc1_wgrad_gemm_kwargs
[
"ub"
]
is
not
None
or
fc1_wgrad_gemm_kwargs
[
"ub_type"
]
is
not
None
or
fc1_wgrad_gemm_kwargs
[
"extra_output"
]
is
not
None
or
fc1_wgrad_gemm_kwargs
[
"bulk_overlap"
]
):
raise
NotImplementedError
(
"Delayed weight grad computation is not supported "
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx
.
wgrad_store
.
put
([
ln_out_total
,
dact
],
fc1_wgrad_gemm
)
fc1_wgrad
=
None
fc1_wgrad
=
None
if
fuse_gemm_and_bias_fc1_wgrad
:
if
fuse_gemm_and_bias_fc1_wgrad
:
fc1_bias_grad
=
None
fc1_bias_grad
=
None
else
:
else
:
fc1_wgrad_outputs
=
general_gemm_fc1_wgrad
(
ln_out_total
,
dact
,
)
clear_tensor_data
(
ln_out_total
,
dact
)
# Call wgrad GEMM now
fc1_wgrad_outputs
=
fc1_wgrad_gemm
(
ln_out_total
,
dact
)
if
fuse_gemm_and_bias_fc1_wgrad
:
if
fuse_gemm_and_bias_fc1_wgrad
:
fc1_wgrad
,
fc1_bias_grad
,
*
_
=
fc1_wgrad_outputs
fc1_wgrad
,
fc1_bias_grad
=
fc1_wgrad_outputs
else
:
else
:
fc1_wgrad
,
*
_
=
fc1_wgrad_outputs
fc1_wgrad
,
_
=
fc1_wgrad_outputs
# Deallocate tensors if permitted
clear_tensor_data
(
dact
)
if
not
ctx
.
return_layernorm_output_gathered
:
clear_tensor_data
(
ln_out_total
)
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if
ctx
.
ub_bulk_wgrad
:
if
ctx
.
ub_bulk_wgrad
:
if
ub_obj_fc1_wgrad
.
is_fp8_ubuf
():
if
ub_obj_fc1_wgrad
.
is_fp8_ubuf
():
fc1_dgrad
=
fc1_dgrad_rs
_out
fc1_dgrad
=
reduce_scatter
_out
else
:
else
:
fc1_dgrad
=
ub_obj_fc1_wgrad
.
get_buffer
(
None
,
local_chunk
=
True
)
fc1_dgrad
=
ub_obj_fc1_wgrad
.
get_buffer
(
local_chunk
=
True
).
clone
()
# --------------------------------------------------
# Finished FC1 WGRAD...
# --------------------------------------------------
# Make sure all tensor-parallel communication is finished
# Make sure all tensor-parallel communication is finished
if
ln_out_total_work
is
not
None
:
if
ln_out_total_work
is
not
None
:
...
@@ -1748,7 +1863,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1748,7 +1863,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
)
=
[
None
]
*
12
)
=
[
None
]
*
12
if
self
.
fp8
:
if
self
.
fp8
:
fc1_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
fc1_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
fc1_input_quantizer
.
internal
=
False
# temporary
fc1_input_quantizer
.
internal
=
True
fc1_weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
fc1_weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
fc1_weight_quantizer
.
internal
=
True
fc1_weight_quantizer
.
internal
=
True
fc2_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_INPUT
]
fc2_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_INPUT
]
...
@@ -1756,6 +1871,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1756,6 +1871,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
rowwise
=
True
,
rowwise
=
True
,
columnwise
=
isinstance
(
fc2_input_quantizer
,
(
MXFP8Quantizer
,
Float8BlockQuantizer
)),
columnwise
=
isinstance
(
fc2_input_quantizer
,
(
MXFP8Quantizer
,
Float8BlockQuantizer
)),
)
)
fc1_input_quantizer
.
internal
=
True
fc2_weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_WEIGHT
]
fc2_weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_WEIGHT
]
fc2_weight_quantizer
.
internal
=
True
fc2_weight_quantizer
.
internal
=
True
if
fp8_output
:
if
fp8_output
:
...
@@ -1764,11 +1880,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1764,11 +1880,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
]
]
if
torch
.
is_grad_enabled
():
if
torch
.
is_grad_enabled
():
fc2_grad_output_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
fc2_grad_output_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT
1
tex
.
FP8BwdTensors
.
GRAD_OUTPUT
2
]
]
fc2_grad_output_quantizer
.
internal
=
True
fc2_grad_output_quantizer
.
internal
=
True
fc1_grad_output_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
fc1_grad_output_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_
IN
PUT1
tex
.
FP8BwdTensors
.
GRAD_
OUT
PUT1
]
]
fc1_grad_output_quantizer
.
internal
=
True
fc1_grad_output_quantizer
.
internal
=
True
...
@@ -1853,25 +1969,25 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1853,25 +1969,25 @@ class LayerNormMLP(TransformerEngineBaseModule):
else
:
else
:
# fc2_grad_output_quantizer: set configs about amax epsilon and power_2_scale for fc2_grad_output_quantizer
# fc2_grad_output_quantizer: set configs about amax epsilon and power_2_scale for fc2_grad_output_quantizer
self
.
quantizers
[
"scaling_bwd"
][
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT
1
tex
.
FP8BwdTensors
.
GRAD_OUTPUT
2
].
force_pow_2_scales
=
recipe
.
fp8_quant_bwd_grad
.
power_2_scale
].
force_pow_2_scales
=
recipe
.
fp8_quant_bwd_grad
.
power_2_scale
self
.
quantizers
[
"scaling_bwd"
][
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT
1
tex
.
FP8BwdTensors
.
GRAD_OUTPUT
2
].
amax_epsilon
=
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
].
amax_epsilon
=
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
# fc1_grad_output_quantizer: also set numerical configs for fc1_grad_output_quantizer
# fc1_grad_output_quantizer: also set numerical configs for fc1_grad_output_quantizer
self
.
quantizers
[
"scaling_bwd"
][
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_
IN
PUT1
tex
.
FP8BwdTensors
.
GRAD_
OUT
PUT1
].
force_pow_2_scales
=
recipe
.
fp8_quant_bwd_grad
.
power_2_scale
].
force_pow_2_scales
=
recipe
.
fp8_quant_bwd_grad
.
power_2_scale
self
.
quantizers
[
"scaling_bwd"
][
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_
IN
PUT1
tex
.
FP8BwdTensors
.
GRAD_
OUT
PUT1
].
amax_epsilon
=
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
].
amax_epsilon
=
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
if
self
.
sequence_parallel
and
self
.
set_parallel_mode
:
if
self
.
sequence_parallel
and
self
.
set_parallel_mode
:
# fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
# fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
self
.
quantizers
[
"scaling_bwd"
][
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT
1
tex
.
FP8BwdTensors
.
GRAD_OUTPUT
2
].
with_amax_reduction
=
True
].
with_amax_reduction
=
True
self
.
quantizers
[
"scaling_bwd"
][
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT
1
tex
.
FP8BwdTensors
.
GRAD_OUTPUT
2
].
amax_reduction_group
=
self
.
tp_group
].
amax_reduction_group
=
self
.
tp_group
def
backward_dw
(
self
):
def
backward_dw
(
self
):
...
...
transformer_engine/pytorch/module/linear.py
View file @
f8c2af4c
...
@@ -6,24 +6,26 @@
...
@@ -6,24 +6,26 @@
from
typing
import
Callable
,
Dict
,
Optional
,
Tuple
,
Union
from
typing
import
Callable
,
Dict
,
Optional
,
Tuple
,
Union
from
functools
import
reduce
from
functools
import
reduce
from
operator
import
mul
as
multiply_op
from
operator
import
mul
as
multiply_op
import
warnings
import
functools
import
torch
import
torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
import
torch_version
from
.base
import
(
from
.base
import
(
get_workspace
,
fill_userbuffers_buffer_for_all_gather
,
get_dummy_wgrad
,
get_ub
,
get_ub
,
get_workspace
,
TransformerEngineBaseModule
,
TransformerEngineBaseModule
,
get_dummy_wgrad
,
_2X_ACC_FPROP
,
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
_2X_ACC_WGRAD
,
)
)
from
._common
import
noop_cat
,
_fix_gathered_fp8_transpose
,
WeightGradStore
from
._common
import
noop_cat
,
WeightGradStore
from
..fp8
import
FP8GlobalStateManager
from
..fp8
import
FP8GlobalStateManager
from
..utils
import
(
from
..utils
import
(
cast_if_needed
,
cast_if_needed
,
...
@@ -32,7 +34,6 @@ from ..utils import (
...
@@ -32,7 +34,6 @@ from ..utils import (
init_method_constant
,
init_method_constant
,
requires_grad
,
requires_grad
,
needs_quantized_gemm
,
needs_quantized_gemm
,
non_tn_fp8_gemm_supported
,
assert_dim_for_fp8_exec
,
assert_dim_for_fp8_exec
,
nvtx_range_pop
,
nvtx_range_pop
,
nvtx_range_push
,
nvtx_range_push
,
...
@@ -57,6 +58,7 @@ from ..jit import no_torch_dynamo
...
@@ -57,6 +58,7 @@ from ..jit import no_torch_dynamo
from
..graph
import
is_graph_capturing
from
..graph
import
is_graph_capturing
from
..tensor.quantized_tensor
import
(
from
..tensor.quantized_tensor
import
(
QuantizedTensor
,
QuantizedTensor
,
QuantizedTensorBase
,
Quantizer
,
Quantizer
,
prepare_for_saving
,
prepare_for_saving
,
restore_from_saved
,
restore_from_saved
,
...
@@ -125,88 +127,100 @@ class _Linear(torch.autograd.Function):
...
@@ -125,88 +127,100 @@ class _Linear(torch.autograd.Function):
# Make sure input dimensions are compatible
# Make sure input dimensions are compatible
out_features
,
in_features
=
weight
.
shape
out_features
,
in_features
=
weight
.
shape
inp_shape
=
inp
.
shape
assert
inp
.
shape
[
-
1
]
==
in_features
,
"GEMM not possible"
assert
inp_shape
[
-
1
]
==
in_features
,
"GEMM not possible"
# Configure tensor-parallel communication
tp_world_size
=
get_distributed_world_size
(
tp_group
)
tp_world_size
=
get_distributed_world_size
(
tp_group
)
backward_needs_input
=
is_grad_enabled
and
weight
.
requires_grad
backward_needs_input
=
is_grad_enabled
and
weight
.
requires_grad
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push
(
f
"
{
nvtx_label
}
.input_cast_comm"
)
inputmat
=
inp
.
view
(
-
1
,
in_features
)
inputmat_total
=
None
with_input_all_gather_nccl
=
(
with_input_all_gather_nccl
=
(
parallel_mode
==
"column"
and
sequence_parallel
and
not
ub_overlap_ag_fprop
parallel_mode
==
"column"
and
sequence_parallel
and
not
ub_overlap_ag_fprop
)
)
own_quantized_input
=
False
# TODO(kwyss): Support FP8 allgather for FP8 block quantization.
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_input_gather
=
(
force_hp_input_gather
=
(
fp8
and
with_input_all_gather_nccl
and
isinstance
(
input_quantizer
,
Float8BlockQuantizer
)
fp8
and
with_input_all_gather_nccl
and
isinstance
(
input_quantizer
,
Float8BlockQuantizer
)
)
# Perform TP communication in high precision.
)
# Configure Userbuffers communication (comm+GEMM overlap)
ub_obj
=
None
ub_type
=
None
if
ub_overlap_rs_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_type
=
tex
.
CommOverlapType
.
RS
elif
ub_overlap_ag_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_type
=
tex
.
CommOverlapType
.
AG
# ------------------------------------------------------
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# ------------------------------------------------------
nvtx_range_push
(
f
"
{
nvtx_label
}
.input_cast_comm"
)
inputmat
=
inp
# Input tensor to save for backward (maybe sharded)
inputmat_total
=
None
# Input tensor to pass to GEMM (gathered)
own_quantized_input
=
False
if
fp8
:
if
fp8
:
assert_dim_for_fp8_exec
(
inputmat
,
weight
)
assert_dim_for_fp8_exec
(
inputmat
,
weight
)
if
any
([
ub_overlap_ag_fprop
,
ub_overlap_rs_fprop
])
and
not
(
if
with_input_all_gather_nccl
or
ub_overlap_ag_fprop
:
# All-gather input tensor
FP8GlobalStateManager
.
get_fp8_recipe
().
float8_per_tensor_scaling
()
):
# Cast local input tensor if needed
raise
NotImplementedError
(
if
fp8
or
debug
:
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
if
input_quantizer
is
None
:
" current scaling"
raise
ValueError
(
"Missing quantizer for input tensor"
)
)
if
not
force_hp_input_gather
and
not
isinstance
(
inputmat
,
QuantizedTensorBase
):
if
fp8
or
debug
:
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
backward_needs_input
)
if
input_quantizer
is
None
:
if
isinstance
(
raise
ValueError
(
"Missing quantizer for input tensor"
)
input_quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)
if
with_input_all_gather_nccl
:
):
if
force_hp_input_gather
:
# All-gather is not supported with FP8 column-wise data
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
input_quantizer
.
set_usage
(
columnwise
=
False
)
inputmat_total
,
_
=
gather_along_first_dim
(
inputmat
=
input_quantizer
(
inputmat
)
inputmat
,
tp_group
,
quantizer
=
input_quantizer
own_quantized_input
=
True
)
else
:
if
not
isinstance
(
inputmat
,
QuantizedTensor
):
columnwise_usage
=
backward_needs_input
and
isinstance
(
input_quantizer
,
MXFP8Quantizer
)
# force_hp_input_gather should enforce this
assert
not
isinstance
(
input_quantizer
,
Float8BlockQuantizer
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
columnwise_usage
)
inputmat
=
input_quantizer
(
inputmat
)
own_quantized_input
=
True
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
inputmat_total
,
_
=
gather_along_first_dim
(
inputmat
,
tp_group
,
quantizer
=
input_quantizer
,
)
else
:
else
:
if
(
inputmat
=
cast_if_needed
(
inp
,
activation_dtype
)
# Cast for AMP
FP8GlobalStateManager
.
get_fp8_recipe
().
float8_per_tensor_scaling
()
and
ub_bulk_dgrad
# Initialize gathered input tensor
):
quantizer
=
None
# reduce duplicated transpose in `_fix_gathered_fp8_transpose`
if
fp8
or
debug
:
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
quantizer
=
input_quantizer
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
with_input_all_gather_nccl
:
# Perform NCCL all-gather
inputmat_total
,
_
=
gather_along_first_dim
(
inputmat
,
tp_group
,
quantizer
=
quantizer
,
)
elif
ub_overlap_ag_fprop
:
# Initialize Userbuffers all-gather
inputmat_total
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ub_obj
,
inputmat
,
quantizer
,
tp_group
,
)
else
:
# Do not all-gather input tensor
if
fp8
or
debug
:
if
isinstance
(
inputmat
,
QuantizedTensorBase
):
inputmat
.
update_usage
(
rowwise_usage
=
True
)
else
:
else
:
input_quantizer
.
set_usage
(
if
input_quantizer
is
None
:
rowwise
=
True
,
raise
ValueError
(
"Missing quantizer for input tensor"
)
columnwise
=
backward_needs_input
,
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
backward_needs_input
)
)
if
not
isinstance
(
inputmat
,
QuantizedTensor
):
inputmat
=
input_quantizer
(
inputmat
)
inputmat
=
input_quantizer
(
inputmat
)
own_quantized_input
=
True
own_quantized_input
=
True
elif
backward_needs_input
:
inputmat
.
update_usage
(
rowwise_usage
=
True
,
columnwise_usage
=
True
)
inputmat_total
=
inputmat
else
:
inputmat
=
cast_if_needed
(
inp
,
activation_dtype
)
if
with_input_all_gather_nccl
:
inputmat_total
,
_
=
gather_along_first_dim
(
inputmat
,
tp_group
)
else
:
else
:
inputmat_total
=
inputmat
inputmat
=
cast_if_needed
(
inp
,
activation_dtype
)
# Cast for AMP
inputmat_total
=
inputmat
nvtx_range_pop
(
f
"
{
nvtx_label
}
.input_cast_comm"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.input_cast_comm"
)
# ------------------------------------------------------
# Input tensor is ready for GEMM...
# ------------------------------------------------------
# Cast weight to expected dtype
# ------------------------------------------------------
# Prepare weight tensor
# ------------------------------------------------------
weightmat
=
weight
weightmat
=
weight
if
fp8
or
debug
:
if
fp8
or
debug
:
# Configure quantizer
# Configure quantizer
if
weight_quantizer
is
not
None
:
if
weight_quantizer
is
not
None
:
...
@@ -217,7 +231,8 @@ class _Linear(torch.autograd.Function):
...
@@ -217,7 +231,8 @@ class _Linear(torch.autograd.Function):
and
not
in_fp8_activation_recompute_phase
()
and
not
in_fp8_activation_recompute_phase
()
)
)
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
columnwise_usage
)
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
columnwise_usage
)
# FP8 cast to workspace buffer
# Get quantized weight
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
weightmat
=
module
.
get_weight_workspace
(
weightmat
=
module
.
get_weight_workspace
(
tensor
=
weight
,
tensor
=
weight
,
...
@@ -228,19 +243,21 @@ class _Linear(torch.autograd.Function):
...
@@ -228,19 +243,21 @@ class _Linear(torch.autograd.Function):
fsdp_group
=
fsdp_group
,
fsdp_group
=
fsdp_group
,
workspace_dtype
=
activation_dtype
,
workspace_dtype
=
activation_dtype
,
)
)
weightmat
.
update_usage
(
rowwise_usage
=
True
)
else
:
else
:
weightmat
=
cast_if_needed
(
weightmat
,
activation_dtype
)
weightmat
=
cast_if_needed
(
weightmat
,
activation_dtype
)
# Cast for AMP
# ------------------------------------------------------
# Weight tensor is ready for GEMM...
# ------------------------------------------------------
# Cast bias to expected dtype
# Cast bias to expected dtype
bias_dtype
=
activation_dtype
bias_dtype
=
activation_dtype
if
needs_quantized_gemm
(
inputmat_total
)
and
activation_dtype
==
torch
.
float32
:
if
needs_quantized_gemm
(
inputmat_total
)
and
activation_dtype
==
torch
.
float32
:
# cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16
bias_dtype
=
torch
.
bfloat16
bias_dtype
=
torch
.
bfloat16
bias
=
cast_if_needed
(
bias
,
bias_dtype
)
if
bias
is
not
None
else
bias
bias
=
cast_if_needed
(
bias
,
bias_dtype
)
if
bias
is
not
None
else
bias
# Configure output quantizer
if
output_quantizer
is
not
None
:
output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
# Calibrate quantizers if needed
# Calibrate quantizers if needed
if
not
fp8
and
fp8_calibration
:
if
not
fp8
and
fp8_calibration
:
if
input_quantizer
is
not
None
:
if
input_quantizer
is
not
None
:
...
@@ -248,44 +265,74 @@ class _Linear(torch.autograd.Function):
...
@@ -248,44 +265,74 @@ class _Linear(torch.autograd.Function):
if
weight_quantizer
is
not
None
:
if
weight_quantizer
is
not
None
:
weight_quantizer
.
calibrate
(
weight
)
weight_quantizer
.
calibrate
(
weight
)
ub_obj
=
None
# Choose whether to use GEMM kernel with split accumulator
ub_type
=
None
use_split_accumulator
=
_2X_ACC_FPROP
rs_out
=
None
out_dtype
=
activation_dtype
if
ub_overlap_rs_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_type
=
tex
.
CommOverlapType
.
RS
out_shape
=
[
reduce
(
multiply_op
,
inp_shape
[:
-
1
])
//
tp_world_size
,
out_features
]
rs_out
=
torch
.
empty
(
out_shape
,
dtype
=
activation_dtype
,
device
=
inputmat_total
.
device
)
elif
ub_overlap_ag_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_type
=
tex
.
CommOverlapType
.
AG
if
fp8
:
assert
ub_obj
.
is_fp8_ubuf
(),
"AG overlap with FP8 GEMM inputs requires FP8 buffer."
ub_obj
.
copy_into_buffer
(
inputmat_total
,
input_quantizer
,
local_chunk
=
True
)
inputmat_total
=
ub_obj
.
get_buffer
(
input_quantizer
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.gemm"
)
fprop_gemm_use_split_accumulator
=
_2X_ACC_FPROP
if
fp8
:
if
fp8
:
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
hasattr
(
recipe
,
"fp8_gemm_fprop"
):
if
hasattr
(
recipe
,
"fp8_gemm_fprop"
):
fprop_gemm_use_split_accumulator
=
recipe
.
fp8_gemm_fprop
.
use_split_accumulator
use_split_accumulator
=
recipe
.
fp8_gemm_fprop
.
use_split_accumulator
# Configure output quantizer
if
output_quantizer
is
not
None
:
output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
out
,
*
_
,
rs_out
=
general_gemm
(
# Output buffer for Userbuffers reduce-scatter
reduce_scatter_out
=
None
if
ub_overlap_rs_fprop
:
out_shape
=
list
(
inp
.
shape
)
out_shape
[
0
]
//=
tp_world_size
out_shape
[
-
1
]
=
out_features
reduce_scatter_out
=
torch
.
empty
(
out_shape
,
dtype
=
activation_dtype
,
device
=
inp
.
device
)
# ------------------------------------------------------
# Forward GEMM
# Note: y = x * w^T
# ------------------------------------------------------
nvtx_range_push
(
f
"
{
nvtx_label
}
.gemm"
)
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
weightmat
,
weightmat
,
inputmat_total
,
inputmat_total
,
get_workspace
(),
get_workspace
(),
quantization_params
=
output_quantizer
,
quantization_params
=
output_quantizer
,
out_dtype
=
out
_dtype
,
out_dtype
=
activation
_dtype
,
bias
=
bias
,
bias
=
bias
,
use_split_accumulator
=
fprop_gemm_
use_split_accumulator
,
use_split_accumulator
=
use_split_accumulator
,
ub
=
ub_obj
,
ub
=
ub_obj
,
ub_type
=
ub_type
,
ub_type
=
ub_type
,
extra_output
=
r
s
_out
,
extra_output
=
r
educe_scatter
_out
,
)
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.gemm"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.gemm"
)
# ------------------------------------------------------
# Finished forward GEMM...
# ------------------------------------------------------
# ------------------------------------------------------
# Prepare output tensor
# Note: Perform tensor-parallel communication
# ------------------------------------------------------
out
=
None
if
ub_overlap_rs_fprop
:
out
=
reduce_scatter_out
elif
parallel_mode
==
"row"
and
tp_size
>
1
:
nvtx_range_push
(
f
"
{
nvtx_label
}
.row_parallel_comm"
)
out
=
gemm_out
if
sequence_parallel
:
out
,
_
=
reduce_scatter_along_first_dim
(
out
,
tp_group
)
elif
tensor_parallel
:
if
symmetric_ar_type
is
not
None
:
out
,
_
=
symmetric_all_reduce
(
out
,
tp_group
,
all_reduce_type
=
symmetric_ar_type
)
else
:
out
,
_
=
allreduce
(
out
,
tp_group
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.row_parallel_comm"
)
else
:
out
=
gemm_out
# ------------------------------------------------------
# Output tensor is ready to return...
# ------------------------------------------------------
# ------------------------------------------------------
# Cache state for backward pass
# ------------------------------------------------------
if
is_grad_enabled
:
if
is_grad_enabled
:
ctx
.
weight_quantizer
=
weight_quantizer
ctx
.
weight_quantizer
=
weight_quantizer
...
@@ -296,19 +343,19 @@ class _Linear(torch.autograd.Function):
...
@@ -296,19 +343,19 @@ class _Linear(torch.autograd.Function):
)
)
if
backward_needs_input
:
if
backward_needs_input
:
if
own_quantized_input
and
isinstance
(
inputmat
,
QuantizedTensor
):
if
own_quantized_input
and
isinstance
(
inputmat
,
QuantizedTensor
Base
):
# For sequence parallel in vanilla FP8, rowwise data is
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
# can be allgathered.
if
isinstance
(
inputmat
,
MXFP8TensorBase
)
or
not
ctx
.
backward_input_needs_gather
:
if
isinstance
(
inputmat
,
MXFP8TensorBase
)
or
not
ctx
.
backward_input_needs_gather
:
inputmat
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
inputmat
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
if
force_hp_input_gather
:
if
force_hp_input_gather
:
assert
not
isinstance
(
inputmat
,
QuantizedTensor
)
assert
not
isinstance
(
inputmat
,
QuantizedTensor
Base
)
saved_inputmat
=
inputmat
saved_inputmat
=
inputmat
# Weight with column-wise usage is needed for dgrad GEMM.
# Weight with column-wise usage is needed for dgrad GEMM.
if
inp
.
requires_grad
:
if
inp
.
requires_grad
:
if
isinstance
(
weightmat
,
QuantizedTensor
):
if
isinstance
(
weightmat
,
QuantizedTensor
Base
):
weightmat
.
update_usage
(
columnwise_usage
=
True
)
weightmat
.
update_usage
(
columnwise_usage
=
True
)
if
cpu_offloading
and
saved_inputmat
is
not
None
:
if
cpu_offloading
and
saved_inputmat
is
not
None
:
...
@@ -321,7 +368,7 @@ class _Linear(torch.autograd.Function):
...
@@ -321,7 +368,7 @@ class _Linear(torch.autograd.Function):
ctx
.
fsdp_shapes
=
_fsdp_scatter_tensors
(
ctx
.
fsdp_shapes
=
_fsdp_scatter_tensors
(
fsdp_group
,
fsdp_group
,
saved_inputmat
,
saved_inputmat
,
weightmat
if
fp8
and
not
isinstance
(
weight
,
QuantizedTensor
)
else
None
,
weightmat
if
fp8
and
not
isinstance
(
weight
,
QuantizedTensor
Base
)
else
None
,
)
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
...
@@ -364,7 +411,7 @@ class _Linear(torch.autograd.Function):
...
@@ -364,7 +411,7 @@ class _Linear(torch.autograd.Function):
ctx
.
use_bias
=
bias
is
not
None
ctx
.
use_bias
=
bias
is
not
None
ctx
.
sequence_parallel
=
sequence_parallel
ctx
.
sequence_parallel
=
sequence_parallel
ctx
.
tensor_parallel
=
tensor_parallel
ctx
.
tensor_parallel
=
tensor_parallel
ctx
.
inp_shape
=
inp
_
shape
ctx
.
inp_shape
=
inp
.
shape
ctx
.
parallel_mode
=
parallel_mode
ctx
.
parallel_mode
=
parallel_mode
ctx
.
tp_group
=
tp_group
ctx
.
tp_group
=
tp_group
ctx
.
ub_overlap_ag
=
ub_overlap_ag_dgrad
ctx
.
ub_overlap_ag
=
ub_overlap_ag_dgrad
...
@@ -376,6 +423,7 @@ class _Linear(torch.autograd.Function):
...
@@ -376,6 +423,7 @@ class _Linear(torch.autograd.Function):
ctx
.
requires_dgrad
=
inp
.
requires_grad
ctx
.
requires_dgrad
=
inp
.
requires_grad
ctx
.
requires_wgrad
=
weight
.
requires_grad
ctx
.
requires_wgrad
=
weight
.
requires_grad
ctx
.
reduce_and_update_bwd_fp8_tensors
=
False
ctx
.
reduce_and_update_bwd_fp8_tensors
=
False
ctx
.
owns_input
=
saved_inputmat
is
not
inp
ctx
.
owns_input
=
saved_inputmat
is
not
inp
if
ctx
.
fp8
and
requires_grad
(
inp
,
weight
,
bias
):
if
ctx
.
fp8
and
requires_grad
(
inp
,
weight
,
bias
):
_first_fp8_module
=
FP8GlobalStateManager
.
IS_FIRST_FP8_MODULE
_first_fp8_module
=
FP8GlobalStateManager
.
IS_FIRST_FP8_MODULE
...
@@ -384,21 +432,10 @@ class _Linear(torch.autograd.Function):
...
@@ -384,21 +432,10 @@ class _Linear(torch.autograd.Function):
FP8GlobalStateManager
.
IS_FIRST_FP8_MODULE
=
_first_fp8_module
FP8GlobalStateManager
.
IS_FIRST_FP8_MODULE
=
_first_fp8_module
ctx
.
wgrad_store
=
wgrad_store
ctx
.
wgrad_store
=
wgrad_store
# Row Parallel Linear
# ------------------------------------------------------
if
ub_overlap_rs_fprop
:
# Cached state for backward pass is ready...
out
=
rs_out
# ------------------------------------------------------
elif
parallel_mode
==
"row"
:
nvtx_range_push
(
f
"
{
nvtx_label
}
.row_parallel_comm"
)
if
sequence_parallel
:
out
,
_
=
reduce_scatter_along_first_dim
(
out
,
tp_group
)
elif
tensor_parallel
:
if
symmetric_ar_type
is
not
None
:
out
,
_
=
symmetric_all_reduce
(
out
,
tp_group
,
all_reduce_type
=
symmetric_ar_type
)
else
:
out
,
_
=
allreduce
(
out
,
tp_group
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.row_parallel_comm"
)
out
=
out
.
view
(
-
1
,
*
inp_shape
[
1
:
-
1
],
out_features
)
return
out
return
out
@
staticmethod
@
staticmethod
...
@@ -411,28 +448,11 @@ class _Linear(torch.autograd.Function):
...
@@ -411,28 +448,11 @@ class _Linear(torch.autograd.Function):
nvtx_label
=
f
"
{
nvtx_label
}
.
{
ctx
.
ub_name
}
"
nvtx_label
=
f
"
{
nvtx_label
}
.
{
ctx
.
ub_name
}
"
with
torch
.
cuda
.
nvtx
.
range
(
"_Linear_backward"
):
with
torch
.
cuda
.
nvtx
.
range
(
"_Linear_backward"
):
if
(
ctx
.
fp8
and
any
(
[
ctx
.
ub_overlap_ag
,
ctx
.
ub_overlap_rs_dgrad
,
ctx
.
ub_bulk_dgrad
,
ctx
.
ub_bulk_wgrad
,
]
)
and
(
ctx
.
fp8_recipe
is
not
None
)
):
if
not
ctx
.
fp8_recipe
.
float8_per_tensor_scaling
():
raise
NotImplementedError
(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
saved_tensors
=
ctx
.
saved_tensors
saved_tensors
=
ctx
.
saved_tensors
inputmat
,
weight_fp8
,
weight
,
bias
=
(
# pylint: disable=unbalanced-tuple-unpacking
inputmat
,
weight_fp8
,
weight
,
bias
=
(
# pylint: disable=unbalanced-tuple-unpacking
restore_from_saved
(
ctx
.
tensor_objects
,
saved_tensors
)
restore_from_saved
(
ctx
.
tensor_objects
,
saved_tensors
)
)
)
# Delete the references to tensor objects once they've been consumed
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
# by the `restore_from_saved` method to construct back the actual tensors.
ctx
.
tensor_objects
=
None
ctx
.
tensor_objects
=
None
...
@@ -462,69 +482,55 @@ class _Linear(torch.autograd.Function):
...
@@ -462,69 +482,55 @@ class _Linear(torch.autograd.Function):
)
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_gather"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_gather"
)
# Configure Userbuffers communication (comm+GEMM overlap)
ctx
.
ub_obj_gradout
=
None
ctx
.
ub_obj_gradout
=
None
ub_obj_dgrad
=
None
ub_obj_dgrad
=
None
ub_obj_wgrad
=
None
ub_obj_wgrad
=
None
ub_type_dgrad
=
None
ub_type_dgrad
=
None
ub_type_wgrad
=
None
ub_type_wgrad
=
None
dgrad_shape
=
[
reduce
(
multiply_op
,
ctx
.
inp_shape
[:
-
1
]),
ctx
.
inp_shape
[
-
1
]]
dgrad_shape
=
[
reduce
(
multiply_op
,
ctx
.
inp_shape
[:
-
1
]),
ctx
.
inp_shape
[
-
1
]]
rs_out
=
None
dgrad_bulk
=
None
if
ctx
.
ub_overlap_ag
:
if
ctx
.
ub_overlap_ag
:
# Overlap grad_output all-gather with dgrad compute
# Overlap grad_output all-gather with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
elif
ctx
.
ub_overlap_rs_dgrad
:
elif
ctx
.
ub_overlap_rs_dgrad
:
# Overlap dgrad reduce-scatter with dgrad compute
# Overlap dgrad reduce-scatter with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
RS
ub_type_dgrad
=
tex
.
CommOverlapType
.
RS
rs_out
=
torch
.
empty
(
dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
grad_output
.
device
)
else
:
else
:
if
ctx
.
ub_bulk_dgrad
:
if
ctx
.
ub_bulk_dgrad
:
# Overlap inputmat all-gather with dgrad compute
# Overlap inputmat all-gather with dgrad compute
# NOTE: Copying into communication buffer will always prefer rowwise data,
# and will copy columnwise data if rowwise does not exist. In that case,
# the all-gather will apply to the leading dimension of the transpose,
# which then needs to be interleaved correctly before WGRAD.
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
ub_obj_dgrad
.
copy_into_buffer
(
inputmat
,
ctx
.
input_quantizer
,
local_chunk
=
True
)
if
ctx
.
ub_bulk_wgrad
:
if
ctx
.
ub_bulk_wgrad
:
# Overlap dgrad reduce-scatter with wgrad compute
# Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
)
ub_obj_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
)
ub_type_wgrad
=
tex
.
CommOverlapType
.
RS
ub_type_wgrad
=
tex
.
CommOverlapType
.
RS
ub_obj_wgrad
.
set_buffer_params
(
ctx
.
grad_input_quantizer
)
dgrad_bulk
=
ub_obj_wgrad
.
get_buffer
(
ctx
.
grad_input_quantizer
)
# --------------------------------------------------
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# --------------------------------------------------
# Unmodified grad output tensor
grad_output_arg
=
grad_output
# Configure quantizer for grad output tensor
# Configure quantizer for grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
# requires column-wise usage
if
ctx
.
grad_output_quantizer
is
not
None
:
if
ctx
.
grad_output_quantizer
is
not
None
:
rowwise_usage
=
True
quantizer
=
ctx
.
grad_output_quantizer
columnwise_usage
=
True
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
if
ctx
.
ub_overlap_ag
and
isinstance
(
if
ctx
.
ub_overlap_ag
:
ctx
.
grad_output_quantizer
,
# Userbuffers only supports communication for one
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
),
# tensor usage at a time. Configure quantizer with
):
# usage for only dgrad GEMM.
# If data is in FP8 and communication is handled
quantizer
.
set_usage
(
columnwise
=
False
)
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage
=
False
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
rowwise_usage
,
columnwise
=
columnwise_usage
,
)
# Prepare grad output tensor
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push
(
f
"
{
nvtx_label
}
.grad_output_preprocess"
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.grad_output_preprocess"
)
(
(
grad_output
,
grad_output
,
...
@@ -537,12 +543,21 @@ class _Linear(torch.autograd.Function):
...
@@ -537,12 +543,21 @@ class _Linear(torch.autograd.Function):
)
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.grad_output_preprocess"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.grad_output_preprocess"
)
# Launch tensor-parallel communication for input tensor
# --------------------------------------------------
# Grad output tensor is ready for computing grad input...
# --------------------------------------------------
# --------------------------------------------------
# Prepare input tensor
# Note: Input tensor is needed for wgrad GEMM.
# Tensor-parallel communication is overlapped with dgrad
# GEMM.
# --------------------------------------------------
inputmat_total
=
None
inputmat_total
=
None
inputmat_total_work
=
None
inputmat_total_work
=
None
if
ctx
.
backward_input_needs_gather
and
not
ctx
.
ub_bulk_dgrad
:
if
ctx
.
backward_input_needs_gather
:
quantizer
=
None
quantizer
=
None
if
ctx
.
fp8
or
ctx
.
debug
:
if
(
ctx
.
fp8
or
ctx
.
debug
)
and
not
ctx
.
force_hp_input_gather
:
quantizer
=
ctx
.
input_quantizer
quantizer
=
ctx
.
input_quantizer
if
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)):
if
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)):
# If data is in FP8, we compute FP8 transposes manually
# If data is in FP8, we compute FP8 transposes manually
...
@@ -550,72 +565,92 @@ class _Linear(torch.autograd.Function):
...
@@ -550,72 +565,92 @@ class _Linear(torch.autograd.Function):
else
:
else
:
# wgrad GEMM requires input with column-wise usage
# wgrad GEMM requires input with column-wise usage
quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.column_parallel_comm_input"
)
if
ctx
.
ub_bulk_dgrad
:
gather_quantizer
=
None
if
ctx
.
force_hp_input_gather
else
quantizer
inputmat_total
,
_
=
fill_userbuffers_buffer_for_all_gather
(
inputmat_total
,
inputmat_total_work
=
gather_along_first_dim
(
ub_obj_dgrad
,
inputmat
,
inputmat
,
ctx
.
tp_group
,
quantizer
,
async_op
=
True
,
ctx
.
tp_group
,
quantizer
=
gather_quantizer
,
)
)
else
:
nvtx_range_pop
(
f
"
{
nvtx_label
}
.column_parallel_comm_input"
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.column_parallel_comm_input"
)
inputmat_total
,
inputmat_total_work
=
gather_along_first_dim
(
inputmat
,
ctx
.
tp_group
,
async_op
=
True
,
quantizer
=
quantizer
,
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.column_parallel_comm_input"
)
else
:
else
:
inputmat_total
=
inputmat
inputmat_total
=
inputmat
# --------------------------------------------------
# Input tensor is ready for computing grad weight...
# --------------------------------------------------
# Check whether to output wgrad GEMM directly into main grad
# --------------------------------------------------
if
ctx
.
is_first_microbatch
is
not
None
:
accumulate_wgrad_into_param_main_grad
=
(
ctx
.
fuse_wgrad_accumulation
and
not
ctx
.
is_first_microbatch
)
else
:
accumulate_wgrad_into_param_main_grad
=
ctx
.
fuse_wgrad_accumulation
# Compute grad input tensor
# Compute grad input tensor
# --------------------------------------------------
dgrad
=
None
dgrad
=
None
dgrad_work
=
None
dgrad_work
=
None
if
ctx
.
requires_dgrad
:
if
ctx
.
requires_dgrad
:
# Update quantizer
# Make sure required data is available
if
ctx
.
grad_input_quantizer
is
not
None
:
if
isinstance
(
grad_output
,
QuantizedTensorBase
):
ctx
.
grad_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
grad_output
.
update_usage
(
rowwise_usage
=
True
)
# dgrad GEMM
if
ctx
.
weight_quantizer
is
not
None
and
isinstance
(
weight_fp8
,
QuantizedTensorBase
):
nvtx_range_push
(
f
"
{
nvtx_label
}
.dgrad_gemm"
)
weight_fp8
.
update_usage
(
columnwise_usage
=
True
)
dgrad_gemm_use_split_accumulator
=
_2X_ACC_DGRAD
# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator
=
_2X_ACC_DGRAD
if
ctx
.
fp8
:
if
ctx
.
fp8
:
recipe
=
ctx
.
fp8_recipe
recipe
=
ctx
.
fp8_recipe
if
hasattr
(
recipe
,
"fp8_gemm_dgrad"
):
if
hasattr
(
recipe
,
"fp8_gemm_dgrad"
):
dgrad_gemm_use_split_accumulator
=
(
use_split_accumulator
=
recipe
.
fp8_gemm_dgrad
.
use_split_accumulator
recipe
.
fp8_gemm_dgrad
.
use_split_accumulator
)
if
ctx
.
weight_quantizer
is
not
None
and
isinstance
(
weight_fp8
,
QuantizedTensor
):
# Update grad input quantizer
weight_fp8
.
update_usage
(
if
ctx
.
grad_input_quantizer
is
not
None
:
rowwise_usage
=
ctx
.
weight_quantizer
.
rowwise_usage
,
ctx
.
grad_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
columnwise_usage
=
ctx
.
weight_quantizer
.
columnwise_usage
,
# Output buffers for Userbuffers reduce-scatter
gemm_out
=
None
reduce_scatter_out
=
None
if
ctx
.
ub_overlap_rs_dgrad
:
reduce_scatter_out
=
torch
.
empty
(
dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
grad_output_arg
.
device
)
)
elif
ctx
.
ub_bulk_wgrad
:
gemm_out
=
ub_obj_wgrad
.
get_buffer
(
local_chunk
=
False
)
dgrad
,
*
_
,
rs_out
=
general_gemm
(
# dgrad GEMM
# Note: dx = dy * w
nvtx_range_push
(
f
"
{
nvtx_label
}
.dgrad_gemm"
)
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
weight_fp8
,
weight_fp8
,
grad_output
,
grad_output
,
get_workspace
(),
get_workspace
(),
layout
=
"NN"
,
layout
=
"NN"
,
grad
=
True
,
grad
=
True
,
quantization_params
=
ctx
.
grad_input_quantizer
,
quantization_params
=
ctx
.
grad_input_quantizer
,
out
=
dgrad_bulk
,
out
=
gemm_out
,
out_dtype
=
ctx
.
activation_dtype
,
out_dtype
=
ctx
.
activation_dtype
,
use_split_accumulator
=
dgrad_gemm_
use_split_accumulator
,
use_split_accumulator
=
use_split_accumulator
,
ub
=
ub_obj_dgrad
,
ub
=
ub_obj_dgrad
,
ub_type
=
ub_type_dgrad
,
ub_type
=
ub_type_dgrad
,
extra_output
=
r
s
_out
,
extra_output
=
r
educe_scatter
_out
,
bulk_overlap
=
ctx
.
ub_bulk_dgrad
,
bulk_overlap
=
ctx
.
ub_bulk_dgrad
,
)
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.dgrad_gemm"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.dgrad_gemm"
)
# Launch tensor-parallel communication
# Prepare grad input tensor
# Note: Perform tensor-parallel communication
if
ctx
.
ub_overlap_rs_dgrad
:
if
ctx
.
ub_overlap_rs_dgrad
:
dgrad
=
rs_out
dgrad
=
reduce_scatter_out
elif
ctx
.
parallel_mode
==
"column"
and
not
ctx
.
ub_bulk_wgrad
:
elif
ctx
.
ub_bulk_wgrad
:
dgrad
=
ub_obj_wgrad
.
get_buffer
(
local_chunk
=
True
)
elif
ctx
.
parallel_mode
==
"column"
and
ctx
.
tp_size
>
1
:
nvtx_range_push
(
f
"
{
nvtx_label
}
.column_parallel_comm_dgrad"
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.column_parallel_comm_dgrad"
)
dgrad
=
gemm_out
if
ctx
.
sequence_parallel
:
if
ctx
.
sequence_parallel
:
dgrad
,
dgrad_work
=
reduce_scatter_along_first_dim
(
dgrad
,
dgrad_work
=
reduce_scatter_along_first_dim
(
dgrad
,
dgrad
,
...
@@ -625,41 +660,55 @@ class _Linear(torch.autograd.Function):
...
@@ -625,41 +660,55 @@ class _Linear(torch.autograd.Function):
else
:
else
:
dgrad
,
dgrad_work
=
allreduce
(
dgrad
,
ctx
.
tp_group
,
async_op
=
True
)
dgrad
,
dgrad_work
=
allreduce
(
dgrad
,
ctx
.
tp_group
,
async_op
=
True
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.column_parallel_comm_dgrad"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.column_parallel_comm_dgrad"
)
else
:
dgrad
=
gemm_out
# --------------------------------------------------
# Grad input tensor has been computed...
# --------------------------------------------------
# --------------------------------------------------
# Compute grad weight
# --------------------------------------------------
# Compute grad weight tensor
wgrad
=
None
wgrad
=
None
if
ctx
.
requires_wgrad
:
if
ctx
.
requires_wgrad
:
# Synchronize tensor-parallel communication for input tensor
# Prepare input tensor
if
ctx
.
ub_bulk_dgrad
:
# Note: Synchronize tensor-parallel communication and
inputmat_total
=
ub_obj_dgrad
.
get_buffer
(
ctx
.
input_quantizer
)
# make sure required data is available
if
ctx
.
fp8
:
if
inputmat
.
_data
is
None
:
# All-gather executed on columnwise data and result is in rowwise data,
# so we need to fix the interleaving before WGRAD.
inputmat_total
=
_fix_gathered_fp8_transpose
(
inputmat_total
,
ctx
.
tp_size
)
elif
not
non_tn_fp8_gemm_supported
():
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
inputmat_total
.
_create_transpose
()
if
inputmat_total_work
is
not
None
:
if
inputmat_total_work
is
not
None
:
inputmat_total_work
.
wait
()
inputmat_total_work
.
wait
()
inputmat_total_work
=
None
inputmat_total_work
=
None
if
ctx
.
input_quantizer
is
not
None
and
not
isinstance
(
if
ctx
.
fp8
or
ctx
.
debug
:
inputmat_total
,
QuantizedTensor
if
isinstance
(
inputmat_total
,
QuantizedTensorBase
):
):
inputmat_total
.
update_usage
(
columnwise_usage
=
True
)
# Async gather in BF16 does not asynchronously
else
:
# call quantizer after gather.
ctx
.
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ctx
.
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
inputmat_total
=
ctx
.
input_quantizer
(
inputmat_total
)
inputmat_total
=
ctx
.
input_quantizer
(
inputmat_total
)
# Prepare grad output tensor
# Make sure GEMM inputs have required data
# Note: Synchronize tensor-parallel communication and
if
isinstance
(
inputmat_total
,
QuantizedTensor
):
# make sure required data is available
inputmat_total
.
update_usage
(
columnwise_usage
=
True
)
if
ctx
.
ub_overlap_ag
and
isinstance
(
ctx
.
grad_output_quantizer
,
MXFP8Quantizer
):
if
isinstance
(
grad_output
,
QuantizedTensor
):
# UB does not support overlapping grad output
grad_output
.
update_usage
(
columnwise_usage
=
True
)
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
grad_output
,
_
=
gather_along_first_dim
(
grad_output_arg
,
ctx
.
tp_group
,
quantizer
=
ctx
.
grad_output_quantizer
,
)
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
grad_output
,
QuantizedTensorBase
):
grad_output
.
update_usage
(
columnwise_usage
=
True
)
else
:
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
grad_output
=
ctx
.
grad_output_quantizer
(
grad_output
)
# Figure out whether to use split accumulator
# Figure out whether to use split accumulator
use_split_accumulator
=
_2X_ACC_WGRAD
use_split_accumulator
=
_2X_ACC_WGRAD
...
@@ -668,54 +717,95 @@ class _Linear(torch.autograd.Function):
...
@@ -668,54 +717,95 @@ class _Linear(torch.autograd.Function):
if
hasattr
(
recipe
,
"fp8_gemm_wgrad"
):
if
hasattr
(
recipe
,
"fp8_gemm_wgrad"
):
use_split_accumulator
=
recipe
.
fp8_gemm_wgrad
.
use_split_accumulator
use_split_accumulator
=
recipe
.
fp8_gemm_wgrad
.
use_split_accumulator
# Output buffer for overlapping grad input
# Figure out whether to output wgrad GEMM directly into main grad
if
ctx
.
is_first_microbatch
is
not
None
:
accumulate_wgrad_into_param_main_grad
=
(
ctx
.
fuse_wgrad_accumulation
and
not
ctx
.
is_first_microbatch
)
else
:
accumulate_wgrad_into_param_main_grad
=
ctx
.
fuse_wgrad_accumulation
# Output buffer for overlapping FP8 grad input
# reduce-scatter with wgrad GEMM
# reduce-scatter with wgrad GEMM
reduce_scatter_out
=
None
if
ctx
.
ub_bulk_wgrad
and
ub_obj_wgrad
.
is_fp8_ubuf
():
if
ctx
.
ub_bulk_wgrad
and
ub_obj_wgrad
.
is_fp8_ubuf
():
r
s
_out
=
torch
.
empty
(
r
educe_scatter
_out
=
torch
.
empty
(
dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
grad_output
.
device
dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
grad_output
_arg
.
device
)
)
# wgrad GEMM
# Arguments to include in wgrad GEMM closure
# Note: Fuse with bgrad computation if needed
wgrad_gemm_kwargs
=
{
nvtx_range_push
(
f
"
{
nvtx_label
}
.wgrad_gemm"
)
"workspace"
:
get_workspace
(),
general_gemm_wgrad
=
functools
.
partial
(
"out_dtype"
:
(
general_gemm
,
out_dtype
=
(
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
),
),
workspace
=
get_workspace
(),
"quantization_params"
:
ctx
.
grad_weight_quantizer
,
layout
=
"NT"
,
"accumulate"
:
accumulate_wgrad_into_param_main_grad
,
grad
=
True
,
"layout"
:
"NT"
,
bias
=
(
bias
if
(
grad_bias
is
None
and
not
ctx
.
fp8
)
else
None
),
"out"
:
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
out
=
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
"bias"
:
(
bias
if
(
grad_bias
is
None
and
not
ctx
.
fp8
)
else
None
),
use_split_accumulator
=
use_split_accumulator
,
"use_split_accumulator"
:
use_split_accumulator
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
"grad"
:
True
,
quantization_params
=
ctx
.
grad_weight_quantizer
,
"ub"
:
ub_obj_wgrad
,
ub
=
ub_obj_wgrad
,
"ub_type"
:
ub_type_wgrad
,
ub_type
=
ub_type_wgrad
,
"extra_output"
:
reduce_scatter_out
,
extra_output
=
rs_out
,
"bulk_overlap"
:
ctx
.
ub_bulk_wgrad
,
bulk_overlap
=
ctx
.
ub_bulk_wgrad
,
}
)
def
wgrad_gemm
(
x
:
torch
.
Tensor
,
dy
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Perform wgrad GEMM: dw = dy^T * x
May be fused with bgrad computation.
May be called outside of this function to enable
some advanced communication/compute overlapping.
"""
nvtx_range_push
(
f
"
{
nvtx_label
}
.wgrad_gemm"
)
dw
,
db
,
*
_
=
general_gemm
(
x
,
dy
,
**
wgrad_gemm_kwargs
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.wgrad_gemm"
)
return
dw
,
db
# Choose whether to call wgrad GEMM now or delay
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
ctx
.
wgrad_store
.
put
([
inputmat_total
,
grad_output
],
general_gemm_wgrad
)
if
(
wgrad_gemm_kwargs
[
"ub"
]
is
not
None
or
wgrad_gemm_kwargs
[
"ub_type"
]
is
not
None
or
wgrad_gemm_kwargs
[
"extra_output"
]
is
not
None
or
wgrad_gemm_kwargs
[
"bulk_overlap"
]
):
raise
NotImplementedError
(
"Delayed weight grad computation is not supported "
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx
.
wgrad_store
.
put
([
inputmat_total
,
grad_output
],
wgrad_gemm
)
else
:
else
:
wgrad
,
grad_bias_
,
_
,
rs_out
=
general_gemm_wgrad
(
inputmat_total
,
grad_output
)
# Call wgrad GEMM now
wgrad
,
grad_bias_
=
wgrad_gemm
(
inputmat_total
,
grad_output
)
# Update grad bias if needed
if
grad_bias
is
None
:
if
grad_bias
is
None
:
grad_bias
=
grad_bias_
grad_bias
=
grad_bias_
del
grad_bias_
del
grad_bias_
# Deallocate input tensor
# Deallocate input tensor
if permitted
if
ctx
.
owns_input
:
if
ctx
.
owns_input
:
clear_tensor_data
(
inputmat_total
)
clear_tensor_data
(
inputmat_total
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.wgrad_gemm"
)
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if
ctx
.
ub_bulk_wgrad
:
if
ctx
.
ub_bulk_wgrad
:
if
ub_obj_wgrad
.
is_fp8_ubuf
():
if
ub_obj_wgrad
.
is_fp8_ubuf
():
dgrad
=
r
s
_out
dgrad
=
r
educe_scatter
_out
else
:
else
:
dgrad
=
ub_obj_wgrad
.
get_buffer
(
ctx
.
grad_input_quantizer
,
local_chunk
=
True
)
dgrad
=
ub_obj_wgrad
.
get_buffer
(
local_chunk
=
True
).
clone
()
# --------------------------------------------------
# Grad weight has been computed...
# --------------------------------------------------
# Don't return grad bias if not needed
# Don't return grad bias if not needed
if
not
ctx
.
use_bias
:
if
not
ctx
.
use_bias
:
...
@@ -753,13 +843,14 @@ class _Linear(torch.autograd.Function):
...
@@ -753,13 +843,14 @@ class _Linear(torch.autograd.Function):
else
:
else
:
wgrad
=
None
wgrad
=
None
# Update FP8 scaling factors if needed
if
ctx
.
reduce_and_update_bwd_fp8_tensors
and
not
is_graph_capturing
():
if
ctx
.
reduce_and_update_bwd_fp8_tensors
and
not
is_graph_capturing
():
nvtx_range_push
(
f
"
{
nvtx_label
}
.reduce_and_update_fp8_tensors"
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.reduce_and_update_fp8_tensors"
)
FP8GlobalStateManager
.
reduce_and_update_fp8_tensors
(
forward
=
False
)
FP8GlobalStateManager
.
reduce_and_update_fp8_tensors
(
forward
=
False
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.reduce_and_update_fp8_tensors"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.reduce_and_update_fp8_tensors"
)
# Scatter fp8 weight buffers
# Scatter fp8 weight buffers
if
ctx
.
fp8
and
not
isinstance
(
weight
,
QuantizedTensor
):
if
ctx
.
fp8
and
not
isinstance
(
weight
,
QuantizedTensor
Base
):
_fsdp_scatter_tensors
(
ctx
.
fsdp_group
,
weight_fp8
)
_fsdp_scatter_tensors
(
ctx
.
fsdp_group
,
weight_fp8
)
return
(
return
(
wgrad
,
wgrad
,
...
@@ -1207,7 +1298,12 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1207,7 +1298,12 @@ class Linear(TransformerEngineBaseModule):
"Splitting QuantizedTensor into multiple params is not supported"
"Splitting QuantizedTensor into multiple params is not supported"
)
)
else
:
else
:
warnings
.
warn
(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights
=
[
w
.
dequantize
()
for
w
in
unfused_weights
]
unfused_weights
=
[
w
.
dequantize
()
for
w
in
unfused_weights
]
weight_tensor
=
noop_cat
(
unfused_weights
)
weight_tensor
=
noop_cat
(
unfused_weights
)
if
self
.
use_bias
:
if
self
.
use_bias
:
bias_tensor
=
noop_cat
([
getattr
(
self
,
name
)
for
name
in
self
.
bias_names
])
bias_tensor
=
noop_cat
([
getattr
(
self
,
name
)
for
name
in
self
.
bias_names
])
...
@@ -1302,7 +1398,7 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1302,7 +1398,7 @@ class Linear(TransformerEngineBaseModule):
grad_output_quantizer
=
None
grad_output_quantizer
=
None
output_quantizer
=
None
output_quantizer
=
None
input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
input_quantizer
.
internal
=
Fals
e
input_quantizer
.
internal
=
Tru
e
weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
weight_quantizer
.
internal
=
True
weight_quantizer
.
internal
=
True
if
fp8_output
:
if
fp8_output
:
...
...
transformer_engine/pytorch/module/rmsnorm.py
View file @
f8c2af4c
...
@@ -98,6 +98,9 @@ class RMSNorm(_RMSNormOp):
...
@@ -98,6 +98,9 @@ class RMSNorm(_RMSNormOp):
)
)
kwargs
[
"dtype"
]
=
params_dtype
kwargs
[
"dtype"
]
=
params_dtype
# Flag for sequence parallelism (custom Megatron-LM integration)
self
.
sequence_parallel
:
Optional
[
bool
]
=
sequence_parallel
# Initialize RMSNorm operation
# Initialize RMSNorm operation
super
().
__init__
(
super
().
__init__
(
normalized_shape
,
normalized_shape
,
...
@@ -106,11 +109,6 @@ class RMSNorm(_RMSNormOp):
...
@@ -106,11 +109,6 @@ class RMSNorm(_RMSNormOp):
**
kwargs
,
**
kwargs
,
)
)
# Flag for sequence parallelism (custom Megatron-LM integration)
self
.
sequence_parallel
:
Optional
[
bool
]
=
sequence_parallel
if
sequence_parallel
is
not
None
:
self
.
weight
.
sequence_parallel
=
sequence_parallel
def
reset_rms_norm_parameters
(
self
)
->
None
:
def
reset_rms_norm_parameters
(
self
)
->
None
:
"""Deprecated"""
"""Deprecated"""
warnings
.
warn
(
warnings
.
warn
(
...
@@ -139,7 +137,7 @@ class RMSNorm(_RMSNormOp):
...
@@ -139,7 +137,7 @@ class RMSNorm(_RMSNormOp):
super
().
reset_parameters
()
super
().
reset_parameters
()
# Flag for sequence parallelism (custom Megatron-LM integration)
# Flag for sequence parallelism (custom Megatron-LM integration)
if
getattr
(
self
,
"
sequence_parallel
"
,
None
)
is
not
None
:
if
self
.
sequence_parallel
is
not
None
:
self
.
weight
.
sequence_parallel
=
self
.
sequence_parallel
self
.
weight
.
sequence_parallel
=
self
.
sequence_parallel
@
property
@
property
...
...
transformer_engine/pytorch/ops/basic/basic_linear.py
View file @
f8c2af4c
...
@@ -534,7 +534,9 @@ class BasicLinear(BasicOperation):
...
@@ -534,7 +534,9 @@ class BasicLinear(BasicOperation):
# Configure input tensor for backward pass
# Configure input tensor for backward pass
if
with_quantized_compute
and
isinstance
(
x_local
,
QuantizedTensor
):
if
with_quantized_compute
and
isinstance
(
x_local
,
QuantizedTensor
):
x_local
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
if
not
(
isinstance
(
x_local
,
Float8TensorBase
)
and
with_x_all_gather
):
# FP8 does not support all-gather of transpose data
x_local
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
return
y
,
x_local
,
w
return
y
,
x_local
,
w
...
@@ -622,7 +624,10 @@ class BasicLinear(BasicOperation):
...
@@ -622,7 +624,10 @@ class BasicLinear(BasicOperation):
# Check datatype
# Check datatype
if
dtype
is
None
:
if
dtype
is
None
:
dtype
=
weight
.
dtype
if
weight
is
not
None
:
dtype
=
weight
.
dtype
else
:
dtype
=
grad_output
.
dtype
dtype
=
canonicalize_dtype
(
dtype
)
dtype
=
canonicalize_dtype
(
dtype
)
if
dtype
not
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
if
dtype
not
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
raise
ValueError
(
f
"Supported dtypes are float32, float16, bfloat16 (got
{
dtype
}
)"
)
raise
ValueError
(
f
"Supported dtypes are float32, float16, bfloat16 (got
{
dtype
}
)"
)
...
@@ -814,7 +819,7 @@ class BasicLinear(BasicOperation):
...
@@ -814,7 +819,7 @@ class BasicLinear(BasicOperation):
x_async
=
None
x_async
=
None
dy_async
=
None
dy_async
=
None
# Check grad
inpu
t tensor
# Check grad
weigh
t tensor
dw
=
grad_weight
dw
=
grad_weight
dw_dtype
=
dtype
dw_dtype
=
dtype
if
dw
is
None
:
if
dw
is
None
:
...
...
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
View file @
f8c2af4c
...
@@ -4,30 +4,27 @@
...
@@ -4,30 +4,27 @@
"""Linear layer backward with Userbuffers communication."""
"""Linear layer backward with Userbuffers communication."""
# pylint: skip-file ### TODO Debug Userbuffers support
from
__future__
import
annotations
from
__future__
import
annotations
from
collections.abc
import
Iterable
from
typing
import
Optional
from
typing
import
Any
,
Optional
import
warnings
import
warnings
import
torch
import
torch
from
transformer_engine_torch
import
CommOverlap
Algo
from
transformer_engine_torch
import
CommOverlap
Type
from
...cpp_extensions
import
general_gemm
from
...cpp_extensions
import
general_gemm
from
...distributed
import
get_distributed_world_size
from
...distributed
import
gather_along_first_dim
,
get_distributed_world_size
from
...float8_tensor
import
Float8Tensor
from
...module.base
import
(
from
...fp8
import
FP8GlobalStateManager
,
get_fp8_te_dtype
fill_userbuffers_buffer_for_all_gather
,
from
...module.base
import
get_ub
,
get_workspace
get_ub
,
get_workspace
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
)
from
...tensor.quantized_tensor
import
QuantizedTensorBase
,
Quantizer
from
...tensor.mxfp8_tensor
import
MXFP8Quantizer
from
...utils
import
canonicalize_device
,
canonicalize_dtype
,
clear_tensor_data
from
...utils
import
canonicalize_device
,
canonicalize_dtype
,
clear_tensor_data
from
..basic
import
BasicLinear
,
Bias
,
ReduceScatter
from
..basic
import
BasicLinear
,
Bias
,
ReduceScatter
from
..op
import
FusedOperation
,
FusibleOperation
,
OperationContext
from
..op
import
FusedOperation
,
FusibleOperation
,
OperationContext
from
.._common
import
(
convert_tensor
,
get_fp8_meta_from_fp8_tensor
,
is_float8_tensor
,
reshape
,
)
class
UserbuffersBackwardLinear
(
FusedOperation
):
class
UserbuffersBackwardLinear
(
FusedOperation
):
...
@@ -47,9 +44,6 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -47,9 +44,6 @@ class UserbuffersBackwardLinear(FusedOperation):
reduce_scatter
:
Optional
[
ReduceScatter
],
reduce_scatter
:
Optional
[
ReduceScatter
],
)
->
None
:
)
->
None
:
### TODO Debug Userbuffers support
raise
NotImplementedError
(
"Userbuffers support has been broken by recent refactors"
)
# Basic operations that comprise this fused operation
# Basic operations that comprise this fused operation
op_idxs
=
{
"linear"
:
None
,
"bias"
:
None
,
"reduce_scatter"
:
None
}
op_idxs
=
{
"linear"
:
None
,
"bias"
:
None
,
"reduce_scatter"
:
None
}
ops
=
[]
ops
=
[]
...
@@ -89,9 +83,8 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -89,9 +83,8 @@ class UserbuffersBackwardLinear(FusedOperation):
grad_output
:
torch
.
Tensor
,
grad_output
:
torch
.
Tensor
,
input
:
Optional
[
torch
.
Tensor
],
# pylint: disable=redefined-builtin
input
:
Optional
[
torch
.
Tensor
],
# pylint: disable=redefined-builtin
weight
:
Optional
[
torch
.
Tensor
],
weight
:
Optional
[
torch
.
Tensor
],
input_dims
:
Iterable
[
int
],
weight_dims
:
Iterable
[
int
],
*
,
*
,
input_requires_grad
:
bool
=
True
,
weight_requires_grad
:
bool
=
True
,
weight_requires_grad
:
bool
=
True
,
bias_requires_grad
:
bool
=
False
,
bias_requires_grad
:
bool
=
False
,
device
:
Optional
[
torch
.
device
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
...
@@ -102,11 +95,11 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -102,11 +95,11 @@ class UserbuffersBackwardLinear(FusedOperation):
tensor_parallel_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
tensor_parallel_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
tensor_parallel_size
:
Optional
[
int
]
=
None
,
tensor_parallel_size
:
Optional
[
int
]
=
None
,
sequence_parallel
:
bool
=
False
,
sequence_parallel
:
bool
=
False
,
with_
fp8
_compute
:
bool
=
False
,
with_
quantized
_compute
:
bool
=
False
,
input_
fp8_meta
:
Optional
[
dict
[
str
,
Any
]
]
=
None
,
input_
quantizer
:
Optional
[
Quantizer
]
=
None
,
weight_
fp8_meta
:
Optional
[
dict
[
str
,
Any
]
]
=
None
,
weight_
quantizer
:
Optional
[
Quantizer
]
=
None
,
grad_output_
fp8_meta
:
Optional
[
dict
[
str
,
Any
]
]
=
None
,
grad_output_
quantizer
:
Optional
[
Quantizer
]
=
None
,
grad_input_
fp8_meta
:
Optional
[
dict
[
str
,
Any
]
]
=
None
,
grad_input_
quantizer
:
Optional
[
Quantizer
]
=
None
,
ub_comm_name
:
str
,
ub_comm_name
:
str
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
dict
]:
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
dict
]:
"""Functional API for backward pass
"""Functional API for backward pass
...
@@ -121,10 +114,6 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -121,10 +114,6 @@ class UserbuffersBackwardLinear(FusedOperation):
weight: torch.Tensor, optional
weight: torch.Tensor, optional
Weight tensor. Required to compute loss gradient w.r.t.
Weight tensor. Required to compute loss gradient w.r.t.
input.
input.
input_dims: iterable of int
Input tensor dimensions
weight_dims: iterable of int
Weight tensor dimensions
weight_requires_grad: bool
weight_requires_grad: bool
Whether to compute loss gradient w.r.t. weight tensor
Whether to compute loss gradient w.r.t. weight tensor
bias_requires_grad: bool
bias_requires_grad: bool
...
@@ -146,21 +135,18 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -146,21 +135,18 @@ class UserbuffersBackwardLinear(FusedOperation):
parallelism, i.e. distributing input or output tensors
parallelism, i.e. distributing input or output tensors
along outer dimension (sequence or batch dim) when not
along outer dimension (sequence or batch dim) when not
distributing along inner dimension (embedding dim)
distributing along inner dimension (embedding dim)
with_fp8_compute: bool, default = `False`
with_quantized_compute: bool, default = `False`
Whether to perform compute in FP8
Whether to perform compute with quantized data.
input_fp8_meta: dict, optional
input_quantizer: Quantizer, optional
FP8 metadata for casting input tensor to FP8. Required for
Builder class for quantized input tensor.
FP8 compute if input is not already in FP8.
weight_quantizer: Quantizer, optional
weight_fp8_meta: dict, optional
Builder class for quantized weight tensor.
FP8 metadata for casting weight tensor to FP8. Required for
grad_output_quantizer: Quantizer, optional
FP8 compute if weight is not already in FP8.
Builder class for quantized loss gradient w.r.t. output
grad_output_fp8_meta: dict, optional
tensor.
FP8 metadata for casting loss gradient w.r.t. output
grad_input_quantizer: Quantizer, optional
tensor to FP8. Required if output grad is not already in
Builder class for quantized loss gradient w.r.t. input
FP8.
tensor.
grad_input_fp8_meta: dict, optional
FP8 metadata for casting loss gradient w.r.t. input
tensor to FP8
ub_comm_name: str
ub_comm_name: str
Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is
Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is
used to access the corresponding Userbuffers communicators
used to access the corresponding Userbuffers communicators
...
@@ -183,37 +169,24 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -183,37 +169,24 @@ class UserbuffersBackwardLinear(FusedOperation):
# Check device
# Check device
if
device
is
None
:
if
device
is
None
:
device
=
weight
.
device
if
weight
is
not
None
:
device
=
weight
.
device
else
:
device
=
grad_output
.
device
device
=
canonicalize_device
(
device
)
device
=
canonicalize_device
(
device
)
if
device
.
type
!=
"cuda"
:
if
device
.
type
!=
"cuda"
:
raise
ValueError
(
f
"Only CUDA devices are supported (got
{
device
}
)"
)
raise
ValueError
(
f
"Only CUDA devices are supported (got
{
device
}
)"
)
# Check datatype
# Check datatype
if
dtype
is
None
:
if
dtype
is
None
:
dtype
=
weight
.
dtype
if
weight
is
not
None
:
dtype
=
weight
.
dtype
else
:
dtype
=
grad_output
.
dtype
dtype
=
canonicalize_dtype
(
dtype
)
dtype
=
canonicalize_dtype
(
dtype
)
if
dtype
not
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
if
dtype
not
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
raise
ValueError
(
f
"Supported dtypes are float32, float16, bfloat16 (got
{
dtype
}
)"
)
raise
ValueError
(
f
"Supported dtypes are float32, float16, bfloat16 (got
{
dtype
}
)"
)
# Input tensor dims
output_dims
=
tuple
(
grad_output
.
size
())
input_dims
=
tuple
(
input_dims
)
weight_dims
=
tuple
(
weight_dims
)
if
len
(
weight_dims
)
!=
2
:
raise
ValueError
(
f
"Weight tensor is not 2D (shape=
{
weight_dims
}
)"
)
if
len
(
input_dims
)
==
0
or
weight_dims
[
1
]
!=
input_dims
[
-
1
]:
raise
ValueError
(
f
"Input tensor (shape=
{
input_dims
}
) "
f
"and weight tensor (shape=
{
weight_dims
}
) "
"are not compatible"
)
if
weight_dims
[
0
]
!=
output_dims
[
-
1
]:
raise
ValueError
(
f
"Grad output tensor (shape=
{
output_dims
}
) "
f
"and weight tensor (shape=
{
weight_dims
}
) "
"are not compatible"
)
# Check tensor parallel group
# Check tensor parallel group
if
tensor_parallel_size
is
None
:
if
tensor_parallel_size
is
None
:
tensor_parallel_size
=
get_distributed_world_size
(
tensor_parallel_group
)
tensor_parallel_size
=
get_distributed_world_size
(
tensor_parallel_group
)
...
@@ -227,373 +200,283 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -227,373 +200,283 @@ class UserbuffersBackwardLinear(FusedOperation):
if
not
sequence_parallel
:
if
not
sequence_parallel
:
raise
RuntimeError
(
f
"Invalid configuration for Userbuffers (
{
sequence_parallel
=
}
)"
)
raise
RuntimeError
(
f
"Invalid configuration for Userbuffers (
{
sequence_parallel
=
}
)"
)
# Check if FP8 is enabled
# dgrad GEMM is required
if
with_fp8_compute
:
if
not
input_requires_grad
:
if
grad_output_fp8_meta
is
None
and
not
is_float8_tensor
(
grad_output
):
warnings
.
warn
(
raise
ValueError
(
"No FP8 metadata was provided for casting output gradient to FP8"
)
"Linear input doesn't require gradient, "
"but Userbuffers implementation requires dgrad GEMM."
)
input_requires_grad
=
True
# Check quantizers
if
with_quantized_compute
:
if
weight_requires_grad
and
input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for input tensor"
)
if
input_requires_grad
and
weight_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for weight tensor"
)
if
grad_output_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for grad output tensor"
)
if
grad_input_quantizer
is
not
None
:
raise
ValueError
(
"Quantized grad input is not supported"
)
else
:
else
:
input_fp8_meta
=
None
input_quantizer
=
None
weight_fp8_meta
=
None
weight_quantizer
=
None
grad_output_fp8_meta
=
None
grad_output_quantizer
=
None
grad_input_fp8_meta
=
None
grad_input_quantizer
=
None
with_fp8_grad_input
=
(
with_fp8_compute
and
tensor_parallel_mode
!=
"column"
and
grad_input_fp8_meta
is
not
None
)
# Get Userbuffers communicators
and algorithms
# Get Userbuffers communicators
# Note:
c
ommunication patterns are (1) overlap dy all-gather
# Note:
C
ommunication patterns are (1) overlap dy all-gather
# with dgrad GEMM, (2) overlap x all-gather with dgrad GEMM
# with dgrad GEMM, (2) overlap x all-gather with dgrad GEMM
# and dx reduce-scatter with wgrad GEMM, (3) overlap dx
# and dx reduce-scatter with wgrad GEMM, (3) overlap dx
# reduce-scatter with dgrad GEMM
.
# reduce-scatter with dgrad GEMM
with_ub_all_gather_dy
=
Fals
e
ub_comm_dgrad
=
Non
e
with_ub_reduce_scatter_dx
=
Fals
e
ub_comm_wgrad
=
Non
e
with_ub_all_gather_x
=
Fals
e
ub_type_dgrad
=
Non
e
ub_
comm_dy
=
None
ub_
type_wgrad
=
None
ub_comm_dx
=
Non
e
with_bulk_overlap
=
Fals
e
ub_comm_x
=
Non
e
with_dgrad_all_gather_dy
=
Fals
e
ub_algo
_d
y
=
Non
e
with_dgrad_reduce_scatter
_d
x
=
Fals
e
ub_algo_dx
=
Non
e
with_dgrad_all_gather_x
=
Fals
e
ub_algo_x
=
Non
e
with_wgrad_reduce_scatter_dx
=
Fals
e
if
tensor_parallel_mode
==
"row"
:
if
tensor_parallel_mode
==
"row"
:
with_ub_all_gather_dy
=
True
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
)
ub_comm_dy
=
get_ub
(
ub_comm_name
+
"_dgrad"
)
ub_type_dgrad
=
CommOverlapType
.
AG
if
with_fp8_compute
and
ub_comm_dy
.
is_atomic_gemm
():
with_dgrad_all_gather_dy
=
True
ub_algo_dy
=
CommOverlapAlgo
.
ATOMIC_GEMM_AG_P2P
else
:
ub_algo_dy
=
CommOverlapAlgo
.
SPLIT_PIPELINED_AG_P2P
elif
tensor_parallel_mode
==
"column"
:
elif
tensor_parallel_mode
==
"column"
:
with_ub_reduce_scatter_dx
=
True
if
input_requires_grad
and
weight_requires_grad
:
if
weight_requires_grad
:
with_bulk_overlap
=
True
with_ub_all_gather_x
=
True
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
)
ub_comm_dx
=
get_ub
(
ub_comm_name
+
"_wgrad"
)
ub_type_dgrad
=
CommOverlapType
.
AG
ub_comm_x
=
get_ub
(
ub_comm_name
+
"_dgrad"
)
with_dgrad_all_gather_x
=
True
ub_algo_dx
=
CommOverlapAlgo
.
BULK_OVERLAP_RS
ub_comm_wgrad
=
get_ub
(
ub_comm_name
+
"_wgrad"
)
ub_algo_x
=
CommOverlapAlgo
.
BULK_OVERLAP_AG
ub_type_wgrad
=
CommOverlapType
.
RS
with_wgrad_reduce_scatter_dx
=
True
if
ub_comm_wgrad
.
is_fp8_ubuf
():
raise
RuntimeError
(
"Userbuffers reduce-scatter is not supported with FP8 buffers"
)
else
:
else
:
with_ub_all_gather_x
=
False
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
)
ub_comm_dx
=
get_ub
(
ub_comm_name
+
"_dgrad"
)
ub_type_dgrad
=
CommOverlapType
.
RS
is_atomic_gemm
=
with_fp8_compute
and
ub_comm_dx
.
is_atomic_gemm
()
with_dgrad_reduce_scatter_dx
=
True
ub_algo_dx
=
{
if
ub_comm_dgrad
.
is_fp8_ubuf
():
(
True
,
True
):
CommOverlapAlgo
.
ATOMIC_GEMM_RS_P2P
,
raise
RuntimeError
(
(
True
,
False
):
CommOverlapAlgo
.
SPLIT_PIPELINED_RS_P2P
,
"Userbuffers reduce-scatter is not supported with FP8 buffers"
(
False
,
True
):
CommOverlapAlgo
.
ATOMIC_GEMM_RS
,
)
(
False
,
False
):
CommOverlapAlgo
.
SPLIT_PIPELINED_RS
,
}[(
ub_comm_dx
.
is_p2p_overlap
(),
is_atomic_gemm
)]
# Compute grad bias if needed
# Check grad output tensor
# Note: Possibly fuse cast with computing grad bias
dy_local
=
reshape
(
grad_output
,
(
-
1
,
output_dims
[
-
1
]),
device
=
device
,
dtype
=
dtype
,
)
db
=
None
db
=
None
db_async
=
None
db_async
=
None
if
bias_requires_grad
and
with_fp8_compute
and
with_ub_all_gather_dy
:
if
bias_requires_grad
:
# We don't have a grad bias impl that takes FP8 input. For
db
=
grad_output
.
sum
(
tuple
(
range
(
grad_output
.
dim
()
-
1
)))
# cases where we cast to FP8 and all-gather, it's better
if
tensor_parallel_mode
==
"row"
:
# to compute the grad bias on ungathered, non-FP8 values.
db_async
=
torch
.
distributed
.
all_reduce
(
db
=
dy_local
.
sum
(
dim
=
0
)
db
,
db_async
=
torch
.
distributed
.
all_reduce
(
group
=
tensor_parallel_group
,
db
,
async_op
=
True
,
group
=
tensor_parallel_group
,
async_op
=
True
,
)
if
with_fp8_compute
and
not
is_float8_tensor
(
dy_local
):
fp8_dtype
=
get_fp8_te_dtype
(
grad_output_fp8_meta
[
"recipe"
],
fprop_tensor
=
False
,
)
if
bias_requires_grad
and
db
is
None
:
# Fused cast-transpose-bgrad
fp8_meta_key
=
FP8GlobalStateManager
.
get_meta_tensor_key
(
forward
=
False
)
fp8_scale_inv
=
torch
.
empty
([
1
],
dtype
=
torch
.
float32
,
device
=
device
)
db
,
data
,
data_transpose
=
fp8_cast_transpose_bgrad_fused
(
dy_local
,
grad_output_fp8_meta
[
fp8_meta_key
],
0
,
fp8_dtype
,
scale_inv
=
fp8_scale_inv
,
)
if
with_ub_all_gather_dy
:
data
=
ub_comm_dy
.
get_ubuf_output
(
0
).
copy_
(
data
)
dy_local
=
Float8Tensor
(
data
=
data
,
fp8_meta
=
grad_output_fp8_meta
,
fp8_meta_forward
=
False
,
fp8_meta_index
=
0
,
fp8_dtype
=
fp8_dtype
,
fp8_scale_inv
=
fp8_scale_inv
,
dtype
=
dtype
,
data_transpose
=
data_transpose
,
)
)
else
:
dy_local
=
Float8Tensor
.
to_float8
(
# Cast grad output tensor dtype if needed
dy_local
,
dy_local
=
grad_output
fp8_meta
=
grad_output_fp8_meta
,
if
with_quantized_compute
:
fp8_meta_forward
=
False
,
if
not
isinstance
(
dy_local
,
QuantizedTensorBase
):
fp8_meta_index
=
0
,
with_columnwise
=
weight_requires_grad
fp8_dtype
=
fp8_dtype
,
if
(
data
=
(
ub_comm_dy
.
get_ubuf_output
(
0
)
if
with_ub_all_gather_dy
else
None
),
with_columnwise
with_transpose_cache
=
(
not
with_ub_all_gather_dy
),
and
with_dgrad_all_gather_dy
and
not
isinstance
(
grad_output_quantizer
,
MXFP8Quantizer
)
):
with_columnwise
=
False
grad_output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
with_columnwise
,
)
)
elif
not
with_fp8_compute
and
is_float8_tensor
(
dy_local
):
dy_local
=
grad_output_quantizer
(
dy_local
)
if
with_ub_all_gather_dy
:
else
:
ub_local_buffer
=
ub_comm_dy
.
get_ubuf_output
(
0
)
if
isinstance
(
dy_local
,
QuantizedTensorBase
):
dy_local
=
ub_local_buffer
.
copy_
(
dy_local
)
dy_local
=
dy_local
.
dequantize
(
dtype
=
dtype
)
else
:
elif
dy_local
.
dtype
!=
dtype
:
dy_local
=
dy_local
.
dequantize
()
dy_local
=
dy_local
.
to
(
dtype
=
dtype
)
if
bias_requires_grad
and
db
is
None
and
with_fp8_compute
and
with_ub_all_gather_dy
:
# Cast weight tensor dtype if needed
# We don't have a fused grad bias impl that takes FP8
if
weight
is
None
:
# input. For cases where we cast to FP8 and all-gather,
raise
ValueError
(
"Weight tensor is required to compute input grad"
)
# it's better to compute the grad bias on ungathered,
w
=
weight
# non-FP8 values.
if
with_quantized_compute
:
db
=
dy_local
.
sum
(
dim
=
0
)
if
not
isinstance
(
w
,
QuantizedTensorBase
):
db_async
=
torch
.
distributed
.
all_reduce
(
weight_quantizer
.
set_usage
(
columnwise
=
True
)
db
,
w
=
weight_quantizer
(
w
)
group
=
tensor_parallel_group
,
else
:
async_op
=
True
,
if
isinstance
(
w
,
QuantizedTensorBase
):
)
w
=
w
.
dequantize
(
dtype
=
dtype
)
elif
w
.
dtype
!=
dtype
:
w
=
w
.
to
(
dtype
=
dtype
)
# C
heck
input tensor
# C
ast
input tensor
dtype if needed
x_local
=
None
x_local
=
None
if
weight_requires_grad
:
if
weight_requires_grad
:
x_local
=
reshape
(
if
input
is
None
:
input
,
raise
ValueError
(
"Input tensor is required to compute weight grad"
)
(
-
1
,
input_dims
[
-
1
]),
x_local
=
input
device
=
device
,
if
with_quantized_compute
:
dtype
=
dtype
,
if
not
isinstance
(
x_local
,
QuantizedTensorBase
):
)
input_quantizer
.
set_usage
(
columnwise
=
True
)
if
with_fp8_compute
and
not
is_float8_tensor
(
x_local
):
x_local
=
input_quantizer
(
x_local
)
fp8_dtype
=
get_fp8_te_dtype
(
else
:
input_fp8_meta
[
"recipe"
],
if
isinstance
(
x_local
,
QuantizedTensorBase
):
fprop_tensor
=
True
,
x_local
=
x_local
.
dequantize
(
dtype
=
dtype
)
elif
x_local
.
dtype
!=
dtype
:
x_local
=
x_local
.
to
(
dtype
=
dtype
)
# dgrad GEMM
dx_local
=
None
dx
=
None
dy
=
None
x
=
None
if
input_requires_grad
:
# Initialize grad output
if
with_dgrad_all_gather_dy
:
if
grad_output_quantizer
is
not
None
:
grad_output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
dy
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ub_comm_dgrad
,
dy_local
,
grad_output_quantizer
,
tensor_parallel_group
,
)
)
x_local
=
Float8Tensor
.
to_float8
(
else
:
dy
=
dy_local
# Construct grad input tensor if needed
if
with_dgrad_reduce_scatter_dx
or
with_wgrad_reduce_scatter_dx
:
dx_size
=
list
(
dy
.
size
())
dx_size
[
-
1
]
=
w
.
size
(
-
1
)
dx_local_size
=
list
(
dx_size
)
dx_local_size
[
0
]
//=
tensor_parallel_size
if
with_dgrad_reduce_scatter_dx
:
dx_local
=
torch
.
empty
(
dx_local_size
,
dtype
=
dtype
,
device
=
device
,
)
elif
with_wgrad_reduce_scatter_dx
:
dx_local
=
ub_comm_wgrad
.
get_buffer
(
local_chunk
=
True
,
shape
=
dx_local_size
,
)
dx
=
ub_comm_wgrad
.
get_buffer
(
local_chunk
=
False
,
shape
=
dx_size
,
)
# Initialize input tensor if needed
if
with_dgrad_all_gather_x
:
if
input_quantizer
is
not
None
:
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
x
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ub_comm_dgrad
,
x_local
,
x_local
,
fp8_meta
=
input_fp8_meta
,
input_quantizer
,
fp8_meta_forward
=
True
,
tensor_parallel_group
,
fp8_meta_index
=
0
,
fp8_dtype
=
fp8_dtype
,
data
=
(
ub_comm_x
.
get_ubuf_output
(
0
)
if
with_ub_all_gather_x
else
None
),
with_transpose_cache
=
(
not
with_ub_all_gather_x
),
)
)
elif
not
with_fp8_compute
and
is_float8_tensor
(
x_local
):
if
with_ub_all_gather_x
:
# Perform dgrad GEMM
ub_local_buffer
=
ub_comm_x
.
get_ubuf_output
(
0
)
dx
,
*
_
=
general_gemm
(
x_local
=
ub_local_buffer
.
copy_
(
x_local
)
else
:
x_local
=
x_local
.
dequantize
()
# Check weight tensor
w
=
convert_tensor
(
weight
,
device
=
device
,
dtype
=
dtype
,
memory_format
=
torch
.
contiguous_format
,
)
if
with_fp8_compute
and
not
is_float8_tensor
(
w
):
fp8_dtype
=
get_fp8_te_dtype
(
weight_fp8_meta
[
"recipe"
],
fprop_tensor
=
True
,
)
w
=
Float8Tensor
.
to_float8
(
w
,
w
,
fp8_meta
=
weight_fp8_meta
,
dy
,
fp8_meta_forward
=
True
,
get_workspace
(),
fp8_meta_index
=
0
,
out_dtype
=
dtype
,
fp8_dtype
=
fp8_dtype
,
quantization_params
=
grad_input_quantizer
,
with_transpose_cache
=
True
,
layout
=
"NN"
,
out
=
dx
,
use_split_accumulator
=
_2X_ACC_DGRAD
,
grad
=
True
,
ub
=
ub_comm_dgrad
,
ub_type
=
ub_type_dgrad
,
extra_output
=
dx_local
if
with_dgrad_reduce_scatter_dx
else
None
,
bulk_overlap
=
with_bulk_overlap
,
)
)
elif
not
with_fp8_compute
and
is_float8_tensor
(
w
):
if
not
(
with_dgrad_reduce_scatter_dx
or
with_wgrad_reduce_scatter_dx
):
w
=
w
.
dequantize
()
dx_local
=
dx
# Initialize buffers for UB all-gather if needed
dy
=
dy_local
x
=
x_local
if
with_ub_all_gather_dy
:
ub_local_buffer
=
ub_comm_dy
.
get_ubuf_output
(
0
)
ub_global_buffer
=
ub_comm_dy
.
get_ubuf_output
(
1
)
if
with_fp8_compute
:
dy
=
Float8Tensor
.
make_like
(
dy_local
,
data
=
ub_global_buffer
)
if
dy_local
.
_data
.
data_ptr
()
!=
ub_local_buffer
.
data_ptr
():
ub_local_buffer
.
copy_
(
dy_local
.
_data
)
else
:
dy
=
ub_global_buffer
if
dy_local
.
data_ptr
()
!=
ub_local_buffer
.
data_ptr
():
ub_local_buffer
.
copy_
(
dy_local
)
if
with_ub_all_gather_x
:
ub_local_buffer
=
ub_comm_x
.
get_ubuf_output
(
0
)
ub_global_buffer
=
ub_comm_x
.
get_ubuf_output
(
1
)
if
with_fp8_compute
:
x
=
Float8Tensor
.
make_like
(
x_local
,
data
=
ub_global_buffer
)
if
x_local
.
_data
.
data_ptr
()
!=
ub_local_buffer
.
data_ptr
():
ub_local_buffer
.
copy_
(
x_local
.
_data
)
else
:
x
=
ub_global_buffer
if
x_local
.
data_ptr
()
!=
ub_local_buffer
.
data_ptr
():
ub_local_buffer
.
copy_
(
x_local
)
# Construct grad input tensor
# wgrad GEMM
dx
=
None
dw
=
None
dx_local
=
None
if
weight_requires_grad
:
if
with_ub_reduce_scatter_dx
:
# Initialize buffers for UB reduce-scatter
# Initialize grad output
dx
=
ub_comm_dx
.
get_ubuf_output
(
1
)
if
tensor_parallel_mode
==
"row"
and
isinstance
(
grad_output_quantizer
,
MXFP8Quantizer
):
ub_local_buffer
=
ub_comm_dx
.
get_ubuf_output
(
0
)
# UB does not support overlapping grad output
if
with_ub_all_gather_x
:
# all-gather with wgrad GEMM. Also, MXFP8 does not
dx_local
=
ub_local_buffer
# allow reusing the grad output that was gathered for
else
:
# the dgrad GEMM. We work around with blocking
dx_local
=
torch
.
empty_like
(
ub_local_buffer
)
# all-gather for column-scaled MXFP8 data.
else
:
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
# Allocate grad input tensor
dy
,
_
=
gather_along_first_dim
(
if
with_fp8_grad_input
:
grad_output
,
fp8_dtype
=
get_fp8_te_dtype
(
tensor_parallel_group
,
grad_input_fp8_meta
[
"recipe"
],
quantizer
=
grad_output_quantizer
,
fprop_tensor
=
False
,
)
data
=
torch
.
empty
(
(
dy
.
size
(
0
),
w
.
size
(
-
1
)),
dtype
=
torch
.
uint8
,
device
=
device
,
)
dx
=
Float8Tensor
(
data
=
data
,
fp8_meta
=
grad_input_fp8_meta
,
fp8_meta_forward
=
False
,
fp8_meta_index
=
0
,
fp8_dtype
=
fp8_dtype
,
dtype
=
dtype
,
)
)
else
:
if
tensor_parallel_mode
==
"column"
:
d
x
=
torch
.
empty
(
d
y
=
dy_local
(
dy
.
size
(
0
),
w
.
size
(
-
1
)),
if
dy
is
None
:
dtype
=
dtype
,
raise
RuntimeError
(
device
=
device
,
"wgrad GEMM requires grad output tensor, which has not been initialized"
)
)
dx_local
=
dx
if
isinstance
(
dy
,
QuantizedTensorBase
):
dy
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
# Allocate grad input tensor
# Initialize input tensor
if
grad_weight
is
None
:
if
tensor_parallel_mode
==
"row"
:
if
accumulate_into_grad_weight
:
x
=
x_local
raise
ValueError
(
if
x
is
None
:
"Attempted to accumulate into grad weight bufferwithout providing grad weight"
raise
RuntimeError
(
"wgrad GEMM requires input tensor, which has not been initialized"
)
)
grad_weight
=
torch
.
empty
(
if
isinstance
(
x
,
QuantizedTensorBase
):
weight_dims
,
x
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
dtype
=
dtype
,
device
=
device
,
# Check grad weight tensor
memory_format
=
torch
.
contiguous_format
,
dw
=
grad_weight
)
dw_dtype
=
dtype
if
dw
is
None
:
if
accumulate_into_grad_weight
:
raise
ValueError
(
"Attempted to accumulate into grad weight tensor "
"without providing grad weight tensor"
)
else
:
dw_dtype
=
dw
.
dtype
# Perform dgrad GEMM
# Perform wgrad GEMM
if
with_fp8_compute
:
dw
,
*
_
=
general_gemm
(
kwargs
=
{
"out"
:
dx
,
"use_split_accumulator"
:
True
}
if
with_ub_all_gather_dy
:
kwargs
[
"ub_algo"
]
=
ub_algo_dy
kwargs
[
"ub"
]
=
ub_comm_dy
elif
with_ub_all_gather_x
:
kwargs
[
"ub_algo"
]
=
ub_algo_x
kwargs
[
"ub"
]
=
ub_comm_x
elif
with_ub_reduce_scatter_dx
:
kwargs
[
"ub_algo"
]
=
ub_algo_dx
kwargs
[
"ub"
]
=
ub_comm_dx
kwargs
[
"extra_output_tensor"
]
=
dx_local
if
with_fp8_grad_input
:
fp8_meta
,
fp8_meta_index
=
get_fp8_meta_from_fp8_tensor
(
dx
)
kwargs
.
update
(
{
"out"
:
dx
.
_data
,
"out_index"
:
fp8_meta_index
,
"fp8_meta_tensor"
:
fp8_meta
,
"D_dtype"
:
dx
.
_fp8_dtype
,
}
)
fp8_gemm
(
w
.
transpose_2d
(),
w
.
_scale_inv
,
0
,
w
.
_fp8_dtype
,
dy
.
_data
,
dy
.
_scale_inv
,
0
,
dy
.
_fp8_dtype
,
dy
.
dtype
,
get_workspace
(),
**
kwargs
,
)
else
:
kwargs
=
{
"grad"
:
True
,
"layout"
:
"NN"
,
"out"
:
dx
}
if
with_ub_all_gather_dy
:
kwargs
[
"ub_algo"
]
=
ub_algo_dy
kwargs
[
"ub"
]
=
ub_comm_dy
elif
with_ub_all_gather_x
:
kwargs
[
"ub_algo"
]
=
ub_algo_x
kwargs
[
"ub"
]
=
ub_comm_x
elif
with_ub_reduce_scatter_dx
:
kwargs
[
"ub_algo"
]
=
ub_algo_dx
kwargs
[
"ub"
]
=
ub_comm_dx
kwargs
[
"extra_output_tensor"
]
=
dx_local
gemm
(
w
,
dy
,
dx
.
dtype
,
get_workspace
(),
**
kwargs
)
grad_input
=
reshape
(
dx_local
,
input_dims
)
# Perform wgrad GEMM
if
not
weight_requires_grad
:
pass
elif
with_fp8_compute
:
kwargs
=
{
"accumulate"
:
accumulate_into_grad_weight
,
"out"
:
grad_weight
,
"use_split_accumulator"
:
True
,
}
if
with_ub_reduce_scatter_dx
:
kwargs
[
"ub_algo"
]
=
ub_algo_dx
kwargs
[
"ub"
]
=
ub_comm_dx
fp8_gemm
(
x
.
transpose_2d
(),
x
.
_scale_inv
,
0
,
x
.
_fp8_dtype
,
dy
.
transpose_2d
(),
dy
.
_scale_inv
,
0
,
dy
.
_fp8_dtype
,
grad_weight
.
dtype
,
get_workspace
(),
**
kwargs
,
)
else
:
kwargs
=
{
"accumulate"
:
accumulate_into_grad_weight
,
"layout"
:
"NT"
,
"grad"
:
True
,
"use_bias"
:
bias_requires_grad
,
"out"
:
grad_weight
,
}
if
with_ub_reduce_scatter_dx
:
kwargs
[
"ub_algo"
]
=
ub_algo_dx
kwargs
[
"ub"
]
=
ub_comm_dx
grad_weight
,
db
,
_
=
gemm
(
x
,
x
,
dy
,
dy
,
grad_weight
.
dtype
,
get_workspace
(),
get_workspace
(),
**
kwargs
,
out_dtype
=
dw_dtype
,
accumulate
=
accumulate_into_grad_weight
,
layout
=
"NT"
,
out
=
dw
,
use_split_accumulator
=
_2X_ACC_WGRAD
,
grad
=
True
,
ub
=
ub_comm_wgrad
,
ub_type
=
ub_type_wgrad
,
bulk_overlap
=
with_bulk_overlap
,
)
)
# Bulk overlap reduce-scatter with non-FP8 buffer is
# in-place. Need to copy grad input tensor to avoid data
# corruption in Userbuffers buffer.
if
with_wgrad_reduce_scatter_dx
:
dx_local
=
dx_local
.
clone
()
# Compute grad bias if needed
# Compute grad bias if needed
if
db_async
is
not
None
:
if
db_async
is
not
None
:
db_async
.
wait
()
db_async
.
wait
()
if
bias_requires_grad
:
if
bias_requires_grad
:
if
db
is
None
:
db
=
dy
.
sum
(
dim
=
0
)
extra_outputs
[
"grad_bias"
]
=
db
extra_outputs
[
"grad_bias"
]
=
db
return
grad_input
,
grad_weight
,
extra_outputs
return
dx_local
,
dw
,
extra_outputs
def
fuser_backward
(
def
fuser_backward
(
self
,
self
,
...
@@ -633,40 +516,24 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -633,40 +516,24 @@ class UserbuffersBackwardLinear(FusedOperation):
else
:
else
:
accumulate_into_main_grad
=
False
accumulate_into_main_grad
=
False
# Hackily workaround Userbuffers bug with non-FP8 dgrad
# reduce-scatter overlap
weight_requires_grad
=
linear_op_ctx
.
weight_requires_grad
if
not
linear_op_ctx
.
with_fp8_compute
and
not
weight_requires_grad
:
warnings
.
warn
(
"There is a correctness bug when using Userbuffers "
"to overlap a dgrad reduce-scatter with a non-FP8 dgrad GEMM. "
"Hackily working around by overlapping dgrad reduce-scatter "
"with wgrad GEMM, even though wgrad isn't needed. "
"Please contact Transformer Engine team "
"if you encounter this use-case."
)
weight_requires_grad
=
True
# Linear backward pass
# Linear backward pass
retval
=
UserbuffersBackwardLinear
.
_functional_backward
(
retval
=
UserbuffersBackwardLinear
.
_functional_backward
(
grad_output
=
grad_output
,
grad_output
=
grad_output
,
input
=
x_local
,
input
=
x_local
,
weight
=
linear_op
.
weight
,
weight
=
linear_op
.
weight
,
input_dims
=
linear_op_ctx
.
input_dims
,
weight_requires_grad
=
linear_op_ctx
.
weight_requires_grad
,
weight_dims
=
linear_op
.
weight
.
size
(),
weight_requires_grad
=
weight_requires_grad
,
bias_requires_grad
=
(
bias_op
is
not
None
),
bias_requires_grad
=
(
bias_op
is
not
None
),
device
=
linear_op
.
device
,
dtype
=
linear_op_ctx
.
dtype
,
dtype
=
linear_op_ctx
.
dtype
,
grad_weight
=
grad_weight
,
grad_weight
=
grad_weight
,
accumulate_into_grad_weight
=
accumulate_into_main_grad
,
accumulate_into_grad_weight
=
accumulate_into_main_grad
,
tensor_parallel_mode
=
self
.
tensor_parallel_mode
,
tensor_parallel_mode
=
self
.
tensor_parallel_mode
,
tensor_parallel_group
=
self
.
tensor_parallel_group
,
tensor_parallel_group
=
self
.
tensor_parallel_group
,
sequence_parallel
=
self
.
sequence_parallel
,
sequence_parallel
=
self
.
sequence_parallel
,
with_fp8_compute
=
linear_op_ctx
.
with_fp8_compute
,
with_quantized_compute
=
linear_op_ctx
.
with_quantized_compute
,
weight_fp8_meta
=
linear_op_ctx
.
weight_fp8_meta
,
input_quantizer
=
linear_op_ctx
.
input_quantizer
,
grad_output_fp8_meta
=
linear_op_ctx
.
grad_output_fp8_meta
,
weight_quantizer
=
linear_op_ctx
.
weight_quantizer
,
grad_input_fp8_meta
=
linear_op_ctx
.
grad_input_fp8_meta
,
grad_output_quantizer
=
linear_op_ctx
.
grad_output_quantizer
,
grad_input_quantizer
=
None
,
# Not supported
ub_comm_name
=
linear_op
.
_userbuffers_options
[
"comm_name"
],
ub_comm_name
=
linear_op
.
_userbuffers_options
[
"comm_name"
],
)
)
grad_input
,
grad_weight
,
extra_outputs
=
retval
grad_input
,
grad_weight
,
extra_outputs
=
retval
...
@@ -707,8 +574,6 @@ def fuse_userbuffers_backward_linear(
...
@@ -707,8 +574,6 @@ def fuse_userbuffers_backward_linear(
"""
"""
return
ops
### TODO Debug Userbuffers support
# Return immediately if environment is not distributed
# Return immediately if environment is not distributed
if
not
torch
.
distributed
.
is_initialized
()
or
torch
.
distributed
.
get_world_size
()
==
1
:
if
not
torch
.
distributed
.
is_initialized
()
or
torch
.
distributed
.
get_world_size
()
==
1
:
return
ops
return
ops
...
...
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
View file @
f8c2af4c
...
@@ -4,20 +4,25 @@
...
@@ -4,20 +4,25 @@
"""Linear layer forward with Userbuffers communication."""
"""Linear layer forward with Userbuffers communication."""
# pylint: skip-file ### TODO Debug Userbuffers support
from
__future__
import
annotations
from
__future__
import
annotations
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
import
torch
import
torch
from
transformer_engine_torch
import
CommOverlap
Algo
from
transformer_engine_torch
import
CommOverlap
Type
from
...cpp_extensions
import
general_gemm
from
...cpp_extensions
import
general_gemm
from
...distributed
import
get_distributed_world_size
from
...distributed
import
get_distributed_world_size
from
...float8_tensor
import
Float8Tensor
from
...fp8
import
FP8GlobalStateManager
from
...fp8
import
FP8GlobalStateManager
,
get_fp8_te_dtype
from
...module.base
import
(
from
...module.base
import
get_ub
,
get_workspace
fill_userbuffers_buffer_for_all_gather
,
get_ub
,
get_workspace
,
_2X_ACC_FPROP
,
)
from
...tensor.quantized_tensor
import
QuantizedTensorBase
,
Quantizer
from
...tensor.float8_tensor
import
Float8Quantizer
from
...tensor._internal.float8_tensor_base
import
Float8TensorBase
from
...utils
import
canonicalize_device
,
canonicalize_dtype
from
...utils
import
canonicalize_device
,
canonicalize_dtype
from
..basic
import
BasicLinear
,
Bias
,
ReduceScatter
from
..basic
import
BasicLinear
,
Bias
,
ReduceScatter
from
..op
import
(
from
..op
import
(
...
@@ -26,12 +31,6 @@ from ..op import (
...
@@ -26,12 +31,6 @@ from ..op import (
FusibleOperation
,
FusibleOperation
,
OperationContext
,
OperationContext
,
)
)
from
.._common
import
(
convert_tensor
,
get_fp8_meta_from_fp8_tensor
,
is_float8_tensor
,
reshape
,
)
class
UserbuffersForwardLinear
(
FusedOperation
):
class
UserbuffersForwardLinear
(
FusedOperation
):
...
@@ -51,9 +50,6 @@ class UserbuffersForwardLinear(FusedOperation):
...
@@ -51,9 +50,6 @@ class UserbuffersForwardLinear(FusedOperation):
reduce_scatter
:
Optional
[
ReduceScatter
],
reduce_scatter
:
Optional
[
ReduceScatter
],
)
->
None
:
)
->
None
:
### TODO Debug Userbuffers support
raise
NotImplementedError
(
"Userbuffers support has been broken by recent refactors"
)
# Basic operations that comprise this fused operation
# Basic operations that comprise this fused operation
op_idxs
=
{
"linear"
:
0
,
"bias"
:
None
,
"reduce_scatter"
:
None
}
op_idxs
=
{
"linear"
:
0
,
"bias"
:
None
,
"reduce_scatter"
:
None
}
ops
=
[
linear
]
ops
=
[
linear
]
...
@@ -98,10 +94,10 @@ class UserbuffersForwardLinear(FusedOperation):
...
@@ -98,10 +94,10 @@ class UserbuffersForwardLinear(FusedOperation):
tensor_parallel_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
tensor_parallel_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
tensor_parallel_size
:
Optional
[
int
]
=
None
,
tensor_parallel_size
:
Optional
[
int
]
=
None
,
sequence_parallel
:
bool
=
False
,
sequence_parallel
:
bool
=
False
,
with_
fp8
_compute
:
bool
=
False
,
with_
quantized
_compute
:
bool
=
False
,
input_
fp8_meta
:
Optional
[
dict
[
str
,
Any
]
]
=
None
,
input_
quantizer
:
Optional
[
Quantizer
]
=
None
,
weight_
fp8_meta
:
Optional
[
dict
[
str
,
Any
]
]
=
None
,
weight_
quantizer
:
Optional
[
Quantizer
]
=
None
,
output_
fp8_meta
:
Optional
[
dict
[
str
,
Any
]
]
=
None
,
output_
quantizer
:
Optional
[
Quantizer
]
=
None
,
ub_comm_name
:
str
,
ub_comm_name
:
str
,
)
->
tuple
[
torch
.
Tensor
,
dict
]:
)
->
tuple
[
torch
.
Tensor
,
dict
]:
"""Functional API for forward pass
"""Functional API for forward pass
...
@@ -127,16 +123,14 @@ class UserbuffersForwardLinear(FusedOperation):
...
@@ -127,16 +123,14 @@ class UserbuffersForwardLinear(FusedOperation):
parallelism, i.e. distributing input or output tensors
parallelism, i.e. distributing input or output tensors
along outer dimension (sequence or batch dim) when not
along outer dimension (sequence or batch dim) when not
distributing along inner dimension (embedding dim)
distributing along inner dimension (embedding dim)
with_fp8_compute: bool, default = `False`
with_quantized_compute: bool, default = `False`
Whether to perform compute in FP8
Whether to perform compute with quantized data.
input_fp8_meta: dict, optional
input_quantizer: Quantizer, optional
FP8 metadata for casting input tensor to FP8. Required for
Builder class for quantized input tensor.
FP8 compute if input is not already in FP8.
weight_quantizer: Quantizer, optional
weight_fp8_meta: dict, optional
Builder class for quantized weight tensor.
FP8 metadata for casting weight tensor to FP8. Required for
output_quantizer: Quantizer, optional
FP8 compute if weight is not already in FP8.
Builder class for quantized output tensor.
output_fp8_meta: dict, optional
FP8 metadata for casting output tensor to FP8
ub_comm_name: str
ub_comm_name: str
Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is
Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is
used to access the corresponding Userbuffers communicators
used to access the corresponding Userbuffers communicators
...
@@ -166,23 +160,6 @@ class UserbuffersForwardLinear(FusedOperation):
...
@@ -166,23 +160,6 @@ class UserbuffersForwardLinear(FusedOperation):
if
dtype
not
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
if
dtype
not
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
raise
ValueError
(
f
"Supported dtypes are float32, float16, bfloat16 (got
{
dtype
}
)"
)
raise
ValueError
(
f
"Supported dtypes are float32, float16, bfloat16 (got
{
dtype
}
)"
)
# Input tensor dims
input_dims
=
tuple
(
input
.
size
())
weight_dims
=
tuple
(
weight
.
size
())
if
len
(
weight_dims
)
!=
2
:
raise
ValueError
(
f
"Weight tensor is not 2D (shape=
{
weight_dims
}
)"
)
if
len
(
input_dims
)
==
0
or
weight_dims
[
1
]
!=
input_dims
[
-
1
]:
raise
ValueError
(
f
"Input tensor (shape=
{
input_dims
}
) "
f
"and weight tensor (shape=
{
weight_dims
}
) "
"are not compatible"
)
# Output tensor dims
output_dims
=
list
(
input_dims
)
output_dims
[
0
]
=
-
1
output_dims
[
-
1
]
=
weight_dims
[
0
]
# Check tensor parallel group
# Check tensor parallel group
if
tensor_parallel_size
is
None
:
if
tensor_parallel_size
is
None
:
tensor_parallel_size
=
get_distributed_world_size
(
tensor_parallel_group
)
tensor_parallel_size
=
get_distributed_world_size
(
tensor_parallel_group
)
...
@@ -196,235 +173,106 @@ class UserbuffersForwardLinear(FusedOperation):
...
@@ -196,235 +173,106 @@ class UserbuffersForwardLinear(FusedOperation):
if
not
sequence_parallel
:
if
not
sequence_parallel
:
raise
RuntimeError
(
f
"Invalid configuration for Userbuffers (
{
sequence_parallel
=
}
)"
)
raise
RuntimeError
(
f
"Invalid configuration for Userbuffers (
{
sequence_parallel
=
}
)"
)
# Check if FP8 is enabled
# Check quantizers
if
with_fp8_compute
:
if
with_quantized_compute
:
if
input_fp8_meta
is
None
and
not
is_float8_tensor
(
input
):
if
input_quantizer
is
None
:
raise
ValueError
(
"No FP8 metadata was provided for casting input to FP8"
)
raise
ValueError
(
"Missing quantizer for input tensor"
)
if
weight_fp8_meta
is
None
and
not
is_float8_tensor
(
weight
):
if
weight_quantizer
is
None
:
raise
ValueError
(
"No FP8 metadata was provided for casting weight to FP8"
)
raise
ValueError
(
"Missing quantizer for weight tensor"
)
if
output_quantizer
is
not
None
:
raise
ValueError
(
"FP8 output is not supported"
)
else
:
else
:
input_fp8_meta
=
None
input_quantizer
=
None
weight_fp8_meta
=
None
weight_quantizer
=
None
output_fp8_meta
=
None
output_quantizer
=
None
with_fp8_output
=
(
with_fp8_compute
and
tensor_parallel_mode
!=
"row"
and
output_fp8_meta
is
not
None
)
# Get Userbuffers communicator
# Get Userbuffers communicator
ub_comm
=
get_ub
(
ub_comm_name
+
"_fprop"
)
ub_comm
=
get_ub
(
ub_comm_name
+
"_fprop"
)
ub_local_buffer
=
ub_comm
.
get_ubuf_output
(
0
)
ub_global_buffer
=
ub_comm
.
get_ubuf_output
(
1
)
with_ub_all_gather
=
tensor_parallel_mode
==
"column"
with_ub_all_gather
=
tensor_parallel_mode
==
"column"
with_ub_reduce_scatter
=
tensor_parallel_mode
==
"row"
with_ub_reduce_scatter
=
tensor_parallel_mode
==
"row"
ub_type
=
CommOverlapType
.
AG
if
with_ub_all_gather
else
CommOverlapType
.
RS
# Choose Userbuffers communication algorithm
# Initialize input tensor
ub_algo
=
None
x_local
=
input
x
=
None
if
with_ub_all_gather
:
if
with_ub_all_gather
:
if
with_fp8_compute
and
ub_comm
.
is_atomic_gemm
():
if
input_quantizer
is
not
None
:
ub_algo
=
CommOverlapAlgo
.
ATOMIC_GEMM_AG_P2P
if
not
isinstance
(
x_local
,
QuantizedTensorBase
):
else
:
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
ub_algo
=
CommOverlapAlgo
.
SPLIT_PIPELINED_AG_P2P
if
isinstance
(
input_quantizer
,
Float8Quantizer
):
elif
with_ub_reduce_scatter
:
input_quantizer
.
set_usage
(
columnwise
=
False
)
is_atomic_gemm
=
with_fp8_compute
and
ub_comm
.
is_atomic_gemm
()
x_local
=
input_quantizer
(
x_local
)
ub_algo
=
{
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
(
True
,
True
):
CommOverlapAlgo
.
ATOMIC_GEMM_RS_P2P
,
x
,
x_local
=
fill_userbuffers_buffer_for_all_gather
(
(
True
,
False
):
CommOverlapAlgo
.
SPLIT_PIPELINED_RS_P2P
,
ub_comm
,
(
False
,
True
):
CommOverlapAlgo
.
ATOMIC_GEMM_RS
,
(
False
,
False
):
CommOverlapAlgo
.
SPLIT_PIPELINED_RS
,
}[(
ub_comm
.
is_p2p_overlap
(),
is_atomic_gemm
)]
else
:
raise
RuntimeError
(
"Could not choose Userbuffers communication algorithm"
)
# Cast input tensor to correct dtype
x_local
=
reshape
(
input
,
(
-
1
,
input_dims
[
-
1
]),
device
=
device
,
dtype
=
dtype
,
)
if
with_fp8_compute
and
not
is_float8_tensor
(
x_local
):
fp8_dtype
=
get_fp8_te_dtype
(
input_fp8_meta
[
"recipe"
],
fprop_tensor
=
True
,
)
with_transpose_cache
=
weight
.
requires_grad
if
tensor_parallel_mode
==
"column"
and
sequence_parallel
:
with_transpose_cache
=
False
x_local
=
Float8Tensor
.
to_float8
(
x_local
,
x_local
,
fp8_meta
=
input_fp8_meta
,
input_quantizer
,
fp8_meta_forward
=
True
,
tensor_parallel_group
,
fp8_meta_index
=
0
,
fp8_dtype
=
fp8_dtype
,
data
=
(
ub_local_buffer
if
with_ub_all_gather
else
None
),
with_transpose_cache
=
with_transpose_cache
,
)
)
elif
not
with_fp8_compute
and
is_float8_tensor
(
x_local
):
else
:
if
with_ub_all_gather
:
if
with_quantized_compute
:
x_local
=
ub_local_buffer
.
copy_
(
x_local
)
if
not
isinstance
(
x_local
,
QuantizedTensorBase
):
else
:
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
x_local
=
x_local
.
dequantize
()
x_local
=
input_quantizer
(
x_local
)
# Initialize buffers for UB all-gather if needed
x
=
x_local
if
with_ub_all_gather
:
if
with_fp8_compute
:
x
=
Float8Tensor
.
make_like
(
x_local
,
data
=
ub_global_buffer
)
if
x_local
.
_data
.
data_ptr
()
!=
ub_local_buffer
.
data_ptr
():
ub_local_buffer
.
copy_
(
x_local
.
_data
)
else
:
x_local
.
_data
=
torch
.
empty_like
(
x_local
.
_data
)
else
:
else
:
x
=
ub_global_buffer
if
isinstance
(
x_local
,
QuantizedTensorBase
):
if
x_local
.
data_ptr
()
!=
ub_local_buffer
.
data_ptr
():
x_local
=
x_local
.
dequantize
(
dtype
=
dtype
)
ub_local_buffer
.
copy_
(
x_local
)
if
x_local
.
dtype
!=
dtype
:
else
:
x_local
=
x_local
.
to
(
dtype
=
dtype
)
x_local
=
torch
.
empty_like
(
x_local
)
x
=
x_local
# Check weight tensor
# Initialize weight tensor
w
=
convert_tensor
(
w
=
weight
weight
,
w_is_quantized
=
isinstance
(
w
,
QuantizedTensorBase
)
device
=
device
,
if
with_quantized_compute
and
not
w_is_quantized
:
dtype
=
dtype
,
weight_quantizer
.
set_usage
(
rowwise
=
True
)
memory_format
=
torch
.
contiguous_format
,
w
=
weight_quantizer
(
w
)
)
elif
not
with_quantized_compute
and
w_is_quantized
:
if
with_fp8_compute
and
not
is_float8_tensor
(
w
):
fp8_dtype
=
get_fp8_te_dtype
(
weight_fp8_meta
[
"recipe"
],
fprop_tensor
=
True
,
)
w
=
Float8Tensor
.
to_float8
(
w
,
fp8_meta
=
weight_fp8_meta
,
fp8_meta_index
=
0
,
fp8_dtype
=
fp8_dtype
,
)
elif
not
with_fp8_compute
and
is_float8_tensor
(
w
):
w
=
w
.
dequantize
()
w
=
w
.
dequantize
()
if
not
with_quantized_compute
and
w
.
dtype
!=
dtype
:
w
=
w
.
to
(
dtype
=
dtype
)
# Check bias tensor
# Construct output tensor if needed
b
=
None
reduce_scatter_output
=
None
if
bias
is
not
None
:
b
=
convert_tensor
(
bias
,
device
=
device
,
dtype
=
dtype
,
memory_format
=
torch
.
contiguous_format
,
)
# Construct output tensor
y
=
None
y_local
=
None
if
with_ub_reduce_scatter
:
if
with_ub_reduce_scatter
:
# Initialize buffers for UB reduce-scatter
y_local_size
=
list
(
x
.
size
())
if
with_fp8_output
:
y_local_size
[
0
]
//=
tensor_parallel_size
fp8_meta_key
=
FP8GlobalStateManager
.
get_meta_tensor_key
(
forward
=
True
)
y_local_size
[
-
1
]
=
w
.
size
(
0
)
fp8_dtype
=
get_fp8_te_dtype
(
reduce_scatter_output
=
torch
.
empty
(
y_local_size
,
dtype
=
dtype
,
device
=
device
)
output_fp8_meta
[
"recipe"
],
fprop_tensor
=
True
,
)
y
=
Float8Tensor
(
data
=
ub_global_buffer
,
fp8_meta
=
output_fp8_meta
,
fp8_meta_forward
=
True
,
fp8_meta_index
=
0
,
fp8_dtype
=
fp8_dtype
,
fp8_scale_inv
=
output_fp8_meta
[
fp8_meta_key
].
scale_inv
[
0
],
dtype
=
dtype
,
)
ub_comm
.
set_ubuf_scale_inv
(
y
.
_scale_inv
)
else
:
y
=
ub_global_buffer
y_local
=
torch
.
empty
(
(
x
.
size
(
0
)
//
tensor_parallel_size
,
weight_dims
[
0
]),
dtype
=
dtype
,
device
=
device
,
)
else
:
# Allocate output tensor
if
with_fp8_output
:
fp8_dtype
=
get_fp8_te_dtype
(
output_fp8_meta
[
"recipe"
],
fprop_tensor
=
True
,
)
data
=
torch
.
empty
(
(
x
.
size
(
0
),
weight_dims
[
0
]),
dtype
=
torch
.
uint8
,
device
=
device
,
)
y
=
Float8Tensor
(
data
=
data
,
fp8_meta
=
output_fp8_meta
,
fp8_meta_forward
=
True
,
fp8_meta_index
=
0
,
fp8_dtype
=
fp8_dtype
,
dtype
=
dtype
,
)
else
:
y
=
torch
.
empty
(
(
x
.
size
(
0
),
weight_dims
[
0
]),
dtype
=
dtype
,
device
=
device
,
)
y_local
=
y
# Perform GEMM
# Perform GEMM
if
with_fp8_compute
:
gemm_output
,
*
_
,
reduce_scatter_output
=
general_gemm
(
kwargs
=
{
w
,
"out"
:
y
,
x
,
"bias"
:
b
,
get_workspace
(),
"use_bias"
:
(
b
is
not
None
),
out_dtype
=
dtype
,
"use_split_accumulator"
:
False
,
quantization_params
=
output_quantizer
,
"ub_algo"
:
ub_algo
,
bias
=
bias
,
"ub"
:
ub_comm
,
use_split_accumulator
=
_2X_ACC_FPROP
,
}
ub
=
ub_comm
,
if
with_ub_all_gather
:
ub_type
=
ub_type
,
kwargs
[
"extra_output_tensor"
]
=
x_local
.
_data
extra_output
=
reduce_scatter_output
,
if
with_ub_reduce_scatter
:
)
kwargs
[
"extra_output_tensor"
]
=
y_local
if
with_ub_reduce_scatter
:
if
with_fp8_output
:
y_local
=
reduce_scatter_output
fp8_meta
,
fp8_meta_index
=
get_fp8_meta_from_fp8_tensor
(
y
)
kwargs
.
update
(
{
"out"
:
y
.
_data
,
"out_index"
:
fp8_meta_index
,
"fp8_meta_tensor"
:
fp8_meta
,
"D_dtype"
:
y
.
_fp8_dtype
,
}
)
fp8_gemm
(
w
.
_data
,
w
.
_scale_inv
,
0
,
w
.
_fp8_dtype
,
x
.
_data
,
x
.
_scale_inv
,
0
,
x
.
_fp8_dtype
,
y
.
dtype
,
get_workspace
(),
**
kwargs
,
)
else
:
else
:
kwargs
=
{
y_local
=
gemm_output
"out"
:
y
,
"bias"
:
b
,
# Detach input tensor if needed
"use_bias"
:
(
b
is
not
None
),
# Note: PyTorch autograd produces esoteric errors if we save
"ub_algo"
:
ub_algo
,
# input tensor as context for backward pass.
"ub"
:
ub_comm
,
if
x_local
is
input
:
}
x_local
=
x_local
.
detach
()
if
with_ub_all_gather
:
kwargs
[
"extra_output_tensor"
]
=
x_local
# Configure input tensor for backward pass
if
with_ub_reduce_scatter
:
if
with_quantized_compute
and
isinstance
(
x_local
,
QuantizedTensorBase
):
kwargs
[
"extra_output_tensor"
]
=
y_local
if
not
(
isinstance
(
x_local
,
Float8TensorBase
)
and
with_ub_all_gather
):
gemm
(
w
,
x
,
y
.
dtype
,
get_workspace
(),
**
kwargs
)
# FP8 does not support all-gather of transpose data
x_local
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
# Reshape output tensor
out
=
reshape
(
y_local
,
output_dims
)
# Return cast tensors
# Return cast tensors
extra_outputs
=
{
"input"
:
x_local
,
"weight"
:
w
}
extra_outputs
=
{
"input"
:
x_local
,
"weight"
:
w
}
return
out
,
extra_outputs
return
y_local
,
extra_outputs
def
fuser_forward
(
def
fuser_forward
(
self
,
self
,
...
@@ -450,23 +298,22 @@ class UserbuffersForwardLinear(FusedOperation):
...
@@ -450,23 +298,22 @@ class UserbuffersForwardLinear(FusedOperation):
if
basic_op_kwargs
[
idx
]:
if
basic_op_kwargs
[
idx
]:
raise
ValueError
(
"Bias operation forward does not expect keyword arguments"
)
raise
ValueError
(
"Bias operation forward does not expect keyword arguments"
)
# FP8 metadata
# Quantization metadata
with_fp8_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
input_fp8_meta
=
None
input_quantizer
=
None
weight_fp8_meta
=
None
weight_quantizer
=
None
output_fp8_meta
=
None
grad_output_quantizer
=
None
grad_output_fp8_meta
=
None
grad_input_quantizer
=
None
grad_input_fp8_meta
=
None
if
with_quantized_compute
:
if
with_fp8_compute
:
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
input_fp8_meta
=
linear_op
.
get_fp8_meta
(
"input"
)
if
not
recipe
.
delayed
()
and
not
recipe
.
mxfp8
():
weight_fp8_meta
=
linear_op
.
get_fp8_meta
(
"param"
)
raise
RuntimeError
(
"Userbuffers is only supported with FP8 delayed scaling recipe"
)
next_op
=
basic_op_next_ops
[
-
1
]
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
if
next_op
is
not
None
and
next_op
.
num_fp8_scales
(
"input"
)
>
0
:
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
output_fp8_meta
=
next_op
.
get_fp8_meta
(
"input"
)
grad_output_quantizer
=
linear_op
.
get_quantizer
(
"backward"
,
0
)
grad_output_fp8_meta
=
linear_op
.
get_fp8_meta
(
"grad_output"
)
prev_op
=
basic_op_prev_ops
[
0
]
prev_op
=
basic_op_prev_ops
[
0
]
if
prev_op
is
not
None
and
prev_op
.
num_
fp8_scales
(
"grad_output"
)
>
0
:
if
prev_op
is
not
None
and
prev_op
.
num_
quantizers
(
"backward"
)
>
0
and
recipe
.
delayed
()
:
grad_input_
fp8_meta
=
prev_op
.
get_
fp8_meta
(
"grad_output"
)
grad_input_
quantizer
=
prev_op
.
get_
quantizer
(
"backward"
,
0
)
# Get autocast dtype if needed
# Get autocast dtype if needed
dtype
=
None
dtype
=
None
...
@@ -482,26 +329,26 @@ class UserbuffersForwardLinear(FusedOperation):
...
@@ -482,26 +329,26 @@ class UserbuffersForwardLinear(FusedOperation):
input
=
input_
,
input
=
input_
,
weight
=
linear_op
.
weight
,
weight
=
linear_op
.
weight
,
bias
=
bias
,
bias
=
bias
,
device
=
linear_op
.
device
,
dtype
=
dtype
,
dtype
=
dtype
,
tensor_parallel_mode
=
self
.
tensor_parallel_mode
,
tensor_parallel_mode
=
self
.
tensor_parallel_mode
,
tensor_parallel_group
=
self
.
tensor_parallel_group
,
tensor_parallel_group
=
self
.
tensor_parallel_group
,
tensor_parallel_size
=
self
.
tensor_parallel_size
,
tensor_parallel_size
=
self
.
tensor_parallel_size
,
sequence_parallel
=
self
.
sequence_parallel
,
sequence_parallel
=
self
.
sequence_parallel
,
with_
fp8
_compute
=
with_
fp8
_compute
,
with_
quantized
_compute
=
with_
quantized
_compute
,
input_
fp8_meta
=
input_fp8_meta
,
input_
quantizer
=
input_quantizer
,
weight_
fp8_meta
=
weight_fp8_meta
,
weight_
quantizer
=
weight_quantizer
,
output_
fp8_meta
=
output_fp8_meta
,
output_
quantizer
=
None
,
# Not supported
ub_comm_name
=
linear_op
.
_userbuffers_options
[
"comm_name"
],
ub_comm_name
=
linear_op
.
_userbuffers_options
[
"comm_name"
],
)
)
x_local
=
extra_outputs
[
"input"
]
x_local
=
extra_outputs
[
"input"
]
# Save state for backward pass
# Save state for backward pass
linear_op_ctx
.
save_for_backward
(
x_local
)
linear_op_ctx
.
save_for_backward
(
x_local
)
linear_op_ctx
.
with_fp8_compute
=
with_fp8_compute
linear_op_ctx
.
with_quantized_compute
=
with_quantized_compute
linear_op_ctx
.
weight_fp8_meta
=
weight_fp8_meta
linear_op_ctx
.
input_quantizer
=
input_quantizer
linear_op_ctx
.
grad_output_fp8_meta
=
grad_output_fp8_meta
linear_op_ctx
.
weight_quantizer
=
weight_quantizer
linear_op_ctx
.
grad_input_fp8_meta
=
grad_input_fp8_meta
linear_op_ctx
.
grad_output_quantizer
=
grad_output_quantizer
linear_op_ctx
.
grad_input_quantizer
=
grad_input_quantizer
linear_op_ctx
.
dtype
=
dtype
linear_op_ctx
.
dtype
=
dtype
linear_op_ctx
.
input_dims
=
input_
.
size
()
linear_op_ctx
.
input_dims
=
input_
.
size
()
linear_op_ctx
.
input_requires_grad
=
input_
.
requires_grad
linear_op_ctx
.
input_requires_grad
=
input_
.
requires_grad
...
@@ -529,8 +376,6 @@ def fuse_userbuffers_forward_linear(
...
@@ -529,8 +376,6 @@ def fuse_userbuffers_forward_linear(
"""
"""
return
ops
### TODO Debug Userbuffers support
# Return immediately if environment is not distributed
# Return immediately if environment is not distributed
if
not
torch
.
distributed
.
is_initialized
()
or
torch
.
distributed
.
get_world_size
()
==
1
:
if
not
torch
.
distributed
.
is_initialized
()
or
torch
.
distributed
.
get_world_size
()
==
1
:
return
ops
return
ops
...
...
transformer_engine/pytorch/setup.py
View file @
f8c2af4c
...
@@ -55,7 +55,17 @@ if __name__ == "__main__":
...
@@ -55,7 +55,17 @@ if __name__ == "__main__":
description
=
"Transformer acceleration library - Torch Lib"
,
description
=
"Transformer acceleration library - Torch Lib"
,
ext_modules
=
ext_modules
,
ext_modules
=
ext_modules
,
cmdclass
=
{
"build_ext"
:
CMakeBuildExtension
},
cmdclass
=
{
"build_ext"
:
CMakeBuildExtension
},
install_requires
=
[
"torch"
],
setup_requires
=
[
"torch>=2.1"
,
"nvidia-cuda-runtime-cu12"
,
"nvidia-cublas-cu12"
,
"nvidia-cudnn-cu12"
,
"nvidia-cuda-cccl-cu12"
,
"nvidia-cuda-nvcc-cu12"
,
"nvidia-nvtx-cu12"
,
"nvidia-cuda-nvrtc-cu12"
,
],
install_requires
=
[
"torch>=2.1"
],
tests_require
=
[
"numpy"
,
"torchvision"
],
tests_require
=
[
"numpy"
,
"torchvision"
],
)
)
if
any
(
x
in
sys
.
argv
for
x
in
(
"."
,
"sdist"
,
"bdist_wheel"
)):
if
any
(
x
in
sys
.
argv
for
x
in
(
"."
,
"sdist"
,
"bdist_wheel"
)):
...
...
Prev
1
…
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment