Unverified Commit 3a9a4c83 authored by Frank Lin's avatar Frank Lin Committed by GitHub
Browse files

[Paddle][CUDAGraph] 175B GPT-3 Hybrid-Parallel Training with CUDAGraph (#957)



* NVTE_OVERRIDE_MAX_SEQ_LEN
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>

* small fix
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>

* preserve old amax_and_scale_update_inplace and new amax_and_scale_update_inplace
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>

* remove useless code path; try to simplify logic within the baseline
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>

* simplify logic
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>

* small fix
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix comments from Timmoon
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>

* fix comments from Timmoon
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>

* Update transformer_engine/paddle/distributed.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>

* diable bw fp8 update
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>

* fix lint
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>

* fix ci error
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>

---------
Signed-off-by: default avatarFrank Lin <eee4017@gmail.com>
Co-authored-by: default avatarFrank Lin (Engrg-Hardware 1) <fralin@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 67b67432
......@@ -9,7 +9,7 @@ import paddle
import paddle.nn.functional as F
from transformer_engine import transformer_engine_paddle as tex
from .constants import TE_DType, FusedAttnBackend, FP8FwdTensors, FP8BwdTensors
from .fp8 import FP8TensorMeta
from .fp8 import FP8TensorMeta, get_global_fp8_state
BACKEND_F16m512_THREADS_PER_CTA = 128
BACKEND_F16arb_ELTS_PER_THREADS = 16
......@@ -526,6 +526,8 @@ def mask_to_cu_seqlens(
) -> paddle.Tensor:
"""Convert mask to cu_seqlens"""
# mask shape: [b, 1, s_q, s_kv]
if get_global_fp8_state().is_cudagraph_enabled():
raise RuntimeError("mask_to_cu_seqlens is not supported with cuda graphs.")
q_seqlen, kv_seqlen = mask.shape[2], mask.shape[3]
q_cu_seqlens = paddle.empty(shape=[mask.shape[0] + 1], dtype=paddle.int32)
q_cu_seqlens[0] = 0
......
......@@ -10,6 +10,7 @@
#include "common.h"
#include "common/common.h"
#include "paddle/phi/backends/gpu/cuda/cuda_graph.h"
namespace transformer_engine {
namespace paddle_ext {
......@@ -581,11 +582,44 @@ std::vector<paddle::Tensor> te_rmsnorm_bwd(const paddle::Tensor &dz, const paddl
return {dx, dgamma};
}
__global__ void set_rng_state(std::pair<uint64_t, uint64_t> seed_offset, int64_t *rng_state_ptr) {
__global__ void set_rng_state(
[[maybe_unused]] unsigned int
identifier, // This is used to relate kernel to cudaGraph nodes please refer to https://github.com/PaddlePaddle/Paddle/pull/60516
std::pair<uint64_t, uint64_t> seed_offset, int64_t *rng_state_ptr) {
rng_state_ptr[0] = static_cast<int64_t>(seed_offset.first);
rng_state_ptr[1] = static_cast<int64_t>(seed_offset.second);
}
void UpdateRandomGenerator(phi::Place place, cudaStream_t stream, int rng_elts_per_thread,
paddle::Tensor &rng_state) {
// extract random number generator seed and offset
const phi::DeviceContext *dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(place);
phi::Generator *gen_cuda = dev_ctx->GetGenerator();
auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread);
auto state_index = gen_cuda->GetStateIndex();
int64_t *rng_state_p = static_cast<int64_t *>(rng_state.data());
auto parameterSetter = [gen_cuda, state_index,
rng_elts_per_thread](phi::backends::gpu::CUDAKernelParams &params) {
// ensure the generator use correct state index
gen_cuda->SetStateIndex(state_index);
auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread);
params.As<std::pair<int64_t, int64_t>>(1) = seed_offset;
};
phi::backends::gpu::CUDAGraphNodeLauncher::cudaKernelCallback_t cudaKernelCallback =
[=](unsigned int id) {
void *functionPtr = reinterpret_cast<void *>(&set_rng_state);
cudaFunction_t cudaFunc;
PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, functionPtr));
set_rng_state<<<1, 1, 0, stream>>>(id, seed_offset, rng_state_p);
return cudaFunc;
};
phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch(parameterSetter,
cudaKernelCallback);
}
void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor &cu_seqlens,
const paddle::optional<paddle::Tensor> &Bias,
paddle::Tensor &O, // NOLINT
......@@ -623,12 +657,7 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// extract random number generator seed and offset
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(QKV.place());
auto gen_cuda = dev_ctx->GetGenerator();
auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread);
set_rng_state<<<1, 1, 0, QKV.stream()>>>(seed_offset, static_cast<int64_t *>(rng_state.data()));
UpdateRandomGenerator(QKV.place(), QKV.stream(), rng_elts_per_thread, rng_state);
auto te_rng_state = MakeNvteTensor(rng_state);
// create auxiliary output tensors
......@@ -799,10 +828,7 @@ void te_fused_attn_fwd_kvpacked(
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(Q.place());
auto gen_cuda = dev_ctx->GetGenerator();
auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread);
set_rng_state<<<1, 1, 0, Q.stream()>>>(seed_offset, static_cast<int64_t *>(rng_state.data()));
UpdateRandomGenerator(Q.place(), Q.stream(), rng_elts_per_thread, rng_state);
auto te_rng_state = MakeNvteTensor(rng_state);
// create auxiliary output tensors
......@@ -979,7 +1005,27 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(Q.place());
auto gen_cuda = dev_ctx->GetGenerator();
auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread);
set_rng_state<<<1, 1, 0, Q.stream()>>>(seed_offset, static_cast<int64_t *>(rng_state.data()));
auto state_index = gen_cuda->GetStateIndex();
auto rng_state_p = static_cast<int64_t *>(rng_state.data());
auto stream = Q.stream();
auto parameterSetter = [gen_cuda, state_index,
rng_elts_per_thread](phi::backends::gpu::CUDAKernelParams &params) {
// ensure the generator use correct state index
gen_cuda->SetStateIndex(state_index);
auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread);
params.As<std::pair<int64_t, int64_t>>(1) = seed_offset;
};
phi::backends::gpu::CUDAGraphNodeLauncher::cudaKernelCallback_t cudaKernelCallback =
[=](unsigned int id) {
void *functionPtr = reinterpret_cast<void *>(&set_rng_state);
cudaFunction_t cudaFunc;
PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, functionPtr));
set_rng_state<<<1, 1, 0, stream>>>(id, seed_offset, rng_state_p);
return cudaFunc;
};
phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch(parameterSetter,
cudaKernelCallback);
auto te_rng_state = MakeNvteTensor(rng_state);
......@@ -1260,6 +1306,29 @@ void te_scaled_upper_triang_masked_softmax_backward(paddle::Tensor &output_grads
softmax_results.stream());
}
__global__ void UpdateFP8MetaKernel(
[[maybe_unused]] unsigned int
identifier, // This is used to relate kernel to cudaGraph nodes please refer to https://github.com/PaddlePaddle/Paddle/pull/60516
const float *amax, const float *rolled_amax_history, const bool *non_weight_mask,
float *amax_history, float *scale, float *scale_inv, bool update_weight_scale_inv, float margin,
float fp8_max, size_t history_numel, size_t amax_numel) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= history_numel) {
return;
}
amax_history[idx] = rolled_amax_history[idx];
if (idx < amax_numel) {
float sf = (fp8_max / amax[idx]) / powf(2.0f, margin);
float scale_reg = ((amax[idx] > 0.0f) && isfinite(amax[idx])) ? sf : scale[idx];
scale[idx] = scale_reg;
if (update_weight_scale_inv || non_weight_mask[idx]) scale_inv[idx] = 1.0f / scale_reg;
amax_history[idx] = 0.0f;
}
}
constexpr int BLOCK_SIZE = 512;
void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT
......@@ -1277,6 +1346,63 @@ void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT
static_cast<NVTEDType>(fp8_dtype), margin, amax_history.stream());
}
void amax_and_scale_update_inplace_legacy(paddle::Tensor &amax_history, // NOLINT
paddle::Tensor &scale, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
const paddle::Tensor &non_weight_mask,
const paddle::Tensor &current_step_id_tensor,
bool update_weight_scale_inv, bool fwd_update,
float fp8_max, float margin,
const std::string &amax_compute) {
NVTE_CHECK(amax_compute == "max" || amax_compute == "most_recent");
paddle::Tensor amax;
if (amax_compute == "max") {
amax = amax_history.max({0});
} else {
amax = amax_history.slice(0, 1);
}
const auto rolled_amax_history = amax_history.roll({-1}, {0});
auto amax_history_numel = amax_history.numel();
auto amax_numel = amax.numel();
size_t num_blocks = (amax_history_numel + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int *current_step_id_ptr = nullptr;
if (fwd_update) current_step_id_ptr = current_step_id_tensor.data<int>();
auto parameterSetter = [current_step_id_ptr,
fwd_update](phi::backends::gpu::CUDAKernelParams &params) {
if (fwd_update) {
int current_step_id = *current_step_id_ptr;
params.As<bool>(7) = (current_step_id == 0);
}
};
const float *amax_ptr = amax.data<float>();
const float *rolled_amax_history_ptr = rolled_amax_history.data<float>();
const bool *non_weight_mask_ptr = non_weight_mask.data<bool>();
float *amax_history_ptr = amax_history.data<float>();
float *scale_ptr = scale.data<float>();
float *scale_inv_ptr = scale_inv.data<float>();
phi::backends::gpu::CUDAGraphNodeLauncher::cudaKernelCallback_t cudaKernelCallback =
[=](unsigned int id) {
void *functionPtr = reinterpret_cast<void *>(&UpdateFP8MetaKernel);
cudaFunction_t cudaFunc;
PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, functionPtr));
UpdateFP8MetaKernel<<<num_blocks, BLOCK_SIZE, 0, amax_history.stream()>>>(
id, amax_ptr, rolled_amax_history_ptr, non_weight_mask_ptr, amax_history_ptr, scale_ptr,
scale_inv_ptr, update_weight_scale_inv, margin, fp8_max, amax_history_numel,
amax_numel);
NVTE_CHECK_CUDA(cudaGetLastError());
return cudaFunc;
};
phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch(parameterSetter,
cudaKernelCallback);
}
void update_latest_amax_history_inplace(paddle::Tensor &history, // NOLINT
const paddle::Tensor &amax) {
// Copy amax to history[0]
......@@ -1617,6 +1743,16 @@ PD_BUILD_OP(te_scaled_upper_triang_masked_softmax_backward)
.SetKernelFn(
PD_KERNEL(transformer_engine::paddle_ext::te_scaled_upper_triang_masked_softmax_backward));
PD_BUILD_OP(amax_and_scale_update_inplace_legacy)
.Inputs({"_amax_history", "_scale", "_scale_inv", "non_weight_mask", "current_step_id_tensor"})
.Outputs({"amax_history", "scale", "scale_inv"})
.SetInplaceMap({{"_amax_history", "amax_history"},
{"_scale", "scale"},
{"_scale_inv", "scale_inv"}})
.Attrs({"update_weight_scale_inv: bool", "fwd_update: bool", "fp8_max: float", "margin: float",
"amax_compute: std::string"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::amax_and_scale_update_inplace_legacy));
PD_BUILD_OP(amax_and_scale_update_inplace)
.Inputs({"_amax_history", "_scale", "_scale_inv", "non_weight_mask"})
.Outputs({"amax_history", "scale", "scale_inv"})
......
......@@ -14,6 +14,16 @@ import paddle.distributed.fleet.base.topology as tp
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.fleet.layers.mpu import mp_ops
try:
# This feature is not supported as of Paddle 2.6.
from paddle.distributed.fleet.meta_parallel import (
PipelineParallelMicroStepLocations,
register_global_pipeline_parallel_hook,
)
except ImportError:
print("Cannot find register_global_pipeline_parallel_hook !")
register_global_pipeline_parallel_hook = None
from .constants import dist_group_type
_weight_split_axis = {
......@@ -54,6 +64,22 @@ def get_tp_group_and_world_size(
return model_parallel_group, world_size
def is_pp_enabled() -> bool:
"""Check if pipeline parallel is enabled"""
if not paddle.distributed.is_initialized():
return False
return tp._HYBRID_PARALLEL_GROUP.get_pipe_parallel_world_size() > 1
def register_pp_fwd_begin_hook(forward_begin_hook):
"""Register the pp hook if register_global_pipeline_parallel_hook exist"""
if register_global_pipeline_parallel_hook is not None:
register_global_pipeline_parallel_hook(
PipelineParallelMicroStepLocations.FORWARD_BEGIN, forward_begin_hook
)
@contextmanager
def track_rng_state(enable: bool, **kwargs) -> None:
"""
......
......@@ -62,6 +62,7 @@ class FP8State:
self._fp8_autocast_counter = 0
self._fp8_autocast_depth = 0
self._fp8_recompute_enabled = False
self._use_cudagraph = False
self._fp8_fwd_buffer = FP8MetaFwdBuffer()
self._fp8_bwd_buffer = FP8MetaBwdBuffer()
self._fp8_recompute_buffer = FP8RecomputeBuffer()
......@@ -116,6 +117,18 @@ class FP8State:
"""Returns global fp8 recompute buffer."""
return self._fp8_recompute_buffer
def is_cudagraph_enabled(self) -> bool:
"""Is CUDAGraph enabled"""
return self._use_cudagraph
def enable_cudagraph(self):
"""Enable CUDA Graphs. Once CUDA Graphs are enabled, they cannot be disabled within the same execution context at current implementation."""
self._use_cudagraph = True
self._fp8_fwd_buffer.enable_cudagraph()
self._fp8_bwd_buffer.enable_cudagraph()
if self._fp8_recompute_enabled:
raise RuntimeError("Currently, We do not allow recompute with cudagraph")
def enter(
self,
enabled: bool,
......@@ -235,25 +248,45 @@ def amax_and_scale_update(
fp8_meta: Dict[str, Any],
fwd_update: bool,
update_weight_scale_inv: bool = True,
current_step_id_tensor: Optional[paddle.Tensor] = None,
use_cudagraph: bool = False,
) -> None:
"""Updates fp8 amaxes/scales for fwd | bwd."""
amax_compute = fp8_meta["recipe"].amax_compute_algo
sf_compute = fp8_meta["recipe"].scaling_factor_compute_algo
fp8_meta_tensor_key = "scaling_fwd" if fwd_update else "scaling_bwd"
fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd"
if not callable(amax_compute) and sf_compute is None:
non_weight_mask = fp8_meta[fp8_meta_tensor_key].non_weight_mask
if update_weight_scale_inv:
non_weight_mask = paddle.empty([0])
tex.amax_and_scale_update_inplace(
_amax_history=fp8_meta[fp8_meta_tensor_key].amax_history,
_scale=fp8_meta[fp8_meta_tensor_key].scale,
_scale_inv=fp8_meta[fp8_meta_tensor_key].scale_inv,
non_weight_mask=non_weight_mask,
fp8_dtype=int(get_fp8_te_dtype(fp8_meta["recipe"], fwd_update)),
margin=float(fp8_meta["recipe"].margin),
amax_compute=amax_compute,
)
if use_cudagraph:
tex.amax_and_scale_update_inplace_legacy(
_amax_history=fp8_meta[fp8_meta_tensor_key].amax_history,
_scale=fp8_meta[fp8_meta_tensor_key].scale,
_scale_inv=fp8_meta[fp8_meta_tensor_key].scale_inv,
non_weight_mask=non_weight_mask,
current_step_id_tensor=current_step_id_tensor,
update_weight_scale_inv=update_weight_scale_inv,
fwd_update=fwd_update,
fp8_max=fp8_meta[fp8_max_key],
margin=float(fp8_meta["recipe"].margin),
amax_compute=amax_compute,
)
else:
if update_weight_scale_inv:
# we pass nullptr into kernel when we need to update_weight_scale_inv
non_weight_mask = paddle.empty([0])
tex.amax_and_scale_update_inplace(
_amax_history=fp8_meta[fp8_meta_tensor_key].amax_history,
_scale=fp8_meta[fp8_meta_tensor_key].scale,
_scale_inv=fp8_meta[fp8_meta_tensor_key].scale_inv,
non_weight_mask=non_weight_mask,
fp8_dtype=int(get_fp8_te_dtype(fp8_meta["recipe"], fwd_update)),
margin=float(fp8_meta["recipe"].margin),
amax_compute=amax_compute,
)
else:
raise ValueError(
"We only support the fp8 recipe with 'max' or 'most_recent' "
......
......@@ -22,10 +22,12 @@ class FP8MetaBufferBase(ABC):
"""
def __init__(self):
self._data = {}
self._global_amax = {}
self._buffer_delete_key = None
self._amax_reduce_wait_func = None
self._dp_amax_reduce_interval = None
self._contiguous_amax = None
self._use_cudagraph = False
self._dp_amax_reduce_idx = 0
@staticmethod
......@@ -44,13 +46,13 @@ class FP8MetaBufferBase(ABC):
"""Returns autocast id key in `fp8_meta`."""
def _get_amax_buffer_key(self, fp8_meta: Dict[str, Any]) -> str:
"""Return a key in `_data` for the AMAX storage."""
"""Return a key in `_global_amax` for the AMAX storage."""
return f"AMAX_{fp8_meta[self._get_autocast_key()]}"
def _execute_deletion(self) -> None:
"""Delete the key from global amax buffer."""
if self._buffer_delete_key is not None and self._buffer_delete_key in self._data:
del self._data[self._buffer_delete_key]
if self._buffer_delete_key is not None and self._buffer_delete_key in self._global_amax:
del self._global_amax[self._buffer_delete_key]
def _wait_handle_and_split(
self,
......@@ -62,7 +64,12 @@ class FP8MetaBufferBase(ABC):
"""Wait for amax reduction to finish and then copy reduced amax to buffer"""
if wait_handle is not None:
wait_handle.wait()
self._data[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes))
if self._use_cudagraph:
splited_list = list(contiguous_amax.split(chunk_sizes))
for amax, split in zip(self._global_amax[amax_buffer_key], splited_list):
amax.copy_(split, False)
else:
self._global_amax[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes))
def _global_amax_reduction(
self,
......@@ -85,7 +92,7 @@ class FP8MetaBufferBase(ABC):
amax_buffer_key = self._get_amax_buffer_key(fp8_meta)
# Key already deleted.
if amax_buffer_key not in self._data:
if amax_buffer_key not in self._global_amax:
return None
# Reduce AMAX in DP-domain at an interval.
......@@ -105,18 +112,32 @@ class FP8MetaBufferBase(ABC):
else:
return None
chunk_sizes = [x.shape[0] for x in self._data[amax_buffer_key]]
contiguous_amax = paddle.concat(self._data[amax_buffer_key])
chunk_sizes = [x.shape[0] for x in self._global_amax[amax_buffer_key]]
if self._use_cudagraph:
# we need to ensure the _contiguous_amax is address-stable under cudagraph
if self._contiguous_amax is None:
self._contiguous_amax = paddle.concat(self._global_amax[amax_buffer_key])
else:
self._contiguous_amax.copy_(
paddle.concat(self._global_amax[amax_buffer_key]), False
)
else:
self._contiguous_amax = paddle.concat(self._global_amax[amax_buffer_key])
wait_handle = _reduce_tensor_across_group_op_max(
contiguous_amax,
self._contiguous_amax,
reduce_group,
not fp8_meta["async_amax_reduction"],
)
if wait_handle is not None and self._use_cudagraph:
# we need to ensure record/wait does not cross the boundary of the graph
wait_handle.wait()
wait_handle = None
return partial(
self._wait_handle_and_split,
contiguous_amax,
self._contiguous_amax,
chunk_sizes,
amax_buffer_key,
wait_handle,
......@@ -128,16 +149,16 @@ class FP8MetaBufferBase(ABC):
fp8_meta_tensor_key = self._get_meta_tensor_key()
buffer_position_key = self._get_buffer_position_key()
if buffer_key not in self._data:
self._data[buffer_key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
if buffer_key not in self._global_amax:
self._global_amax[buffer_key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
else:
self._data[buffer_key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0])
self._global_amax[buffer_key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0])
if buffer_position_key not in fp8_meta:
fp8_meta[buffer_position_key] = len(self._data[buffer_key]) - 1
fp8_meta[buffer_position_key] = len(self._global_amax[buffer_key]) - 1
# Catch incorrect fp8_autocast usage.
assert fp8_meta[buffer_position_key] == len(self._data[buffer_key]) - 1, (
assert fp8_meta[buffer_position_key] == len(self._global_amax[buffer_key]) - 1, (
"Same module is being invoked more than once inside an `fp8_autocast` "
"region when using FP8 with amax reduction. This behavior is currently "
"unsupported. For more details and correct usage, please see "
......@@ -152,12 +173,12 @@ class FP8MetaBufferBase(ABC):
return
amax_buffer_key = self._get_amax_buffer_key(fp8_meta)
assert amax_buffer_key in self._data, "TE internal error."
assert amax_buffer_key in self._global_amax, "TE internal error."
# Copy amax to amax_history[0]
tex.update_latest_amax_history_inplace(
_history=fp8_meta[fp8_meta_tensor_key].amax_history,
amax=self._data[amax_buffer_key][fp8_meta[buffer_position_key]],
amax=self._global_amax[amax_buffer_key][fp8_meta[buffer_position_key]],
)
def set_for_deletion(self, fp8_meta: Dict[str, Any]) -> None:
......@@ -179,14 +200,18 @@ class FP8MetaBufferBase(ABC):
def to_numpy(self) -> Dict[str, List[np.array]]:
"""Convert to numpy arrays"""
out = {}
for k, v in self._data.items():
for k, v in self._global_amax.items():
out[k] = [tensor.numpy() for tensor in v]
return out
def from_numpy(self, buffer: Dict[str, np.array]) -> None:
"""Set buffer values from numpy arrays"""
for k, v in buffer.items():
self._data[k] = [paddle.to_tensor(arr) for arr in v]
self._global_amax[k] = [paddle.to_tensor(arr) for arr in v]
def enable_cudagraph(self):
"""Enable CUDA Graphs."""
self._use_cudagraph = True
class FP8MetaFwdBuffer(FP8MetaBufferBase):
......@@ -259,7 +284,9 @@ class FP8MetaBwdBuffer(FP8MetaBufferBase):
Called at FP8 autocast end in backward.
Performs AMAX reduction and delete unused buffer entries.
"""
self._amax_reduce_wait_func = self._global_amax_reduction(fp8_meta, tp_group, tp_size)
self._amax_reduce_wait_func = self._global_amax_reduction(
fp8_meta, tp_group, tp_size
) # _wait_handle_and_split
self._execute_deletion()
......@@ -267,7 +294,7 @@ class FP8RecomputeBuffer:
"""Buffer used to hold FP8 meta tensors for recompute"""
def __init__(self):
self._data = []
self._global_amax = []
@staticmethod
def get_buffer_position_key():
......@@ -285,11 +312,11 @@ class FP8RecomputeBuffer:
]
if buffer_position_key in fp8_meta:
self._data[fp8_meta[buffer_position_key]].append(to_copy)
self._global_amax[fp8_meta[buffer_position_key]].append(to_copy)
else:
self._data.append(deque())
self._data[-1].append(to_copy)
fp8_meta[buffer_position_key] = len(self._data) - 1
self._global_amax.append(deque())
self._global_amax[-1].append(to_copy)
fp8_meta[buffer_position_key] = len(self._global_amax) - 1
def retrieve_fp8_meta_tensors(self, fp8_meta: Dict[str, Any]) -> None:
"""Switch to the previously saved scaling factors and amaxes"""
......@@ -300,7 +327,7 @@ class FP8RecomputeBuffer:
# Retrieve stashed amaxes and scales from phase 1 pre forward.
buffer_position_key = self.get_buffer_position_key()
stashed_fp8_meta = self._data[fp8_meta[buffer_position_key]].popleft()
stashed_fp8_meta = self._global_amax[fp8_meta[buffer_position_key]].popleft()
# Replace amaxes and scales with stashed values for phase 2 forward
fp8_meta["scaling_fwd"].amax_history = stashed_fp8_meta[0]
......
......@@ -20,7 +20,7 @@ except ImportError:
from paddle.fluid import core
from paddle.fluid.framework import _dygraph_tracer
from ..constants import FP8BwdTensors, dist_group_type
from ..constants import FP8FwdTensors, FP8BwdTensors, dist_group_type
from ..cpp_extensions import cast_transpose, cast_transpose_bgrad, cast_to_fp8, transpose
from ..fp8 import (
FP8State,
......@@ -29,7 +29,7 @@ from ..fp8 import (
get_global_fp8_state,
get_fp8_te_dtype,
)
from ..distributed import allgather
from ..distributed import allgather, register_pp_fwd_begin_hook, is_pp_enabled
from ..profile import nvtx_range
from ..recompute import is_in_recompute_phase
from ..fp8_buffer import FP8RecomputeBuffer
......@@ -80,8 +80,19 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
self.fp8_meta["async_amax_reduction"] = bool(
int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0"))
)
self.fp8_weight_shapes = []
# weights that stored in fp16 would be cast into fp8 every first microstep
self.fp8_weights = []
self.fp8_weight_cache = {}
self.registered_pp_start_callback = False
self.current_step_id = paddle.to_tensor([1], dtype=paddle.int32, place=paddle.CPUPlace())
def current_step_id_callback(step_id=None, **kwargs): # pylint: disable=unused-argument
self.current_step_id.copy_(
paddle.to_tensor([step_id], dtype=paddle.int32, place=paddle.CPUPlace()), True
)
register_pp_fwd_begin_hook(current_step_id_callback)
def set_activation_dtype(self, inp: paddle.Tensor) -> None:
"""Get activation data type for AMP."""
......@@ -157,23 +168,23 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
if not self.fp8_enabled:
return
for i, shape in enumerate(self.fp8_weight_shapes, start=1):
for i, weight in enumerate(self.fp8_weights, start=1):
weight_cast_key = f"weight{i}_fp8"
weight_transpose_key = f"weight{i}_t_fp8"
if (
weight_cast_key in self.fp8_weight_cache
and self.fp8_weight_cache[weight_cast_key].shape == shape
and self.fp8_weight_cache[weight_cast_key].shape == weight.shape
):
return
self.fp8_weight_cache[weight_cast_key] = paddle.empty(
shape=shape,
shape=weight.shape,
dtype=paddle.uint8,
)
self.fp8_weight_cache[weight_transpose_key] = paddle.empty(
shape=[shape[1], shape[0]],
shape=[weight.shape[1], weight.shape[0]],
dtype=paddle.uint8,
)
......@@ -293,12 +304,20 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
if self.fp8_meta["recipe"].reduce_amax:
global_fp8_fwd_buffer.copy_amax_from_buffer(self.fp8_meta)
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
self.fp8_meta,
fwd_update=True,
update_weight_scale_inv=update_weight_scale_inv,
current_step_id_tensor=self.current_step_id,
use_cudagraph=get_global_fp8_state().is_cudagraph_enabled(),
)
global_fp8_fwd_buffer.set_for_deletion(self.fp8_meta)
else:
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
self.fp8_meta,
fwd_update=True,
update_weight_scale_inv=update_weight_scale_inv,
current_step_id_tensor=self.current_step_id,
use_cudagraph=get_global_fp8_state().is_cudagraph_enabled(),
)
if self.fp8_enabled and self.training:
......@@ -355,13 +374,21 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
if fp8_meta["recipe"].reduce_amax:
global_fp8_bwd_buffer.copy_amax_from_buffer(fp8_meta)
amax_and_scale_update(fp8_meta, False)
amax_and_scale_update(
fp8_meta,
fwd_update=False,
use_cudagraph=get_global_fp8_state().is_cudagraph_enabled(),
)
global_fp8_bwd_buffer.set_for_deletion(fp8_meta)
# Get new backward key.
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0)
else:
amax_and_scale_update(fp8_meta, False)
amax_and_scale_update(
fp8_meta,
fwd_update=False,
use_cudagraph=get_global_fp8_state().is_cudagraph_enabled(),
)
with nvtx_range(name + " backward"):
yield
......@@ -439,14 +466,13 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
fp8_dtype_backward,
)
bgrad = None
return grad_output_mat, grad_output_c, grad_output_t, bgrad
@abstractmethod
def forward(self):
"""Needs override."""
def get_fp8_weights_scratchpad(
def get_fp8_weights_scratchpad_and_cast(
self,
is_first_microbatch: Union[bool, None],
) -> List[Optional[paddle.Tensor]]:
......@@ -455,10 +481,10 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
`is_first_microbatch` is not `None`)
"""
if not self.fp8_enabled or is_first_microbatch is None:
return [None, None] * len(self.fp8_weight_shapes)
return [None, None] * len(self.fp8_weights)
out_list = []
for i, _ in enumerate(self.fp8_weight_shapes, start=1):
for i, _ in enumerate(self.fp8_weights, start=1):
weight_cast_key = f"weight{i}_fp8"
weight_transpose_key = f"weight{i}_t_fp8"
......@@ -466,10 +492,67 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
weight_cast_key in self.fp8_weight_cache
), "TE internal error: fp8 weight buffer is not found"
out_list.extend(
[
self.fp8_weight_cache[weight_cast_key],
self.fp8_weight_cache[weight_transpose_key],
]
)
weight_fp8 = self.fp8_weight_cache[weight_cast_key]
weight_t_fp8 = self.fp8_weight_cache[weight_transpose_key]
# Disable fp8 weight cache
# is_first_microbatch is None -> we cast the weights into fp8 every micro step
# Enalbe fp8 weight cache
# is_first_microbatch == true -> we cast the weights into fp8 every micro step
out_list.extend([weight_fp8, weight_t_fp8])
# is cudagraph is enabled we cast the weight before the pp pipe
# we only register the callback once
if get_global_fp8_state().is_cudagraph_enabled() and (
not self.registered_pp_start_callback and is_pp_enabled()
):
fp8_dtype_forward = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True)
def cast_callback(step_id=None, **kwargs): # pylint: disable=unused-argument
update_fp8_weights = step_id == 0
for i, weight in enumerate(self.fp8_weights, start=1):
weight_cast_key = f"weight{i}_fp8"
weight_transpose_key = f"weight{i}_t_fp8"
assert (
weight_cast_key in self.fp8_weight_cache
), "TE internal error: fp8 weight buffer is not found"
weight_fp8 = self.fp8_weight_cache[weight_cast_key]
weight_t_fp8 = self.fp8_weight_cache[weight_transpose_key]
if paddle.is_grad_enabled():
if update_fp8_weights:
cast_transpose(
weight,
self.fp8_meta["scaling_fwd"],
(
FP8FwdTensors.GEMM1_WEIGHT
if i == 1
else FP8FwdTensors.GEMM2_WEIGHT
),
fp8_dtype_forward,
cast_out=weight_fp8,
transpose_out=weight_t_fp8,
)
else:
if update_fp8_weights:
cast_to_fp8(
weight,
self.fp8_meta["scaling_fwd"],
(
FP8FwdTensors.GEMM1_WEIGHT
if i == 1
else FP8FwdTensors.GEMM2_WEIGHT
),
fp8_dtype_forward,
out=weight_fp8,
)
cast_callback(0 if is_first_microbatch else 1)
register_pp_fwd_begin_hook(cast_callback)
self.registered_pp_start_callback = True
return out_list
......@@ -562,7 +562,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
set_weight_tensor_dist_attr(
self.weight, self.tensor_parallel, self.parallel_mode, self.backend
)
self.fp8_weight_shapes.append(self.weight.shape)
self.fp8_weights.append(self.weight)
# Initialize Linear bias parameter
self.has_bias = self._bias_attr is not False
......@@ -616,7 +616,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
inp = cast_if_needed(inp, self.activation_dtype)
# Get persistent fp8 weight buffer. None if buffer does not exist.
weight_fp8, weight_t_fp8 = self.get_fp8_weights_scratchpad(is_first_microbatch)
weight_fp8, weight_t_fp8 = self.get_fp8_weights_scratchpad_and_cast(is_first_microbatch)
out = _LayerNormLinear.apply(
inp,
......
......@@ -814,7 +814,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
set_weight_tensor_dist_attr(
self.fc1_weight, self.tensor_parallel, parallel_mode="column", backend=self.backend
)
self.fp8_weight_shapes.append(self.fc1_weight.shape)
self.fp8_weights.append(self.fc1_weight)
self.has_bias = self._bias_attr is not False
use_default_bias = self._bias_attr is None or self._bias_attr is True
......@@ -846,7 +846,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
set_weight_tensor_dist_attr(
self.fc2_weight, self.tensor_parallel, parallel_mode="row", backend=self.backend
)
self.fp8_weight_shapes.append(self.fc2_weight.shape)
self.fp8_weights.append(self.fc2_weight)
if self.has_bias:
self.fc2_bias = self.create_parameter(
......@@ -892,7 +892,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
# Get persistent fp8 weight buffer. None if buffer does not exist.
fc1_weight_fp8, fc1_weight_t_fp8, fc2_weight_fp8, fc2_weight_t_fp8 = (
self.get_fp8_weights_scratchpad(is_first_microbatch)
self.get_fp8_weights_scratchpad_and_cast(is_first_microbatch)
)
out = _LayerNormMLP.apply(
......
......@@ -31,7 +31,7 @@ from ..distributed import (
set_weight_tensor_dist_attr,
mark_as_sequence_parallel_parameter,
)
from ..fp8 import get_fp8_te_dtype
from ..fp8 import get_fp8_te_dtype, get_global_fp8_state
from ..utils import (
assert_dim_for_fp8_forward_exec,
cast_if_needed,
......@@ -74,27 +74,29 @@ def _linear_fwd_fp8(
else:
inputmat_total = inputmat
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
if is_grad_enabled:
if update_fp8_weights:
weight_fp8, weight_t_fp8 = cast_transpose(
weight,
fp8_meta["scaling_fwd"],
weight_fp8_index,
fp8_dtype_forward,
cast_out=weight_fp8,
transpose_out=weight_t_fp8,
)
else:
weight_t_fp8 = None
if update_fp8_weights:
weight_fp8 = cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
weight_fp8_index,
fp8_dtype_forward,
out=weight_fp8,
)
if not get_global_fp8_state().is_cudagraph_enabled():
# if cuda graph is not enabled, we cast the weight here
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
if is_grad_enabled:
if update_fp8_weights:
weight_fp8, weight_t_fp8 = cast_transpose(
weight,
fp8_meta["scaling_fwd"],
weight_fp8_index,
fp8_dtype_forward,
cast_out=weight_fp8,
transpose_out=weight_t_fp8,
)
else:
weight_t_fp8 = None
if update_fp8_weights:
weight_fp8 = cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
weight_fp8_index,
fp8_dtype_forward,
out=weight_fp8,
)
out, _ = fp8_gemm(
weight_fp8,
......@@ -346,6 +348,8 @@ def _linear_bwd_fp8(
if parallel_mode == "column" and tensor_parallel and handle is not None:
handle.wait()
if parallel_mode == "column" and sequence_parallel:
handle.wait()
return dgrad, wgrad
......@@ -416,9 +420,10 @@ def _linear_bwd_non_fp8(
elif requires_bgrad:
bgrad = grad_output.sum(axis=0)
if parallel_mode == "column" and tensor_parallel and handle is not None:
handle.wait()
if parallel_mode == "column" and sequence_parallel and handle is not None:
handle.wait()
return dgrad, wgrad, bgrad
......@@ -804,7 +809,7 @@ class Linear(TransformerEngineBaseLayer):
else:
self.bias = None
self.fp8_weight_shapes.append(self.weight.shape)
self.fp8_weights.append(self.weight)
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
......@@ -828,7 +833,7 @@ class Linear(TransformerEngineBaseLayer):
inp = cast_if_needed(inp, self.activation_dtype)
# Get persistent fp8 weight buffer. None if buffer does not exist.
weight_fp8, weight_t_fp8 = self.get_fp8_weights_scratchpad(is_first_microbatch)
weight_fp8, weight_t_fp8 = self.get_fp8_weights_scratchpad_and_cast(is_first_microbatch)
out = _Linear.apply(
self.weight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment