Unverified Commit c32a62cc authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[Paddle] Eliminate amax update bubbles by using custom_ops (#436)



* Eliminate amax_and_scale_update bubbles
Signed-off-by: default avatarrewang <rewang@nvidia.com>

* Add CUDA check
Signed-off-by: default avatarrewang <rewang@nvidia.com>

---------
Signed-off-by: default avatarrewang <rewang@nvidia.com>
parent 7e759174
...@@ -865,13 +865,17 @@ class TestSoftmax: ...@@ -865,13 +865,17 @@ class TestSoftmax:
assert_allclose(dx_ref, dx, rtol=1e-4, atol=5e-3) assert_allclose(dx_ref, dx, rtol=1e-4, atol=5e-3)
def test_update_scale(): def test_amax_and_scale_update():
"""Test update_scale""" """Test update_scale"""
num_gemm = 6 num_gemm = 6
history_len = 1024
recipe = DelayedScaling() recipe = DelayedScaling()
fp8_max = recipe.fp8_format.value.max_fwd fp8_max = recipe.fp8_format.value.max_fwd
amax_tensor = paddle.rand(shape=[num_gemm], dtype='float32') * fp8_max amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype='float32')
rolled_history_ref = paddle.roll(amax_history_tensor, -1, axis=0)
rolled_history_ref[0] = 0.0
amax_tensor = paddle.max(amax_history_tensor, axis=0)
scale_tensor = paddle.ones(shape=[num_gemm], dtype='float32') scale_tensor = paddle.ones(shape=[num_gemm], dtype='float32')
def calc_ref(amax, scale, fp8_max, margin=0): def calc_ref(amax, scale, fp8_max, margin=0):
...@@ -884,6 +888,32 @@ def test_update_scale(): ...@@ -884,6 +888,32 @@ def test_update_scale():
return sf return sf
scale_ref = calc_ref(amax_tensor, scale_tensor, fp8_max, 0.) scale_ref = calc_ref(amax_tensor, scale_tensor, fp8_max, 0.)
scale_actual = tex.update_scale(amax_tensor, scale_tensor, fp8_max, 0.) scale_inv_ref = 1. / scale_ref
assert_allclose(scale_ref, scale_actual, rtol=1e-5, atol=1e-5) # Placeholder
scale_actual = paddle.zeros_like(scale_tensor)
scale_inv_actual = paddle.zeros_like(scale_tensor)
tex.amax_and_scale_update_inplace(_amax_history=amax_history_tensor,
_scale=scale_actual,
_scale_inv=scale_inv_actual,
fp8_max=fp8_max,
margin=0.,
amax_compute="max")
assert_allclose(scale_actual, scale_ref, rtol=1e-7, atol=1e-7)
assert_allclose(scale_inv_actual, scale_inv_ref, rtol=1e-7, atol=1e-7)
assert_allclose(amax_history_tensor, rolled_history_ref, rtol=1e-7, atol=1e-7)
def test_update_latest_history():
"""Test update_latest_history"""
num_gemm = 6
history_len = 1024
amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype='float32')
amax = paddle.rand(shape=[num_gemm], dtype='float32')
tex.update_latest_amax_history_inplace(_history=amax_history_tensor, amax=amax)
assert_allclose(amax_history_tensor[0], amax, rtol=1e-7, atol=1e-7)
...@@ -1019,28 +1019,62 @@ void te_scaled_upper_triang_masked_softmax_backward(paddle::Tensor &output_grads ...@@ -1019,28 +1019,62 @@ void te_scaled_upper_triang_masked_softmax_backward(paddle::Tensor &output_grads
softmax_results.stream()); softmax_results.stream());
} }
__global__ void UpdateScalesKernel(const float *amax, const float *scale, float margin, __global__ void UpdateFP8MetaKernel(const float *amax, const float *rolled_amax_history,
float fp8_max, size_t size, float *scale_out) { float *amax_history, float *scale, float *scale_inv,
float margin, float fp8_max, size_t history_numel,
size_t amax_numel) {
int idx = blockIdx.x * blockDim.x + threadIdx.x; int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < size) { if (idx >= history_numel) {
return;
}
amax_history[idx] = rolled_amax_history[idx];
if (idx < amax_numel) {
float exp = floor(log2(fp8_max / amax[idx])) - margin; float exp = floor(log2(fp8_max / amax[idx])) - margin;
float sf = round(powf(2.0f, abs(exp))); float sf = round(powf(2.0f, abs(exp)));
sf = ((amax[idx] > 0.0f) && isfinite(amax[idx])) ? sf : scale[idx]; float scale_reg = scale[idx];
scale_out[idx] = exp < 0.0f ? 1 / sf : sf; sf = ((amax[idx] > 0.0f) && isfinite(amax[idx])) ? sf : scale_reg;
scale_reg = exp < 0.0f ? 1 / sf : sf;
scale[idx] = scale_reg;
scale_inv[idx] = 1.0f / scale_reg;
amax_history[idx] = 0.0f;
} }
} }
std::vector<paddle::Tensor> update_scale(const paddle::Tensor &amax, const paddle::Tensor &scale, void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT
float fp8_max, float margin) { paddle::Tensor &scale, // NOLINT
const size_t block_size = 512; paddle::Tensor &scale_inv, // NOLINT
size_t size = static_cast<size_t>(amax.numel()); float fp8_max, float margin, const std::string &amax_compute) {
size_t num_blocks = (size + block_size - 1) / block_size; NVTE_CHECK(amax_compute == "max" || amax_compute == "most_recent");
auto scale_out = paddle::empty_like(scale, scale.dtype(), scale.place());
UpdateScalesKernel<<<num_blocks, block_size, 0, amax.stream()>>>( paddle::Tensor amax;
amax.data<float>(), scale.data<float>(), margin, fp8_max, size, scale_out.data<float>());
if (amax_compute == "max") {
amax = amax_history.max({0});
} else {
amax = amax_history.slice(0, 1);
}
const auto rolled_amax_history = amax_history.roll({-1}, {0});
auto size = amax_history.numel();
constexpr int BLOCK_SIZE = 256;
size_t num_blocks = (size + BLOCK_SIZE - 1) / BLOCK_SIZE;
UpdateFP8MetaKernel<<<num_blocks, BLOCK_SIZE, 0, amax_history.stream()>>>(
amax.data<float>(), rolled_amax_history.data<float>(), amax_history.data<float>(),
scale.data<float>(), scale_inv.data<float>(), margin, fp8_max, amax_history.numel(),
amax.numel());
NVTE_CHECK_CUDA(cudaGetLastError());
}
return {scale_out}; void update_latest_amax_history_inplace(paddle::Tensor &history, // NOLINT
const paddle::Tensor &amax) {
// Copy amax to history[0]
NVTE_CHECK_CUDA(cudaMemcpyAsync(history.data(), amax.data(),
amax.numel() * SizeOf(amax.dtype()), cudaMemcpyDeviceToDevice,
amax.stream()));
} }
} // namespace paddle_ext } // namespace paddle_ext
...@@ -1242,8 +1276,17 @@ PD_BUILD_OP(te_scaled_upper_triang_masked_softmax_backward) ...@@ -1242,8 +1276,17 @@ PD_BUILD_OP(te_scaled_upper_triang_masked_softmax_backward)
.SetKernelFn( .SetKernelFn(
PD_KERNEL(transformer_engine::paddle_ext::te_scaled_upper_triang_masked_softmax_backward)); PD_KERNEL(transformer_engine::paddle_ext::te_scaled_upper_triang_masked_softmax_backward));
PD_BUILD_OP(update_scale) PD_BUILD_OP(amax_and_scale_update_inplace)
.Inputs({"Amax", "Scale"}) .Inputs({"_amax_history", "_scale", "_scale_inv"})
.Outputs({"ScaleOut"}) .Outputs({"amax_history", "scale", "scale_inv"})
.Attrs({"fp8_max: float", "margin: float"}) .SetInplaceMap({{"_amax_history", "amax_history"},
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::update_scale)); {"_scale", "scale"},
{"_scale_inv", "scale_inv"}})
.Attrs({"fp8_max: float", "margin: float", "amax_compute: std::string"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::amax_and_scale_update_inplace));
PD_BUILD_OP(update_latest_amax_history_inplace)
.Inputs({"_history", "amax"})
.Outputs({"history"})
.SetInplaceMap({{"_history", "history"}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::update_latest_amax_history_inplace));
...@@ -197,30 +197,12 @@ def amax_and_scale_update( ...@@ -197,30 +197,12 @@ def amax_and_scale_update(
fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd" fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd"
if not callable(amax_compute) and sf_compute is None: if not callable(amax_compute) and sf_compute is None:
# Obtain amax from history tex.amax_and_scale_update_inplace(_amax_history=fp8_meta[fp8_meta_tensor_key].amax_history,
amax_history = fp8_meta[fp8_meta_tensor_key].amax_history _scale=fp8_meta[fp8_meta_tensor_key].scale,
if amax_compute == "max": _scale_inv=fp8_meta[fp8_meta_tensor_key].scale_inv,
amax = paddle.max(amax_history, axis=0)
else: # amax_compute_algo == "most_recent"
amax = amax_history[0]
# Update amax history and set next amax to zero
if amax_history.shape[0] > 1:
amax_history = paddle.roll(amax_history, -1, 0)
amax_history[0] = 0.0
fp8_meta[fp8_meta_tensor_key].amax_history = amax_history
# Update scaling factor
fp8_meta[fp8_meta_tensor_key].scale = tex.update_scale(
amax=amax,
scale=fp8_meta[fp8_meta_tensor_key].scale,
fp8_max=fp8_meta[fp8_max_key], fp8_max=fp8_meta[fp8_max_key],
margin=float(fp8_meta["recipe"].margin)) margin=float(fp8_meta["recipe"].margin),
amax_compute=amax_compute)
# Update scale_inv
fp8_meta[fp8_meta_tensor_key].scale_inv = \
1.0 / fp8_meta[fp8_meta_tensor_key].scale
else: else:
raise ValueError("We only support the fp8 recipe with 'max' or 'most_recent' " raise ValueError("We only support the fp8 recipe with 'max' or 'most_recent' "
"amax_compute_algo and default scaling_factor_compute_algo at this " "amax_compute_algo and default scaling_factor_compute_algo at this "
...@@ -247,7 +229,7 @@ class FP8TensorMeta(): ...@@ -247,7 +229,7 @@ class FP8TensorMeta():
curr_len = self.amax_history.shape[0] curr_len = self.amax_history.shape[0]
num_fp8_tensors = self.amax_history.shape[1] num_fp8_tensors = self.amax_history.shape[1]
if amax_history_len < curr_len: if amax_history_len < curr_len:
self.amax_history = (self.amax_history[:amax_history_len]) self.amax_history = self.amax_history[:amax_history_len]
elif amax_history_len > curr_len: elif amax_history_len > curr_len:
extra_rows = amax_history_len - curr_len extra_rows = amax_history_len - curr_len
self.amax_history = paddle.concat([ self.amax_history = paddle.concat([
......
...@@ -11,6 +11,7 @@ from typing import Dict, Any, List, Union ...@@ -11,6 +11,7 @@ from typing import Dict, Any, List, Union
import numpy as np import numpy as np
import paddle import paddle
import transformer_engine_paddle as tex
from .constants import dist_group_type, RecomputeFunctionNames from .constants import dist_group_type, RecomputeFunctionNames
...@@ -152,8 +153,10 @@ class FP8MetaBufferBase(ABC): ...@@ -152,8 +153,10 @@ class FP8MetaBufferBase(ABC):
amax_buffer_key = self._get_amax_buffer_key(fp8_meta) amax_buffer_key = self._get_amax_buffer_key(fp8_meta)
assert amax_buffer_key in self._data, "TE internal error." assert amax_buffer_key in self._data, "TE internal error."
fp8_meta[fp8_meta_tensor_key].amax_history[0] = self._data[amax_buffer_key][ # Copy amax to amax_history[0]
fp8_meta[buffer_position_key]] tex.update_latest_amax_history_inplace(
_history=fp8_meta[fp8_meta_tensor_key].amax_history,
amax=self._data[amax_buffer_key][fp8_meta[buffer_position_key]])
def set_for_deletion(self, fp8_meta: Dict[str, Any]) -> None: def set_for_deletion(self, fp8_meta: Dict[str, Any]) -> None:
"""Delete this amax key from global buffer during autocast end.""" """Delete this amax key from global buffer during autocast end."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment