Unverified Commit 3a9a4c83 authored by Frank Lin's avatar Frank Lin Committed by GitHub
Browse files

[Paddle][CUDAGraph] 175B GPT-3 Hybrid-Parallel Training with CUDAGraph (#957)



* NVTE_OVERRIDE_MAX_SEQ_LEN
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>

* small fix
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>

* preserve old amax_and_scale_update_inplace and new amax_and_scale_update_inplace
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>

* remove useless code path; try to simplify logic within the baseline
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>

* simplify logic
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>

* small fix
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix comments from Timmoon
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>

* fix comments from Timmoon
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>

* Update transformer_engine/paddle/distributed.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>

* diable bw fp8 update
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>

* fix lint
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>

* fix ci error
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>

---------
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>
Co-authored-by: default avatarFrank Lin (Engrg-Hardware 1) <fralin@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 67b67432
...@@ -9,7 +9,7 @@ import paddle ...@@ -9,7 +9,7 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from transformer_engine import transformer_engine_paddle as tex from transformer_engine import transformer_engine_paddle as tex
from .constants import TE_DType, FusedAttnBackend, FP8FwdTensors, FP8BwdTensors from .constants import TE_DType, FusedAttnBackend, FP8FwdTensors, FP8BwdTensors
from .fp8 import FP8TensorMeta from .fp8 import FP8TensorMeta, get_global_fp8_state
BACKEND_F16m512_THREADS_PER_CTA = 128 BACKEND_F16m512_THREADS_PER_CTA = 128
BACKEND_F16arb_ELTS_PER_THREADS = 16 BACKEND_F16arb_ELTS_PER_THREADS = 16
...@@ -526,6 +526,8 @@ def mask_to_cu_seqlens( ...@@ -526,6 +526,8 @@ def mask_to_cu_seqlens(
) -> paddle.Tensor: ) -> paddle.Tensor:
"""Convert mask to cu_seqlens""" """Convert mask to cu_seqlens"""
# mask shape: [b, 1, s_q, s_kv] # mask shape: [b, 1, s_q, s_kv]
if get_global_fp8_state().is_cudagraph_enabled():
raise RuntimeError("mask_to_cu_seqlens is not supported with cuda graphs.")
q_seqlen, kv_seqlen = mask.shape[2], mask.shape[3] q_seqlen, kv_seqlen = mask.shape[2], mask.shape[3]
q_cu_seqlens = paddle.empty(shape=[mask.shape[0] + 1], dtype=paddle.int32) q_cu_seqlens = paddle.empty(shape=[mask.shape[0] + 1], dtype=paddle.int32)
q_cu_seqlens[0] = 0 q_cu_seqlens[0] = 0
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "common.h" #include "common.h"
#include "common/common.h" #include "common/common.h"
#include "paddle/phi/backends/gpu/cuda/cuda_graph.h"
namespace transformer_engine { namespace transformer_engine {
namespace paddle_ext { namespace paddle_ext {
...@@ -581,11 +582,44 @@ std::vector<paddle::Tensor> te_rmsnorm_bwd(const paddle::Tensor &dz, const paddl ...@@ -581,11 +582,44 @@ std::vector<paddle::Tensor> te_rmsnorm_bwd(const paddle::Tensor &dz, const paddl
return {dx, dgamma}; return {dx, dgamma};
} }
__global__ void set_rng_state(std::pair<uint64_t, uint64_t> seed_offset, int64_t *rng_state_ptr) { __global__ void set_rng_state(
[[maybe_unused]] unsigned int
identifier, // This is used to relate kernel to cudaGraph nodes please refer to https://github.com/PaddlePaddle/Paddle/pull/60516
std::pair<uint64_t, uint64_t> seed_offset, int64_t *rng_state_ptr) {
rng_state_ptr[0] = static_cast<int64_t>(seed_offset.first); rng_state_ptr[0] = static_cast<int64_t>(seed_offset.first);
rng_state_ptr[1] = static_cast<int64_t>(seed_offset.second); rng_state_ptr[1] = static_cast<int64_t>(seed_offset.second);
} }
void UpdateRandomGenerator(phi::Place place, cudaStream_t stream, int rng_elts_per_thread,
paddle::Tensor &rng_state) {
// extract random number generator seed and offset
const phi::DeviceContext *dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(place);
phi::Generator *gen_cuda = dev_ctx->GetGenerator();
auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread);
auto state_index = gen_cuda->GetStateIndex();
int64_t *rng_state_p = static_cast<int64_t *>(rng_state.data());
auto parameterSetter = [gen_cuda, state_index,
rng_elts_per_thread](phi::backends::gpu::CUDAKernelParams &params) {
// ensure the generator use correct state index
gen_cuda->SetStateIndex(state_index);
auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread);
params.As<std::pair<int64_t, int64_t>>(1) = seed_offset;
};
phi::backends::gpu::CUDAGraphNodeLauncher::cudaKernelCallback_t cudaKernelCallback =
[=](unsigned int id) {
void *functionPtr = reinterpret_cast<void *>(&set_rng_state);
cudaFunction_t cudaFunc;
PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, functionPtr));
set_rng_state<<<1, 1, 0, stream>>>(id, seed_offset, rng_state_p);
return cudaFunc;
};
phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch(parameterSetter,
cudaKernelCallback);
}
void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor &cu_seqlens, void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor &cu_seqlens,
const paddle::optional<paddle::Tensor> &Bias, const paddle::optional<paddle::Tensor> &Bias,
paddle::Tensor &O, // NOLINT paddle::Tensor &O, // NOLINT
...@@ -623,12 +657,7 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -623,12 +657,7 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// extract random number generator seed and offset UpdateRandomGenerator(QKV.place(), QKV.stream(), rng_elts_per_thread, rng_state);
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(QKV.place());
auto gen_cuda = dev_ctx->GetGenerator();
auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread);
set_rng_state<<<1, 1, 0, QKV.stream()>>>(seed_offset, static_cast<int64_t *>(rng_state.data()));
auto te_rng_state = MakeNvteTensor(rng_state); auto te_rng_state = MakeNvteTensor(rng_state);
// create auxiliary output tensors // create auxiliary output tensors
...@@ -799,10 +828,7 @@ void te_fused_attn_fwd_kvpacked( ...@@ -799,10 +828,7 @@ void te_fused_attn_fwd_kvpacked(
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(Q.place()); UpdateRandomGenerator(Q.place(), Q.stream(), rng_elts_per_thread, rng_state);
auto gen_cuda = dev_ctx->GetGenerator();
auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread);
set_rng_state<<<1, 1, 0, Q.stream()>>>(seed_offset, static_cast<int64_t *>(rng_state.data()));
auto te_rng_state = MakeNvteTensor(rng_state); auto te_rng_state = MakeNvteTensor(rng_state);
// create auxiliary output tensors // create auxiliary output tensors
...@@ -979,7 +1005,27 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p ...@@ -979,7 +1005,27 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(Q.place()); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(Q.place());
auto gen_cuda = dev_ctx->GetGenerator(); auto gen_cuda = dev_ctx->GetGenerator();
auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread);
set_rng_state<<<1, 1, 0, Q.stream()>>>(seed_offset, static_cast<int64_t *>(rng_state.data())); auto state_index = gen_cuda->GetStateIndex();
auto rng_state_p = static_cast<int64_t *>(rng_state.data());
auto stream = Q.stream();
auto parameterSetter = [gen_cuda, state_index,
rng_elts_per_thread](phi::backends::gpu::CUDAKernelParams &params) {
// ensure the generator use correct state index
gen_cuda->SetStateIndex(state_index);
auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread);
params.As<std::pair<int64_t, int64_t>>(1) = seed_offset;
};
phi::backends::gpu::CUDAGraphNodeLauncher::cudaKernelCallback_t cudaKernelCallback =
[=](unsigned int id) {
void *functionPtr = reinterpret_cast<void *>(&set_rng_state);
cudaFunction_t cudaFunc;
PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, functionPtr));
set_rng_state<<<1, 1, 0, stream>>>(id, seed_offset, rng_state_p);
return cudaFunc;
};
phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch(parameterSetter,
cudaKernelCallback);
auto te_rng_state = MakeNvteTensor(rng_state); auto te_rng_state = MakeNvteTensor(rng_state);
...@@ -1260,6 +1306,29 @@ void te_scaled_upper_triang_masked_softmax_backward(paddle::Tensor &output_grads ...@@ -1260,6 +1306,29 @@ void te_scaled_upper_triang_masked_softmax_backward(paddle::Tensor &output_grads
softmax_results.stream()); softmax_results.stream());
} }
__global__ void UpdateFP8MetaKernel(
[[maybe_unused]] unsigned int
identifier, // This is used to relate kernel to cudaGraph nodes please refer to https://github.com/PaddlePaddle/Paddle/pull/60516
const float *amax, const float *rolled_amax_history, const bool *non_weight_mask,
float *amax_history, float *scale, float *scale_inv, bool update_weight_scale_inv, float margin,
float fp8_max, size_t history_numel, size_t amax_numel) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= history_numel) {
return;
}
amax_history[idx] = rolled_amax_history[idx];
if (idx < amax_numel) {
float sf = (fp8_max / amax[idx]) / powf(2.0f, margin);
float scale_reg = ((amax[idx] > 0.0f) && isfinite(amax[idx])) ? sf : scale[idx];
scale[idx] = scale_reg;
if (update_weight_scale_inv || non_weight_mask[idx]) scale_inv[idx] = 1.0f / scale_reg;
amax_history[idx] = 0.0f;
}
}
constexpr int BLOCK_SIZE = 512; constexpr int BLOCK_SIZE = 512;
void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT
...@@ -1277,6 +1346,63 @@ void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT ...@@ -1277,6 +1346,63 @@ void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT
static_cast<NVTEDType>(fp8_dtype), margin, amax_history.stream()); static_cast<NVTEDType>(fp8_dtype), margin, amax_history.stream());
} }
void amax_and_scale_update_inplace_legacy(paddle::Tensor &amax_history, // NOLINT
paddle::Tensor &scale, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
const paddle::Tensor &non_weight_mask,
const paddle::Tensor &current_step_id_tensor,
bool update_weight_scale_inv, bool fwd_update,
float fp8_max, float margin,
const std::string &amax_compute) {
NVTE_CHECK(amax_compute == "max" || amax_compute == "most_recent");
paddle::Tensor amax;
if (amax_compute == "max") {
amax = amax_history.max({0});
} else {
amax = amax_history.slice(0, 1);
}
const auto rolled_amax_history = amax_history.roll({-1}, {0});
auto amax_history_numel = amax_history.numel();
auto amax_numel = amax.numel();
size_t num_blocks = (amax_history_numel + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int *current_step_id_ptr = nullptr;
if (fwd_update) current_step_id_ptr = current_step_id_tensor.data<int>();
auto parameterSetter = [current_step_id_ptr,
fwd_update](phi::backends::gpu::CUDAKernelParams &params) {
if (fwd_update) {
int current_step_id = *current_step_id_ptr;
params.As<bool>(7) = (current_step_id == 0);
}
};
const float *amax_ptr = amax.data<float>();
const float *rolled_amax_history_ptr = rolled_amax_history.data<float>();
const bool *non_weight_mask_ptr = non_weight_mask.data<bool>();
float *amax_history_ptr = amax_history.data<float>();
float *scale_ptr = scale.data<float>();
float *scale_inv_ptr = scale_inv.data<float>();
phi::backends::gpu::CUDAGraphNodeLauncher::cudaKernelCallback_t cudaKernelCallback =
[=](unsigned int id) {
void *functionPtr = reinterpret_cast<void *>(&UpdateFP8MetaKernel);
cudaFunction_t cudaFunc;
PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, functionPtr));
UpdateFP8MetaKernel<<<num_blocks, BLOCK_SIZE, 0, amax_history.stream()>>>(
id, amax_ptr, rolled_amax_history_ptr, non_weight_mask_ptr, amax_history_ptr, scale_ptr,
scale_inv_ptr, update_weight_scale_inv, margin, fp8_max, amax_history_numel,
amax_numel);
NVTE_CHECK_CUDA(cudaGetLastError());
return cudaFunc;
};
phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch(parameterSetter,
cudaKernelCallback);
}
void update_latest_amax_history_inplace(paddle::Tensor &history, // NOLINT void update_latest_amax_history_inplace(paddle::Tensor &history, // NOLINT
const paddle::Tensor &amax) { const paddle::Tensor &amax) {
// Copy amax to history[0] // Copy amax to history[0]
...@@ -1617,6 +1743,16 @@ PD_BUILD_OP(te_scaled_upper_triang_masked_softmax_backward) ...@@ -1617,6 +1743,16 @@ PD_BUILD_OP(te_scaled_upper_triang_masked_softmax_backward)
.SetKernelFn( .SetKernelFn(
PD_KERNEL(transformer_engine::paddle_ext::te_scaled_upper_triang_masked_softmax_backward)); PD_KERNEL(transformer_engine::paddle_ext::te_scaled_upper_triang_masked_softmax_backward));
PD_BUILD_OP(amax_and_scale_update_inplace_legacy)
.Inputs({"_amax_history", "_scale", "_scale_inv", "non_weight_mask", "current_step_id_tensor"})
.Outputs({"amax_history", "scale", "scale_inv"})
.SetInplaceMap({{"_amax_history", "amax_history"},
{"_scale", "scale"},
{"_scale_inv", "scale_inv"}})
.Attrs({"update_weight_scale_inv: bool", "fwd_update: bool", "fp8_max: float", "margin: float",
"amax_compute: std::string"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::amax_and_scale_update_inplace_legacy));
PD_BUILD_OP(amax_and_scale_update_inplace) PD_BUILD_OP(amax_and_scale_update_inplace)
.Inputs({"_amax_history", "_scale", "_scale_inv", "non_weight_mask"}) .Inputs({"_amax_history", "_scale", "_scale_inv", "non_weight_mask"})
.Outputs({"amax_history", "scale", "scale_inv"}) .Outputs({"amax_history", "scale", "scale_inv"})
......
...@@ -14,6 +14,16 @@ import paddle.distributed.fleet.base.topology as tp ...@@ -14,6 +14,16 @@ import paddle.distributed.fleet.base.topology as tp
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.fleet.layers.mpu import mp_ops from paddle.distributed.fleet.layers.mpu import mp_ops
try:
# This feature is not supported as of Paddle 2.6.
from paddle.distributed.fleet.meta_parallel import (
PipelineParallelMicroStepLocations,
register_global_pipeline_parallel_hook,
)
except ImportError:
print("Cannot find register_global_pipeline_parallel_hook !")
register_global_pipeline_parallel_hook = None
from .constants import dist_group_type from .constants import dist_group_type
_weight_split_axis = { _weight_split_axis = {
...@@ -54,6 +64,22 @@ def get_tp_group_and_world_size( ...@@ -54,6 +64,22 @@ def get_tp_group_and_world_size(
return model_parallel_group, world_size return model_parallel_group, world_size
def is_pp_enabled() -> bool:
"""Check if pipeline parallel is enabled"""
if not paddle.distributed.is_initialized():
return False
return tp._HYBRID_PARALLEL_GROUP.get_pipe_parallel_world_size() > 1
def register_pp_fwd_begin_hook(forward_begin_hook):
"""Register the pp hook if register_global_pipeline_parallel_hook exist"""
if register_global_pipeline_parallel_hook is not None:
register_global_pipeline_parallel_hook(
PipelineParallelMicroStepLocations.FORWARD_BEGIN, forward_begin_hook
)
@contextmanager @contextmanager
def track_rng_state(enable: bool, **kwargs) -> None: def track_rng_state(enable: bool, **kwargs) -> None:
""" """
......
...@@ -62,6 +62,7 @@ class FP8State: ...@@ -62,6 +62,7 @@ class FP8State:
self._fp8_autocast_counter = 0 self._fp8_autocast_counter = 0
self._fp8_autocast_depth = 0 self._fp8_autocast_depth = 0
self._fp8_recompute_enabled = False self._fp8_recompute_enabled = False
self._use_cudagraph = False
self._fp8_fwd_buffer = FP8MetaFwdBuffer() self._fp8_fwd_buffer = FP8MetaFwdBuffer()
self._fp8_bwd_buffer = FP8MetaBwdBuffer() self._fp8_bwd_buffer = FP8MetaBwdBuffer()
self._fp8_recompute_buffer = FP8RecomputeBuffer() self._fp8_recompute_buffer = FP8RecomputeBuffer()
...@@ -116,6 +117,18 @@ class FP8State: ...@@ -116,6 +117,18 @@ class FP8State:
"""Returns global fp8 recompute buffer.""" """Returns global fp8 recompute buffer."""
return self._fp8_recompute_buffer return self._fp8_recompute_buffer
def is_cudagraph_enabled(self) -> bool:
"""Is CUDAGraph enabled"""
return self._use_cudagraph
def enable_cudagraph(self):
"""Enable CUDA Graphs. Once CUDA Graphs are enabled, they cannot be disabled within the same execution context at current implementation."""
self._use_cudagraph = True
self._fp8_fwd_buffer.enable_cudagraph()
self._fp8_bwd_buffer.enable_cudagraph()
if self._fp8_recompute_enabled:
raise RuntimeError("Currently, We do not allow recompute with cudagraph")
def enter( def enter(
self, self,
enabled: bool, enabled: bool,
...@@ -235,25 +248,45 @@ def amax_and_scale_update( ...@@ -235,25 +248,45 @@ def amax_and_scale_update(
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
fwd_update: bool, fwd_update: bool,
update_weight_scale_inv: bool = True, update_weight_scale_inv: bool = True,
current_step_id_tensor: Optional[paddle.Tensor] = None,
use_cudagraph: bool = False,
) -> None: ) -> None:
"""Updates fp8 amaxes/scales for fwd | bwd.""" """Updates fp8 amaxes/scales for fwd | bwd."""
amax_compute = fp8_meta["recipe"].amax_compute_algo amax_compute = fp8_meta["recipe"].amax_compute_algo
sf_compute = fp8_meta["recipe"].scaling_factor_compute_algo sf_compute = fp8_meta["recipe"].scaling_factor_compute_algo
fp8_meta_tensor_key = "scaling_fwd" if fwd_update else "scaling_bwd" fp8_meta_tensor_key = "scaling_fwd" if fwd_update else "scaling_bwd"
fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd"
if not callable(amax_compute) and sf_compute is None: if not callable(amax_compute) and sf_compute is None:
non_weight_mask = fp8_meta[fp8_meta_tensor_key].non_weight_mask non_weight_mask = fp8_meta[fp8_meta_tensor_key].non_weight_mask
if update_weight_scale_inv:
non_weight_mask = paddle.empty([0]) if use_cudagraph:
tex.amax_and_scale_update_inplace( tex.amax_and_scale_update_inplace_legacy(
_amax_history=fp8_meta[fp8_meta_tensor_key].amax_history, _amax_history=fp8_meta[fp8_meta_tensor_key].amax_history,
_scale=fp8_meta[fp8_meta_tensor_key].scale, _scale=fp8_meta[fp8_meta_tensor_key].scale,
_scale_inv=fp8_meta[fp8_meta_tensor_key].scale_inv, _scale_inv=fp8_meta[fp8_meta_tensor_key].scale_inv,
non_weight_mask=non_weight_mask, non_weight_mask=non_weight_mask,
fp8_dtype=int(get_fp8_te_dtype(fp8_meta["recipe"], fwd_update)), current_step_id_tensor=current_step_id_tensor,
margin=float(fp8_meta["recipe"].margin), update_weight_scale_inv=update_weight_scale_inv,
amax_compute=amax_compute, fwd_update=fwd_update,
) fp8_max=fp8_meta[fp8_max_key],
margin=float(fp8_meta["recipe"].margin),
amax_compute=amax_compute,
)
else:
if update_weight_scale_inv:
# we pass nullptr into kernel when we need to update_weight_scale_inv
non_weight_mask = paddle.empty([0])
tex.amax_and_scale_update_inplace(
_amax_history=fp8_meta[fp8_meta_tensor_key].amax_history,
_scale=fp8_meta[fp8_meta_tensor_key].scale,
_scale_inv=fp8_meta[fp8_meta_tensor_key].scale_inv,
non_weight_mask=non_weight_mask,
fp8_dtype=int(get_fp8_te_dtype(fp8_meta["recipe"], fwd_update)),
margin=float(fp8_meta["recipe"].margin),
amax_compute=amax_compute,
)
else: else:
raise ValueError( raise ValueError(
"We only support the fp8 recipe with 'max' or 'most_recent' " "We only support the fp8 recipe with 'max' or 'most_recent' "
......
...@@ -22,10 +22,12 @@ class FP8MetaBufferBase(ABC): ...@@ -22,10 +22,12 @@ class FP8MetaBufferBase(ABC):
""" """
def __init__(self): def __init__(self):
self._data = {} self._global_amax = {}
self._buffer_delete_key = None self._buffer_delete_key = None
self._amax_reduce_wait_func = None self._amax_reduce_wait_func = None
self._dp_amax_reduce_interval = None self._dp_amax_reduce_interval = None
self._contiguous_amax = None
self._use_cudagraph = False
self._dp_amax_reduce_idx = 0 self._dp_amax_reduce_idx = 0
@staticmethod @staticmethod
...@@ -44,13 +46,13 @@ class FP8MetaBufferBase(ABC): ...@@ -44,13 +46,13 @@ class FP8MetaBufferBase(ABC):
"""Returns autocast id key in `fp8_meta`.""" """Returns autocast id key in `fp8_meta`."""
def _get_amax_buffer_key(self, fp8_meta: Dict[str, Any]) -> str: def _get_amax_buffer_key(self, fp8_meta: Dict[str, Any]) -> str:
"""Return a key in `_data` for the AMAX storage.""" """Return a key in `_global_amax` for the AMAX storage."""
return f"AMAX_{fp8_meta[self._get_autocast_key()]}" return f"AMAX_{fp8_meta[self._get_autocast_key()]}"
def _execute_deletion(self) -> None: def _execute_deletion(self) -> None:
"""Delete the key from global amax buffer.""" """Delete the key from global amax buffer."""
if self._buffer_delete_key is not None and self._buffer_delete_key in self._data: if self._buffer_delete_key is not None and self._buffer_delete_key in self._global_amax:
del self._data[self._buffer_delete_key] del self._global_amax[self._buffer_delete_key]
def _wait_handle_and_split( def _wait_handle_and_split(
self, self,
...@@ -62,7 +64,12 @@ class FP8MetaBufferBase(ABC): ...@@ -62,7 +64,12 @@ class FP8MetaBufferBase(ABC):
"""Wait for amax reduction to finish and then copy reduced amax to buffer""" """Wait for amax reduction to finish and then copy reduced amax to buffer"""
if wait_handle is not None: if wait_handle is not None:
wait_handle.wait() wait_handle.wait()
self._data[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes)) if self._use_cudagraph:
splited_list = list(contiguous_amax.split(chunk_sizes))
for amax, split in zip(self._global_amax[amax_buffer_key], splited_list):
amax.copy_(split, False)
else:
self._global_amax[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes))
def _global_amax_reduction( def _global_amax_reduction(
self, self,
...@@ -85,7 +92,7 @@ class FP8MetaBufferBase(ABC): ...@@ -85,7 +92,7 @@ class FP8MetaBufferBase(ABC):
amax_buffer_key = self._get_amax_buffer_key(fp8_meta) amax_buffer_key = self._get_amax_buffer_key(fp8_meta)
# Key already deleted. # Key already deleted.
if amax_buffer_key not in self._data: if amax_buffer_key not in self._global_amax:
return None return None
# Reduce AMAX in DP-domain at an interval. # Reduce AMAX in DP-domain at an interval.
...@@ -105,18 +112,32 @@ class FP8MetaBufferBase(ABC): ...@@ -105,18 +112,32 @@ class FP8MetaBufferBase(ABC):
else: else:
return None return None
chunk_sizes = [x.shape[0] for x in self._data[amax_buffer_key]] chunk_sizes = [x.shape[0] for x in self._global_amax[amax_buffer_key]]
contiguous_amax = paddle.concat(self._data[amax_buffer_key]) if self._use_cudagraph:
# we need to ensure the _contiguous_amax is address-stable under cudagraph
if self._contiguous_amax is None:
self._contiguous_amax = paddle.concat(self._global_amax[amax_buffer_key])
else:
self._contiguous_amax.copy_(
paddle.concat(self._global_amax[amax_buffer_key]), False
)
else:
self._contiguous_amax = paddle.concat(self._global_amax[amax_buffer_key])
wait_handle = _reduce_tensor_across_group_op_max( wait_handle = _reduce_tensor_across_group_op_max(
contiguous_amax, self._contiguous_amax,
reduce_group, reduce_group,
not fp8_meta["async_amax_reduction"], not fp8_meta["async_amax_reduction"],
) )
if wait_handle is not None and self._use_cudagraph:
# we need to ensure record/wait does not cross the boundary of the graph
wait_handle.wait()
wait_handle = None
return partial( return partial(
self._wait_handle_and_split, self._wait_handle_and_split,
contiguous_amax, self._contiguous_amax,
chunk_sizes, chunk_sizes,
amax_buffer_key, amax_buffer_key,
wait_handle, wait_handle,
...@@ -128,16 +149,16 @@ class FP8MetaBufferBase(ABC): ...@@ -128,16 +149,16 @@ class FP8MetaBufferBase(ABC):
fp8_meta_tensor_key = self._get_meta_tensor_key() fp8_meta_tensor_key = self._get_meta_tensor_key()
buffer_position_key = self._get_buffer_position_key() buffer_position_key = self._get_buffer_position_key()
if buffer_key not in self._data: if buffer_key not in self._global_amax:
self._data[buffer_key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] self._global_amax[buffer_key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
else: else:
self._data[buffer_key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) self._global_amax[buffer_key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0])
if buffer_position_key not in fp8_meta: if buffer_position_key not in fp8_meta:
fp8_meta[buffer_position_key] = len(self._data[buffer_key]) - 1 fp8_meta[buffer_position_key] = len(self._global_amax[buffer_key]) - 1
# Catch incorrect fp8_autocast usage. # Catch incorrect fp8_autocast usage.
assert fp8_meta[buffer_position_key] == len(self._data[buffer_key]) - 1, ( assert fp8_meta[buffer_position_key] == len(self._global_amax[buffer_key]) - 1, (
"Same module is being invoked more than once inside an `fp8_autocast` " "Same module is being invoked more than once inside an `fp8_autocast` "
"region when using FP8 with amax reduction. This behavior is currently " "region when using FP8 with amax reduction. This behavior is currently "
"unsupported. For more details and correct usage, please see " "unsupported. For more details and correct usage, please see "
...@@ -152,12 +173,12 @@ class FP8MetaBufferBase(ABC): ...@@ -152,12 +173,12 @@ class FP8MetaBufferBase(ABC):
return return
amax_buffer_key = self._get_amax_buffer_key(fp8_meta) amax_buffer_key = self._get_amax_buffer_key(fp8_meta)
assert amax_buffer_key in self._data, "TE internal error." assert amax_buffer_key in self._global_amax, "TE internal error."
# Copy amax to amax_history[0] # Copy amax to amax_history[0]
tex.update_latest_amax_history_inplace( tex.update_latest_amax_history_inplace(
_history=fp8_meta[fp8_meta_tensor_key].amax_history, _history=fp8_meta[fp8_meta_tensor_key].amax_history,
amax=self._data[amax_buffer_key][fp8_meta[buffer_position_key]], amax=self._global_amax[amax_buffer_key][fp8_meta[buffer_position_key]],
) )
def set_for_deletion(self, fp8_meta: Dict[str, Any]) -> None: def set_for_deletion(self, fp8_meta: Dict[str, Any]) -> None:
...@@ -179,14 +200,18 @@ class FP8MetaBufferBase(ABC): ...@@ -179,14 +200,18 @@ class FP8MetaBufferBase(ABC):
def to_numpy(self) -> Dict[str, List[np.array]]: def to_numpy(self) -> Dict[str, List[np.array]]:
"""Convert to numpy arrays""" """Convert to numpy arrays"""
out = {} out = {}
for k, v in self._data.items(): for k, v in self._global_amax.items():
out[k] = [tensor.numpy() for tensor in v] out[k] = [tensor.numpy() for tensor in v]
return out return out
def from_numpy(self, buffer: Dict[str, np.array]) -> None: def from_numpy(self, buffer: Dict[str, np.array]) -> None:
"""Set buffer values from numpy arrays""" """Set buffer values from numpy arrays"""
for k, v in buffer.items(): for k, v in buffer.items():
self._data[k] = [paddle.to_tensor(arr) for arr in v] self._global_amax[k] = [paddle.to_tensor(arr) for arr in v]
def enable_cudagraph(self):
"""Enable CUDA Graphs."""
self._use_cudagraph = True
class FP8MetaFwdBuffer(FP8MetaBufferBase): class FP8MetaFwdBuffer(FP8MetaBufferBase):
...@@ -259,7 +284,9 @@ class FP8MetaBwdBuffer(FP8MetaBufferBase): ...@@ -259,7 +284,9 @@ class FP8MetaBwdBuffer(FP8MetaBufferBase):
Called at FP8 autocast end in backward. Called at FP8 autocast end in backward.
Performs AMAX reduction and delete unused buffer entries. Performs AMAX reduction and delete unused buffer entries.
""" """
self._amax_reduce_wait_func = self._global_amax_reduction(fp8_meta, tp_group, tp_size) self._amax_reduce_wait_func = self._global_amax_reduction(
fp8_meta, tp_group, tp_size
) # _wait_handle_and_split
self._execute_deletion() self._execute_deletion()
...@@ -267,7 +294,7 @@ class FP8RecomputeBuffer: ...@@ -267,7 +294,7 @@ class FP8RecomputeBuffer:
"""Buffer used to hold FP8 meta tensors for recompute""" """Buffer used to hold FP8 meta tensors for recompute"""
def __init__(self): def __init__(self):
self._data = [] self._global_amax = []
@staticmethod @staticmethod
def get_buffer_position_key(): def get_buffer_position_key():
...@@ -285,11 +312,11 @@ class FP8RecomputeBuffer: ...@@ -285,11 +312,11 @@ class FP8RecomputeBuffer:
] ]
if buffer_position_key in fp8_meta: if buffer_position_key in fp8_meta:
self._data[fp8_meta[buffer_position_key]].append(to_copy) self._global_amax[fp8_meta[buffer_position_key]].append(to_copy)
else: else:
self._data.append(deque()) self._global_amax.append(deque())
self._data[-1].append(to_copy) self._global_amax[-1].append(to_copy)
fp8_meta[buffer_position_key] = len(self._data) - 1 fp8_meta[buffer_position_key] = len(self._global_amax) - 1
def retrieve_fp8_meta_tensors(self, fp8_meta: Dict[str, Any]) -> None: def retrieve_fp8_meta_tensors(self, fp8_meta: Dict[str, Any]) -> None:
"""Switch to the previously saved scaling factors and amaxes""" """Switch to the previously saved scaling factors and amaxes"""
...@@ -300,7 +327,7 @@ class FP8RecomputeBuffer: ...@@ -300,7 +327,7 @@ class FP8RecomputeBuffer:
# Retrieve stashed amaxes and scales from phase 1 pre forward. # Retrieve stashed amaxes and scales from phase 1 pre forward.
buffer_position_key = self.get_buffer_position_key() buffer_position_key = self.get_buffer_position_key()
stashed_fp8_meta = self._data[fp8_meta[buffer_position_key]].popleft() stashed_fp8_meta = self._global_amax[fp8_meta[buffer_position_key]].popleft()
# Replace amaxes and scales with stashed values for phase 2 forward # Replace amaxes and scales with stashed values for phase 2 forward
fp8_meta["scaling_fwd"].amax_history = stashed_fp8_meta[0] fp8_meta["scaling_fwd"].amax_history = stashed_fp8_meta[0]
......
...@@ -20,7 +20,7 @@ except ImportError: ...@@ -20,7 +20,7 @@ except ImportError:
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.framework import _dygraph_tracer from paddle.fluid.framework import _dygraph_tracer
from ..constants import FP8BwdTensors, dist_group_type from ..constants import FP8FwdTensors, FP8BwdTensors, dist_group_type
from ..cpp_extensions import cast_transpose, cast_transpose_bgrad, cast_to_fp8, transpose from ..cpp_extensions import cast_transpose, cast_transpose_bgrad, cast_to_fp8, transpose
from ..fp8 import ( from ..fp8 import (
FP8State, FP8State,
...@@ -29,7 +29,7 @@ from ..fp8 import ( ...@@ -29,7 +29,7 @@ from ..fp8 import (
get_global_fp8_state, get_global_fp8_state,
get_fp8_te_dtype, get_fp8_te_dtype,
) )
from ..distributed import allgather from ..distributed import allgather, register_pp_fwd_begin_hook, is_pp_enabled
from ..profile import nvtx_range from ..profile import nvtx_range
from ..recompute import is_in_recompute_phase from ..recompute import is_in_recompute_phase
from ..fp8_buffer import FP8RecomputeBuffer from ..fp8_buffer import FP8RecomputeBuffer
...@@ -80,8 +80,19 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): ...@@ -80,8 +80,19 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
self.fp8_meta["async_amax_reduction"] = bool( self.fp8_meta["async_amax_reduction"] = bool(
int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0")) int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0"))
) )
self.fp8_weight_shapes = [] # weights that stored in fp16 would be cast into fp8 every first microstep
self.fp8_weights = []
self.fp8_weight_cache = {} self.fp8_weight_cache = {}
self.registered_pp_start_callback = False
self.current_step_id = paddle.to_tensor([1], dtype=paddle.int32, place=paddle.CPUPlace())
def current_step_id_callback(step_id=None, **kwargs): # pylint: disable=unused-argument
self.current_step_id.copy_(
paddle.to_tensor([step_id], dtype=paddle.int32, place=paddle.CPUPlace()), True
)
register_pp_fwd_begin_hook(current_step_id_callback)
def set_activation_dtype(self, inp: paddle.Tensor) -> None: def set_activation_dtype(self, inp: paddle.Tensor) -> None:
"""Get activation data type for AMP.""" """Get activation data type for AMP."""
...@@ -157,23 +168,23 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): ...@@ -157,23 +168,23 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
if not self.fp8_enabled: if not self.fp8_enabled:
return return
for i, shape in enumerate(self.fp8_weight_shapes, start=1): for i, weight in enumerate(self.fp8_weights, start=1):
weight_cast_key = f"weight{i}_fp8" weight_cast_key = f"weight{i}_fp8"
weight_transpose_key = f"weight{i}_t_fp8" weight_transpose_key = f"weight{i}_t_fp8"
if ( if (
weight_cast_key in self.fp8_weight_cache weight_cast_key in self.fp8_weight_cache
and self.fp8_weight_cache[weight_cast_key].shape == shape and self.fp8_weight_cache[weight_cast_key].shape == weight.shape
): ):
return return
self.fp8_weight_cache[weight_cast_key] = paddle.empty( self.fp8_weight_cache[weight_cast_key] = paddle.empty(
shape=shape, shape=weight.shape,
dtype=paddle.uint8, dtype=paddle.uint8,
) )
self.fp8_weight_cache[weight_transpose_key] = paddle.empty( self.fp8_weight_cache[weight_transpose_key] = paddle.empty(
shape=[shape[1], shape[0]], shape=[weight.shape[1], weight.shape[0]],
dtype=paddle.uint8, dtype=paddle.uint8,
) )
...@@ -293,12 +304,20 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): ...@@ -293,12 +304,20 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
if self.fp8_meta["recipe"].reduce_amax: if self.fp8_meta["recipe"].reduce_amax:
global_fp8_fwd_buffer.copy_amax_from_buffer(self.fp8_meta) global_fp8_fwd_buffer.copy_amax_from_buffer(self.fp8_meta)
amax_and_scale_update( amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv self.fp8_meta,
fwd_update=True,
update_weight_scale_inv=update_weight_scale_inv,
current_step_id_tensor=self.current_step_id,
use_cudagraph=get_global_fp8_state().is_cudagraph_enabled(),
) )
global_fp8_fwd_buffer.set_for_deletion(self.fp8_meta) global_fp8_fwd_buffer.set_for_deletion(self.fp8_meta)
else: else:
amax_and_scale_update( amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv self.fp8_meta,
fwd_update=True,
update_weight_scale_inv=update_weight_scale_inv,
current_step_id_tensor=self.current_step_id,
use_cudagraph=get_global_fp8_state().is_cudagraph_enabled(),
) )
if self.fp8_enabled and self.training: if self.fp8_enabled and self.training:
...@@ -355,13 +374,21 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): ...@@ -355,13 +374,21 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
if fp8_meta["recipe"].reduce_amax: if fp8_meta["recipe"].reduce_amax:
global_fp8_bwd_buffer.copy_amax_from_buffer(fp8_meta) global_fp8_bwd_buffer.copy_amax_from_buffer(fp8_meta)
amax_and_scale_update(fp8_meta, False) amax_and_scale_update(
fp8_meta,
fwd_update=False,
use_cudagraph=get_global_fp8_state().is_cudagraph_enabled(),
)
global_fp8_bwd_buffer.set_for_deletion(fp8_meta) global_fp8_bwd_buffer.set_for_deletion(fp8_meta)
# Get new backward key. # Get new backward key.
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0) fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0)
else: else:
amax_and_scale_update(fp8_meta, False) amax_and_scale_update(
fp8_meta,
fwd_update=False,
use_cudagraph=get_global_fp8_state().is_cudagraph_enabled(),
)
with nvtx_range(name + " backward"): with nvtx_range(name + " backward"):
yield yield
...@@ -439,14 +466,13 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): ...@@ -439,14 +466,13 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
fp8_dtype_backward, fp8_dtype_backward,
) )
bgrad = None bgrad = None
return grad_output_mat, grad_output_c, grad_output_t, bgrad return grad_output_mat, grad_output_c, grad_output_t, bgrad
@abstractmethod @abstractmethod
def forward(self): def forward(self):
"""Needs override.""" """Needs override."""
def get_fp8_weights_scratchpad( def get_fp8_weights_scratchpad_and_cast(
self, self,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
) -> List[Optional[paddle.Tensor]]: ) -> List[Optional[paddle.Tensor]]:
...@@ -455,10 +481,10 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): ...@@ -455,10 +481,10 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
`is_first_microbatch` is not `None`) `is_first_microbatch` is not `None`)
""" """
if not self.fp8_enabled or is_first_microbatch is None: if not self.fp8_enabled or is_first_microbatch is None:
return [None, None] * len(self.fp8_weight_shapes) return [None, None] * len(self.fp8_weights)
out_list = [] out_list = []
for i, _ in enumerate(self.fp8_weight_shapes, start=1): for i, _ in enumerate(self.fp8_weights, start=1):
weight_cast_key = f"weight{i}_fp8" weight_cast_key = f"weight{i}_fp8"
weight_transpose_key = f"weight{i}_t_fp8" weight_transpose_key = f"weight{i}_t_fp8"
...@@ -466,10 +492,67 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): ...@@ -466,10 +492,67 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
weight_cast_key in self.fp8_weight_cache weight_cast_key in self.fp8_weight_cache
), "TE internal error: fp8 weight buffer is not found" ), "TE internal error: fp8 weight buffer is not found"
out_list.extend( weight_fp8 = self.fp8_weight_cache[weight_cast_key]
[ weight_t_fp8 = self.fp8_weight_cache[weight_transpose_key]
self.fp8_weight_cache[weight_cast_key],
self.fp8_weight_cache[weight_transpose_key], # Disable fp8 weight cache
] # is_first_microbatch is None -> we cast the weights into fp8 every micro step
) # Enalbe fp8 weight cache
# is_first_microbatch == true -> we cast the weights into fp8 every micro step
out_list.extend([weight_fp8, weight_t_fp8])
# is cudagraph is enabled we cast the weight before the pp pipe
# we only register the callback once
if get_global_fp8_state().is_cudagraph_enabled() and (
not self.registered_pp_start_callback and is_pp_enabled()
):
fp8_dtype_forward = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True)
def cast_callback(step_id=None, **kwargs): # pylint: disable=unused-argument
update_fp8_weights = step_id == 0
for i, weight in enumerate(self.fp8_weights, start=1):
weight_cast_key = f"weight{i}_fp8"
weight_transpose_key = f"weight{i}_t_fp8"
assert (
weight_cast_key in self.fp8_weight_cache
), "TE internal error: fp8 weight buffer is not found"
weight_fp8 = self.fp8_weight_cache[weight_cast_key]
weight_t_fp8 = self.fp8_weight_cache[weight_transpose_key]
if paddle.is_grad_enabled():
if update_fp8_weights:
cast_transpose(
weight,
self.fp8_meta["scaling_fwd"],
(
FP8FwdTensors.GEMM1_WEIGHT
if i == 1
else FP8FwdTensors.GEMM2_WEIGHT
),
fp8_dtype_forward,
cast_out=weight_fp8,
transpose_out=weight_t_fp8,
)
else:
if update_fp8_weights:
cast_to_fp8(
weight,
self.fp8_meta["scaling_fwd"],
(
FP8FwdTensors.GEMM1_WEIGHT
if i == 1
else FP8FwdTensors.GEMM2_WEIGHT
),
fp8_dtype_forward,
out=weight_fp8,
)
cast_callback(0 if is_first_microbatch else 1)
register_pp_fwd_begin_hook(cast_callback)
self.registered_pp_start_callback = True
return out_list return out_list
...@@ -562,7 +562,7 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -562,7 +562,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
set_weight_tensor_dist_attr( set_weight_tensor_dist_attr(
self.weight, self.tensor_parallel, self.parallel_mode, self.backend self.weight, self.tensor_parallel, self.parallel_mode, self.backend
) )
self.fp8_weight_shapes.append(self.weight.shape) self.fp8_weights.append(self.weight)
# Initialize Linear bias parameter # Initialize Linear bias parameter
self.has_bias = self._bias_attr is not False self.has_bias = self._bias_attr is not False
...@@ -616,7 +616,7 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -616,7 +616,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
inp = cast_if_needed(inp, self.activation_dtype) inp = cast_if_needed(inp, self.activation_dtype)
# Get persistent fp8 weight buffer. None if buffer does not exist. # Get persistent fp8 weight buffer. None if buffer does not exist.
weight_fp8, weight_t_fp8 = self.get_fp8_weights_scratchpad(is_first_microbatch) weight_fp8, weight_t_fp8 = self.get_fp8_weights_scratchpad_and_cast(is_first_microbatch)
out = _LayerNormLinear.apply( out = _LayerNormLinear.apply(
inp, inp,
......
...@@ -814,7 +814,7 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -814,7 +814,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
set_weight_tensor_dist_attr( set_weight_tensor_dist_attr(
self.fc1_weight, self.tensor_parallel, parallel_mode="column", backend=self.backend self.fc1_weight, self.tensor_parallel, parallel_mode="column", backend=self.backend
) )
self.fp8_weight_shapes.append(self.fc1_weight.shape) self.fp8_weights.append(self.fc1_weight)
self.has_bias = self._bias_attr is not False self.has_bias = self._bias_attr is not False
use_default_bias = self._bias_attr is None or self._bias_attr is True use_default_bias = self._bias_attr is None or self._bias_attr is True
...@@ -846,7 +846,7 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -846,7 +846,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
set_weight_tensor_dist_attr( set_weight_tensor_dist_attr(
self.fc2_weight, self.tensor_parallel, parallel_mode="row", backend=self.backend self.fc2_weight, self.tensor_parallel, parallel_mode="row", backend=self.backend
) )
self.fp8_weight_shapes.append(self.fc2_weight.shape) self.fp8_weights.append(self.fc2_weight)
if self.has_bias: if self.has_bias:
self.fc2_bias = self.create_parameter( self.fc2_bias = self.create_parameter(
...@@ -892,7 +892,7 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -892,7 +892,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
# Get persistent fp8 weight buffer. None if buffer does not exist. # Get persistent fp8 weight buffer. None if buffer does not exist.
fc1_weight_fp8, fc1_weight_t_fp8, fc2_weight_fp8, fc2_weight_t_fp8 = ( fc1_weight_fp8, fc1_weight_t_fp8, fc2_weight_fp8, fc2_weight_t_fp8 = (
self.get_fp8_weights_scratchpad(is_first_microbatch) self.get_fp8_weights_scratchpad_and_cast(is_first_microbatch)
) )
out = _LayerNormMLP.apply( out = _LayerNormMLP.apply(
......
...@@ -31,7 +31,7 @@ from ..distributed import ( ...@@ -31,7 +31,7 @@ from ..distributed import (
set_weight_tensor_dist_attr, set_weight_tensor_dist_attr,
mark_as_sequence_parallel_parameter, mark_as_sequence_parallel_parameter,
) )
from ..fp8 import get_fp8_te_dtype from ..fp8 import get_fp8_te_dtype, get_global_fp8_state
from ..utils import ( from ..utils import (
assert_dim_for_fp8_forward_exec, assert_dim_for_fp8_forward_exec,
cast_if_needed, cast_if_needed,
...@@ -74,27 +74,29 @@ def _linear_fwd_fp8( ...@@ -74,27 +74,29 @@ def _linear_fwd_fp8(
else: else:
inputmat_total = inputmat inputmat_total = inputmat
update_fp8_weights = is_first_microbatch is None or is_first_microbatch if not get_global_fp8_state().is_cudagraph_enabled():
if is_grad_enabled: # if cuda graph is not enabled, we cast the weight here
if update_fp8_weights: update_fp8_weights = is_first_microbatch is None or is_first_microbatch
weight_fp8, weight_t_fp8 = cast_transpose( if is_grad_enabled:
weight, if update_fp8_weights:
fp8_meta["scaling_fwd"], weight_fp8, weight_t_fp8 = cast_transpose(
weight_fp8_index, weight,
fp8_dtype_forward, fp8_meta["scaling_fwd"],
cast_out=weight_fp8, weight_fp8_index,
transpose_out=weight_t_fp8, fp8_dtype_forward,
) cast_out=weight_fp8,
else: transpose_out=weight_t_fp8,
weight_t_fp8 = None )
if update_fp8_weights: else:
weight_fp8 = cast_to_fp8( weight_t_fp8 = None
weight, if update_fp8_weights:
fp8_meta["scaling_fwd"], weight_fp8 = cast_to_fp8(
weight_fp8_index, weight,
fp8_dtype_forward, fp8_meta["scaling_fwd"],
out=weight_fp8, weight_fp8_index,
) fp8_dtype_forward,
out=weight_fp8,
)
out, _ = fp8_gemm( out, _ = fp8_gemm(
weight_fp8, weight_fp8,
...@@ -346,6 +348,8 @@ def _linear_bwd_fp8( ...@@ -346,6 +348,8 @@ def _linear_bwd_fp8(
if parallel_mode == "column" and tensor_parallel and handle is not None: if parallel_mode == "column" and tensor_parallel and handle is not None:
handle.wait() handle.wait()
if parallel_mode == "column" and sequence_parallel:
handle.wait()
return dgrad, wgrad return dgrad, wgrad
...@@ -416,9 +420,10 @@ def _linear_bwd_non_fp8( ...@@ -416,9 +420,10 @@ def _linear_bwd_non_fp8(
elif requires_bgrad: elif requires_bgrad:
bgrad = grad_output.sum(axis=0) bgrad = grad_output.sum(axis=0)
if parallel_mode == "column" and tensor_parallel and handle is not None: if parallel_mode == "column" and tensor_parallel and handle is not None:
handle.wait() handle.wait()
if parallel_mode == "column" and sequence_parallel and handle is not None:
handle.wait()
return dgrad, wgrad, bgrad return dgrad, wgrad, bgrad
...@@ -804,7 +809,7 @@ class Linear(TransformerEngineBaseLayer): ...@@ -804,7 +809,7 @@ class Linear(TransformerEngineBaseLayer):
else: else:
self.bias = None self.bias = None
self.fp8_weight_shapes.append(self.weight.shape) self.fp8_weights.append(self.weight)
# For RPL, bias has to be added after TP collectives # For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM # So it cannot be fused with the GEMM
...@@ -828,7 +833,7 @@ class Linear(TransformerEngineBaseLayer): ...@@ -828,7 +833,7 @@ class Linear(TransformerEngineBaseLayer):
inp = cast_if_needed(inp, self.activation_dtype) inp = cast_if_needed(inp, self.activation_dtype)
# Get persistent fp8 weight buffer. None if buffer does not exist. # Get persistent fp8 weight buffer. None if buffer does not exist.
weight_fp8, weight_t_fp8 = self.get_fp8_weights_scratchpad(is_first_microbatch) weight_fp8, weight_t_fp8 = self.get_fp8_weights_scratchpad_and_cast(is_first_microbatch)
out = _Linear.apply( out = _Linear.apply(
self.weight, self.weight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment