Unverified Commit daad219f authored by Tian Zheng's avatar Tian Zheng Committed by GitHub
Browse files

[Paddle] Optimize memory usage when training in pipeline parallel (#580)



* Actively free tensor in bwd
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* - Add inplace support for fp8 casting
- Allow skipping weight update in fp8 meta update
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Support weight caching for Linear
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add weight caching for LayernormLinear
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add weight caching for LayerNormMLP
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add weight caching for Transformer layer
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add PP unittests
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix CI
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

---------
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
parent 2ae121d7
......@@ -19,6 +19,23 @@ from utils import assert_allclose, set_random_seed
import transformer_engine.paddle as te
class TELinear(te.Linear):
"""To pass is_first_microbatch"""
def __init__(self, *args, **kwargs):
assert 'accumulate_steps' in kwargs
self.accumulate_steps = kwargs['accumulate_steps']
del kwargs['accumulate_steps']
self._micro_batch_id = 0
super().__init__(*args, **kwargs)
def forward(self, *args, **kwargs):
kwargs['is_first_microbatch'] = (self._micro_batch_id % self.accumulate_steps) == 0
if paddle.is_grad_enabled() and self.training:
self._micro_batch_id += 1
return super().forward(*args, **kwargs)
class TEPipelineModel(PipelineLayer):
"""Model for pipeline parallel test"""
......@@ -28,6 +45,7 @@ class TEPipelineModel(PipelineLayer):
weight_attrs,
use_te=True,
use_fp8=False,
accumulate_steps=1,
**kwargs):
self.in_features = in_features
self.hidden_features = hidden_features
......@@ -35,10 +53,22 @@ class TEPipelineModel(PipelineLayer):
hcg = fleet.get_hybrid_communicate_group()
self.dp_group = hcg.get_data_parallel_group()
Linear = te.Linear if use_te else paddle.nn.Linear
Linear = TELinear if use_te else paddle.nn.Linear
extra_kwargs = {}
if use_te:
extra_kwargs['accumulate_steps'] = accumulate_steps
model_desc = [
LayerDesc(Linear, self.in_features, self.hidden_features, weight_attr=weight_attrs[0]),
LayerDesc(Linear, self.hidden_features, self.in_features, weight_attr=weight_attrs[1]),
LayerDesc(Linear,
self.in_features,
self.hidden_features,
weight_attr=weight_attrs[0],
**extra_kwargs),
LayerDesc(Linear,
self.hidden_features,
self.in_features,
weight_attr=weight_attrs[1],
**extra_kwargs),
]
super().__init__(layers=model_desc, loss_fn=paddle.nn.CrossEntropyLoss(), **kwargs)
......@@ -84,8 +114,9 @@ class TestLinearPipelineParallel(unittest.TestCase):
"mp_degree": 1,
"pp_degree": self.pipeline_parallel_size,
}
self.accumulate_steps = self.batch_size // self.micro_batch_size
strategy.pipeline_configs = {
"accumulate_steps": self.batch_size // self.micro_batch_size,
"accumulate_steps": self.accumulate_steps,
"micro_batch_size": self.micro_batch_size,
}
fleet.init(is_collective=True, strategy=strategy)
......@@ -128,6 +159,7 @@ class TestLinearPipelineParallel(unittest.TestCase):
use_fp8=self.fp8,
seg_method="layer:Linear",
num_stages=self.pipeline_parallel_size,
accumulate_steps=self.accumulate_steps,
)
# Check if model is split across ranks as expected
......
......@@ -5,7 +5,7 @@
import math
import os
from utils import assert_allclose
from utils import assert_allclose, is_fused_attention_supported
import paddle
import pytest
......@@ -14,8 +14,6 @@ from transformer_engine.common.recipe import DelayedScaling
import transformer_engine.paddle as te
from transformer_engine.paddle.fp8 import is_fp8_available, fp8_autocast
from utils import is_fused_attention_supported
is_fp8_supported, reason = is_fp8_available()
LINEAR_CASES = [(16, 16, 32), (32, 32, 64)]
NORM_CASES = [(16, 32), (256, 1024)]
......@@ -200,6 +198,50 @@ class TestLinear:
if do_calibration:
assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES)
@pytest.mark.parametrize('activation_dtype', ['bfloat16'])
@pytest.mark.parametrize('num_microbatch', [8])
def test_linear_fp8_microbatch(bs, in_features, out_features, activation_dtype, num_microbatch):
"""
Test FP8 Linear
"""
rtol = 0.1
atol = 0.1
recipe = DelayedScaling()
paddle.set_default_dtype(activation_dtype)
layer_cached = te.Linear(
in_features=in_features,
out_features=out_features,
)
layer_normal = te.Linear(
in_features=in_features,
out_features=out_features,
)
layer_cached.weight.copy_(layer_normal.weight, True)
layer_cached.bias.copy_(layer_normal.bias, True)
for iteration in range(num_microbatch):
input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype)
grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
out = layer_cached(input_tensor, is_first_microbatch=(iteration == 0))
out.backward(grad_out)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
out_ref = layer_normal(input_tensor)
out_ref.backward(grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(layer_cached.weight.grad,
layer_normal.weight.grad,
rtol=rtol,
atol=atol)
@pytest.mark.parametrize('bs,hidden_size', NORM_CASES)
@pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]])
......@@ -411,6 +453,62 @@ class TestLayerNormLinear:
if do_calibration:
assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES)
@pytest.mark.parametrize('activation_dtype', ['bfloat16'])
@pytest.mark.parametrize('num_microbatch', [8])
def test_layernorm_linear_fp8_microbatch(bs, in_features, out_features, activation_dtype,
num_microbatch):
"""
Test FP8 LayerNormLinear Layer
"""
paddle.set_default_dtype(activation_dtype)
eps = 1e-3
rtol = 0.5
atol = 0.5
recipe = DelayedScaling()
layer_cached = te.LayerNormLinear(
in_features=in_features,
out_features=out_features,
eps=eps,
)
layer_normal = te.LayerNormLinear(
in_features=in_features,
out_features=out_features,
eps=eps,
)
layer_cached.ln_weight.copy_(layer_normal.ln_weight, True)
layer_cached.ln_bias.copy_(layer_normal.ln_bias, True)
layer_cached.weight.copy_(layer_normal.weight, True)
layer_cached.bias.copy_(layer_normal.bias, True)
for iteration in range(num_microbatch):
input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype)
grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
out = layer_cached(input_tensor, is_first_microbatch=(iteration == 0))
out.backward(grad_out)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
out_ref = layer_normal(input_tensor)
out_ref.backward(grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(layer_cached.weight.grad,
layer_normal.weight.grad,
rtol=rtol,
atol=atol)
assert_allclose(layer_cached.ln_weight.grad,
layer_normal.ln_weight.grad,
rtol=rtol,
atol=atol)
class TestLayerNormMLP:
"""
......@@ -615,6 +713,75 @@ class TestLayerNormMLP:
if do_calibration:
assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('bs,hidden_size,ffn_hidden_size', LINEAR_CASES)
@pytest.mark.parametrize('activation_dtype', ['bfloat16'])
@pytest.mark.parametrize('num_microbatch', [8])
def test_layernorm_mlp_fp8_microbatch(bs, hidden_size, ffn_hidden_size, activation_dtype,
num_microbatch):
"""
Test FP8 LayerNormMLP Layer
"""
paddle.set_default_dtype(activation_dtype)
rtol = 1e-5
atol = 1e-5
eps = 1e-3
recipe = DelayedScaling()
layer_cached = te.LayerNormMLP(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
eps=eps,
)
layer_normal = te.LayerNormMLP(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
eps=eps,
)
layer_normal.ln_weight.copy_(layer_cached.ln_weight, True)
layer_normal.ln_bias.copy_(layer_cached.ln_bias, True)
layer_normal.fc1_weight.copy_(layer_cached.fc1_weight, True)
layer_normal.fc2_weight.copy_(layer_cached.fc2_weight, True)
layer_normal.fc1_bias.copy_(layer_cached.fc1_bias, True)
layer_normal.fc2_bias.copy_(layer_cached.fc2_bias, True)
# Calibration to make sure weight scale is the same
input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe):
_ = layer_cached(input_tensor)
with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe):
_ = layer_normal(input_tensor)
for iteration in range(num_microbatch):
input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
out = layer_cached(input_tensor, is_first_microbatch=(iteration == 0))
out.backward(grad_out)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
out_ref = layer_normal(input_tensor)
out_ref.backward(grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(layer_cached.ln_weight.grad,
layer_normal.ln_weight.grad,
rtol=rtol,
atol=atol)
assert_allclose(layer_cached.fc1_weight.grad,
layer_normal.fc1_weight.grad,
rtol=rtol,
atol=atol)
assert_allclose(layer_cached.fc2_weight.grad,
layer_normal.fc2_weight.grad,
rtol=rtol,
atol=atol)
@pytest.mark.parametrize('bs', [1, 2, 8])
@pytest.mark.parametrize('hidden_size, num_heads', [[1024, 16], [768, 12]])
......@@ -1172,3 +1339,122 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
layer_pd.inter_attention.layernorm_query.bias.grad,
rtol=rtol,
atol=atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('bs', [8])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128]])
@pytest.mark.parametrize('mask_type', ['causal'])
@pytest.mark.parametrize('math_dtype', ['bfloat16'])
@pytest.mark.parametrize('num_microbatch', [8])
def test_transformer_encoder_layer_microbatch(bs, hidden_size, num_heads, ffn_hidden_size, q_seqlen,
kv_seqlen, mask_type, math_dtype, num_microbatch):
"""
Test Transformer Encoder Layer with FP8 weight caching
"""
paddle.set_default_dtype(math_dtype)
rtol = 1e-5
atol = 1e-5
eps = 1e-3
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
num_heads=num_heads,
num_gqa_groups=num_heads,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
head_size=hidden_size // num_heads,
dtype=math_dtype,
dropout=0.0,
qkv_layout="bs3hd",
bias_type="no_bias",
mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
layer_cached = te.TransformerLayer(hidden_size,
ffn_hidden_size,
num_heads,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
weight_attr=None,
bias_attr=None,
self_attn_mask_type=mask_type,
layer_type='encoder')
layer_normal = te.TransformerLayer(hidden_size,
ffn_hidden_size,
num_heads,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
weight_attr=None,
bias_attr=None,
self_attn_mask_type=mask_type,
layer_type='encoder')
layer_normal.self_attention.layernorm_qkv.ln_weight.copy_(
layer_cached.self_attention.layernorm_qkv.ln_weight, True)
layer_normal.self_attention.layernorm_qkv.ln_bias.copy_(
layer_cached.self_attention.layernorm_qkv.ln_bias, True)
layer_normal.self_attention.layernorm_qkv.weight.copy_(
layer_cached.self_attention.layernorm_qkv.weight, True)
layer_normal.self_attention.layernorm_qkv.bias.copy_(
layer_cached.self_attention.layernorm_qkv.bias, True)
layer_normal.self_attention.proj.weight.copy_(layer_cached.self_attention.proj.weight, True)
layer_normal.self_attention.proj.bias.copy_(layer_cached.self_attention.proj.bias, True)
# LayerNorm MLP params
layer_normal.layernorm_mlp.ln_weight.copy_(layer_cached.layernorm_mlp.ln_weight, True)
layer_normal.layernorm_mlp.ln_bias.copy_(layer_cached.layernorm_mlp.ln_bias, True)
layer_normal.layernorm_mlp.fc1_weight.copy_(layer_cached.layernorm_mlp.fc1_weight, True)
layer_normal.layernorm_mlp.fc2_weight.copy_(layer_cached.layernorm_mlp.fc2_weight, True)
layer_normal.layernorm_mlp.fc1_bias.copy_(layer_cached.layernorm_mlp.fc1_bias, True)
layer_normal.layernorm_mlp.fc2_bias.copy_(layer_cached.layernorm_mlp.fc2_bias, True)
recipe = DelayedScaling()
def generate_input():
encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)
q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen
kv_actual_seqlen = q_actual_seqlen
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool')
grad_out = paddle.normal(mean=0.0, std=0.02,
shape=(bs, q_seqlen, hidden_size)).astype('float32')
for i in range(0, bs):
grad_out[i, q_actual_seqlen[i]:, :] = 0
grad_out = grad_out.astype(math_dtype)
for i in range(0, bs):
attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False
return encoder_input, attn_mask, grad_out
# Calibration to make sure weight scale is the same
encoder_input, mask, _ = generate_input()
with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe):
_ = layer_cached(encoder_input, mask)
with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe):
_ = layer_normal(encoder_input, mask)
for iteration in range(num_microbatch):
encoder_input, mask, grad_out = generate_input()
with fp8_autocast(enabled=True, fp8_recipe=recipe):
out = layer_cached(encoder_input, mask, is_first_microbatch=(iteration == 0))
out.backward(grad_out)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
out_ref = layer_normal(encoder_input, mask)
out_ref.backward(grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(layer_cached.self_attention.layernorm_qkv.weight.grad,
layer_normal.self_attention.layernorm_qkv.weight.grad,
rtol=rtol,
atol=atol)
......@@ -10,8 +10,14 @@ import paddle
import paddle.nn.functional as F
import pytest
import transformer_engine # pylint: disable=unused-import
import transformer_engine_paddle as tex # pylint: disable=wrong-import-order
from utils import (
assert_allclose,
create_fp8_meta,
get_fused_attention_backend,
is_fused_attention_supported,
)
import transformer_engine_paddle as tex
from transformer_engine.paddle.cpp_extensions import (
cast_to_fp8,
cast_from_fp8,
......@@ -44,13 +50,6 @@ from transformer_engine.paddle.fp8 import is_fp8_available
from transformer_engine.paddle.constants import FP8FwdTensors
from transformer_engine.common.recipe import DelayedScaling
from utils import (
assert_allclose,
create_fp8_meta,
get_fused_attention_backend,
is_fused_attention_supported,
)
GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)]
is_fp8_supported, reason = is_fp8_available()
......@@ -69,20 +68,24 @@ def setup():
yield
def test_quantize_dequantize():
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize('inplace', [True, False])
def test_quantize_dequantize(fp8_dtype, inplace):
"""
Test cast_to_fp8 and cast_from_fp8
"""
a = paddle.rand(shape=(32, 32), dtype='float32')
# Init fp8_meta
fp8_meta = create_fp8_meta()
for fp8_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
a_fp8 = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_OUTPUT, otype=fp8_dtype)
b = cast_from_fp8(a_fp8,
a_fp8 = paddle.zeros(shape=a.shape, dtype=paddle.uint8) if inplace else None
a_fp8 = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_OUTPUT, otype=fp8_dtype, out=a_fp8)
b = cast_from_fp8(
a_fp8,
fp8_meta,
FP8FwdTensors.GEMM1_OUTPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32)
otype=tex.DType.kFloat32,
)
assert_allclose(a, b, rtol=5e-2, atol=5e-2)
......@@ -142,7 +145,8 @@ class TestTranspose:
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_cast_transpose(fp8_dtype):
@pytest.mark.parametrize('inplace', [True, False])
def test_cast_transpose(fp8_dtype, inplace):
"""
Test cast_transpose
"""
......@@ -150,10 +154,16 @@ class TestTranspose:
max_val = 8
a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32')
fp8_meta = create_fp8_meta()
a_fp8_casted, a_fp8_transposed = None, None
if inplace:
a_fp8_casted = paddle.zeros(shape=a.shape, dtype=paddle.uint8)
a_fp8_transposed = paddle.zeros(shape=a.T.shape, dtype=paddle.uint8)
a_fp8_casted, a_fp8_transposed = cast_transpose(a,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
otype=fp8_dtype)
otype=fp8_dtype,
cast_out=a_fp8_casted,
transpose_out=a_fp8_transposed)
a_transposed = cast_from_fp8(a_fp8_transposed,
fp8_meta,
......@@ -616,7 +626,7 @@ class TestFusedAttn:
assert attn_mode == "self_attn", "only support causal masking for self attention"
for i in range(0, self.batch_size):
for j in range(self.q_actual_seqlen[i]):
self.attn_mask[i, :, j, :j+1] = 0
self.attn_mask[i, :, j, :j + 1] = 0
else:
for i in range(0, self.batch_size):
self.attn_mask[i, :, :self.q_actual_seqlen[i], :self.kv_actual_seqlen[i]] = 0
......@@ -682,11 +692,7 @@ class TestFusedAttn:
q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True)
kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True)
qkv_layout = (
"bs3hd"
if self.attn_mode == "self_attn"
else "bshd_bs2hd"
)
qkv_layout = ("bs3hd" if self.attn_mode == "self_attn" else "bshd_bs2hd")
fused_attention_backend = get_fused_attention_backend(
num_heads=self.num_heads,
num_gqa_groups=self.num_heads,
......@@ -932,12 +938,14 @@ class TestSoftmax:
assert_allclose(dx_ref, dx, rtol=1e-4, atol=5e-3)
def test_amax_and_scale_update():
@pytest.mark.parametrize('update_weight_scale_inv', [True, False])
def test_amax_and_scale_update(update_weight_scale_inv):
"""Test update_scale"""
num_gemm = 6
history_len = 1024
recipe = DelayedScaling()
fp8_max = recipe.fp8_format.value.max_fwd
non_weight_mask = paddle.to_tensor([True, False] * (num_gemm // 2))
amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype='float32')
rolled_history_ref = paddle.roll(amax_history_tensor, -1, axis=0)
......@@ -947,13 +955,17 @@ def test_amax_and_scale_update():
def calc_ref(amax, scale, fp8_max, margin=0):
"""Calculate reference scale"""
sf = (fp8_max / amax) / (2 ** margin)
sf = (fp8_max / amax) / (2**margin)
sf = paddle.where(amax > 0.0, sf, scale)
sf = paddle.where(paddle.isfinite(amax), sf, scale)
return sf
scale_ref = calc_ref(amax_tensor, scale_tensor, fp8_max, 0.)
if update_weight_scale_inv:
scale_inv_ref = 1. / scale_ref
else:
scale_inv_ref = paddle.zeros_like(scale_tensor)
scale_inv_ref = paddle.where(non_weight_mask, 1. / scale_ref, scale_inv_ref)
# Placeholder
scale_actual = paddle.zeros_like(scale_tensor)
......@@ -962,6 +974,8 @@ def test_amax_and_scale_update():
tex.amax_and_scale_update_inplace(_amax_history=amax_history_tensor,
_scale=scale_actual,
_scale_inv=scale_inv_actual,
non_weight_mask=non_weight_mask,
update_weight_scale_inv=update_weight_scale_inv,
fp8_max=fp8_max,
margin=0.,
amax_compute="max")
......
......@@ -185,11 +185,22 @@ def cast_to_fp8(
fp8_meta_tensor: FP8TensorMeta,
fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
"""Cast input to FP8"""
out, _, _ = tex.cast_to_fp8(
if out is None:
out = paddle.empty(
shape=inp.shape,
dtype=paddle.uint8,
)
else:
assert out.shape == inp.shape, "Output shape does not match input shape."
assert out.dtype == paddle.uint8, "Output should be of uint8 dtype."
tex.cast_to_fp8(
inp,
fp8_meta_tensor.scale,
out,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor.value,
......@@ -231,11 +242,34 @@ def cast_transpose(
fp8_meta_tensor: FP8TensorMeta,
fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType,
cast_out: Optional[paddle.Tensor] = None,
transpose_out: Optional[paddle.Tensor] = None,
) -> Union[Tuple[paddle.Tensor, paddle.Tensor], None]:
"""Cast + Transpose with FP8 output"""
cast_out, transpose_out, _, _ = tex.te_cast_transpose(
if cast_out is None:
cast_out = paddle.empty(
shape=inp.shape,
dtype=paddle.uint8,
)
else:
assert cast_out.shape == inp.shape, "cast_out shape does not match input shape."
assert cast_out.dtype == paddle.uint8, "cast_out should be of uint8 dtype."
if transpose_out is None:
transpose_out = paddle.empty(
shape=[inp.shape[1], inp.shape[0]],
dtype=paddle.uint8,
)
else:
assert transpose_out.shape == [inp.shape[1], inp.shape[0]
], "Transposed output shape does not match input shape."
assert transpose_out.dtype == paddle.uint8, "Output should be of uint8 dtype."
tex.te_cast_transpose(
inp,
fp8_meta_tensor.scale,
cast_out,
transpose_out,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor.value,
......
......@@ -40,21 +40,19 @@ NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type) {
}
}
std::vector<paddle::Tensor> cast_to_fp8(const paddle::Tensor &input, const paddle::Tensor &scale,
paddle::Tensor &amax, paddle::Tensor &scale_inv, // NOLINT
void cast_to_fp8(const paddle::Tensor &input, const paddle::Tensor &scale,
paddle::Tensor &output, // NOLINT
paddle::Tensor &amax, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
int64_t index, int64_t otype) {
auto shape = GetShapeArray(input);
auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)));
auto input_cu = MakeNvteTensor(input);
auto output_cu = MakeNvteTensor(
output.data(), shape, Int2NvteDType(otype), GetDataPtr<float>(amax, index),
const_cast<void *>(GetDataPtr<float>(scale, index)), GetDataPtr<float>(scale_inv, index));
nvte_fp8_quantize(input_cu.data(), output_cu.data(), input.stream());
return {output};
}
std::vector<paddle::Tensor> cast_from_fp8(const paddle::Tensor &input,
......@@ -89,8 +87,9 @@ std::vector<paddle::Tensor> te_transpose(const paddle::Tensor &input, int64_t ot
return {output};
}
std::vector<paddle::Tensor> te_cast_transpose(const paddle::Tensor &input,
const paddle::Tensor &scale,
void te_cast_transpose(const paddle::Tensor &input, const paddle::Tensor &scale,
paddle::Tensor &output_cast, // NOLINT
paddle::Tensor &output_transpose, // NOLINT
paddle::Tensor &amax, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
int64_t index, int64_t otype) {
......@@ -100,24 +99,17 @@ std::vector<paddle::Tensor> te_cast_transpose(const paddle::Tensor &input,
size_t M = shape[0];
size_t N = shape[1];
auto input_cast =
paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto input_transpose = paddle::empty({input.shape()[1], input.shape()[0]},
Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto input_cu = MakeNvteTensor(input);
void *amax_data = GetDataPtr<float>(amax, index);
void *scale_data = const_cast<void *>(GetDataPtr<float>(scale, index));
void *scale_inv_data = GetDataPtr<float>(scale_inv, index);
auto output_cast_cu = MakeNvteTensor(input_cast.data(), {M, N}, Int2NvteDType(otype), amax_data,
scale_data, scale_inv_data);
auto output_transpose_cu = MakeNvteTensor(input_transpose.data(), {N, M}, Int2NvteDType(otype),
auto output_cast_cu = MakeNvteTensor(output_cast.data(), {M, N}, Int2NvteDType(otype),
amax_data, scale_data, scale_inv_data);
auto output_transpose_cu = MakeNvteTensor(output_transpose.data(), {N, M}, Int2NvteDType(otype),
amax_data, scale_data, scale_inv_data);
nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(),
input.stream());
return {input_cast, input_transpose};
}
std::vector<paddle::Tensor> te_cast_transpose_bgrad(const paddle::Tensor &grad_output,
......@@ -1021,9 +1013,9 @@ void te_scaled_upper_triang_masked_softmax_backward(paddle::Tensor &output_grads
}
__global__ void UpdateFP8MetaKernel(const float *amax, const float *rolled_amax_history,
float *amax_history, float *scale, float *scale_inv,
float margin, float fp8_max, size_t history_numel,
size_t amax_numel) {
const bool *non_weight_mask, float *amax_history, float *scale,
float *scale_inv, bool update_weight_scale_inv, float margin,
float fp8_max, size_t history_numel, size_t amax_numel) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= history_numel) {
......@@ -1036,7 +1028,7 @@ __global__ void UpdateFP8MetaKernel(const float *amax, const float *rolled_amax_
float sf = (fp8_max / amax[idx]) / powf(2.0f, margin);
float scale_reg = ((amax[idx] > 0.0f) && isfinite(amax[idx])) ? sf : scale[idx];
scale[idx] = scale_reg;
scale_inv[idx] = 1.0f / scale_reg;
if (update_weight_scale_inv || non_weight_mask[idx]) scale_inv[idx] = 1.0f / scale_reg;
amax_history[idx] = 0.0f;
}
}
......@@ -1046,7 +1038,9 @@ constexpr int BLOCK_SIZE = 512;
void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT
paddle::Tensor &scale, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
float fp8_max, float margin, const std::string &amax_compute) {
const paddle::Tensor &non_weight_mask,
bool update_weight_scale_inv, float fp8_max, float margin,
const std::string &amax_compute) {
NVTE_CHECK(amax_compute == "max" || amax_compute == "most_recent");
paddle::Tensor amax;
......@@ -1062,9 +1056,9 @@ void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT
auto size = amax_history.numel();
size_t num_blocks = (size + BLOCK_SIZE - 1) / BLOCK_SIZE;
UpdateFP8MetaKernel<<<num_blocks, BLOCK_SIZE, 0, amax_history.stream()>>>(
amax.data<float>(), rolled_amax_history.data<float>(), amax_history.data<float>(),
scale.data<float>(), scale_inv.data<float>(), margin, fp8_max, amax_history.numel(),
amax.numel());
amax.data<float>(), rolled_amax_history.data<float>(), non_weight_mask.data<bool>(),
amax_history.data<float>(), scale.data<float>(), scale_inv.data<float>(),
update_weight_scale_inv, margin, fp8_max, amax_history.numel(), amax.numel());
NVTE_CHECK_CUDA(cudaGetLastError());
}
......@@ -1178,10 +1172,10 @@ PD_BUILD_OP(te_gemm)
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_gemm));
PD_BUILD_OP(cast_to_fp8)
.Inputs({"Input", "Scale", "_Amax", "_ScaleInv"})
.Inputs({"Input", "Scale", "_Output", "_Amax", "_ScaleInv"})
.Outputs({"Output", "Amax", "ScaleInv"})
.Attrs({"index: int64_t", "otype: int64_t"})
.SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}})
.SetInplaceMap({{"_Output", "Output"}, {"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::cast_to_fp8));
PD_BUILD_OP(cast_from_fp8)
......@@ -1197,9 +1191,12 @@ PD_BUILD_OP(te_transpose)
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_transpose));
PD_BUILD_OP(te_cast_transpose)
.Inputs({"Input", "Scale", "_Amax", "_ScaleInv"})
.Inputs({"Input", "Scale", "_CastedOutput", "_TransposedOutput", "_Amax", "_ScaleInv"})
.Outputs({"CastedOutput", "TransposedOutput", "Amax", "ScaleInv"})
.SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}})
.SetInplaceMap({{"_CastedOutput", "CastedOutput"},
{"_TransposedOutput", "TransposedOutput"},
{"_Amax", "Amax"},
{"_ScaleInv", "ScaleInv"}})
.Attrs({"index: int64_t", "otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_cast_transpose));
......@@ -1361,12 +1358,13 @@ PD_BUILD_OP(te_scaled_upper_triang_masked_softmax_backward)
PD_KERNEL(transformer_engine::paddle_ext::te_scaled_upper_triang_masked_softmax_backward));
PD_BUILD_OP(amax_and_scale_update_inplace)
.Inputs({"_amax_history", "_scale", "_scale_inv"})
.Inputs({"_amax_history", "_scale", "_scale_inv", "non_weight_mask"})
.Outputs({"amax_history", "scale", "scale_inv"})
.SetInplaceMap({{"_amax_history", "amax_history"},
{"_scale", "scale"},
{"_scale_inv", "scale_inv"}})
.Attrs({"fp8_max: float", "margin: float", "amax_compute: std::string"})
.Attrs({"update_weight_scale_inv: bool", "fp8_max: float", "margin: float",
"amax_compute: std::string"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::amax_and_scale_update_inplace));
PD_BUILD_OP(update_latest_amax_history_inplace)
......
......@@ -15,10 +15,8 @@ from transformer_engine.common.recipe import DelayedScaling, Format
from .constants import dist_group_type
from .fp8_buffer import FP8MetaFwdBuffer, FP8MetaBwdBuffer, FP8RecomputeBuffer
__all__ = ['fp8_autocast']
# FP8 support
_is_fp8_available = None
_reason_for_no_fp8 = ""
......@@ -227,6 +225,7 @@ def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> t
def amax_and_scale_update(
fp8_meta: Dict[str, Any],
fwd_update: bool,
update_weight_scale_inv: bool = True,
) -> None:
"""Updates fp8 amaxes/scales for fwd | bwd."""
amax_compute = fp8_meta["recipe"].amax_compute_algo
......@@ -235,9 +234,12 @@ def amax_and_scale_update(
fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd"
if not callable(amax_compute) and sf_compute is None:
tex.amax_and_scale_update_inplace(_amax_history=fp8_meta[fp8_meta_tensor_key].amax_history,
tex.amax_and_scale_update_inplace(
_amax_history=fp8_meta[fp8_meta_tensor_key].amax_history,
_scale=fp8_meta[fp8_meta_tensor_key].scale,
_scale_inv=fp8_meta[fp8_meta_tensor_key].scale_inv,
non_weight_mask=fp8_meta[fp8_meta_tensor_key].non_weight_mask,
update_weight_scale_inv=update_weight_scale_inv,
fp8_max=fp8_meta[fp8_max_key],
margin=float(fp8_meta["recipe"].margin),
amax_compute=amax_compute)
......@@ -254,10 +256,20 @@ class FP8TensorMeta():
self.scale = paddle.Tensor()
self.scale_inv = paddle.Tensor()
self.amax_history = paddle.Tensor()
self.non_weight_mask = paddle.Tensor()
self.is_initialized = False
self.is_forward = is_forward
def prepare(self, num_gemms: bool, amax_history_len: int) -> None:
def get_non_weight_mask(self, num_gemms: int):
"""Needed for calculation of scale inverses to
preserve scale_inv when caching FP8 weights"""
if self.is_forward:
# [True, False, True]: -> [input, weight, output]
return paddle.to_tensor([True, False, True] * num_gemms)
# [True, True]: -> [grad_output, grad_input]
return paddle.to_tensor([True, True] * num_gemms)
def prepare(self, num_gemms: int, amax_history_len: int) -> None:
"""Prepare scales and amax tensors. It is called during fprop in each iteration.
If the meta tensors are not initialized yet, initialization is performed. If already
initialized, resize the meta tensors if amax_history_len has changed."""
......@@ -284,6 +296,8 @@ class FP8TensorMeta():
self.scale = paddle.ones(num_fp8_tensors, dtype='float32')
self.scale_inv = paddle.ones(num_fp8_tensors, dtype='float32')
self.amax_history = paddle.zeros([amax_history_len, num_fp8_tensors], dtype='float32')
self.non_weight_mask = self.get_non_weight_mask(num_gemms=num_gemms)
self.is_initialized = True
def to_numpy(self):
......@@ -300,4 +314,9 @@ class FP8TensorMeta():
self.scale = paddle.to_tensor(data['scale'])
self.scale_inv = paddle.to_tensor(data['scale_inv'])
self.amax_history = paddle.to_tensor(data['amax_history'])
num_fp8_tensors = self.scale.shape[0]
num_gemms = num_fp8_tensors // 3 if self.is_forward else num_fp8_tensors // 2
self.non_weight_mask = self.get_non_weight_mask(num_gemms=num_gemms)
self.is_initialized = True
......@@ -552,6 +552,7 @@ class MultiHeadAttention(paddle.nn.Layer):
core_attention_bias: Optional[paddle.Tensor] = None,
set_zero: bool = True,
recompute_core_attention: bool = False,
is_first_microbatch: Optional[bool] = None,
) -> Tuple[Union[paddle.Tensor, None], ...]:
"""
MultiHeadAttention Layer.
......@@ -575,6 +576,16 @@ class MultiHeadAttention(paddle.nn.Layer):
during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until
backprop.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
"""
if self.attn_mask_type != "causal" and attention_mask is not None:
......@@ -594,13 +605,19 @@ class MultiHeadAttention(paddle.nn.Layer):
if self.attention_type == "self":
if self.input_layernorm:
layernorm_qkv_outputs = self.layernorm_qkv(hidden_states)
layernorm_qkv_outputs = self.layernorm_qkv(
hidden_states,
is_first_microbatch=is_first_microbatch,
)
if self.return_layernorm_output:
mixed_qkv_layer, layernorm_output = layernorm_qkv_outputs
else:
mixed_qkv_layer = layernorm_qkv_outputs
else:
mixed_qkv_layer = self.qkv(hidden_states)
mixed_qkv_layer = self.qkv(
hidden_states,
is_first_microbatch=is_first_microbatch,
)
# [b, s_q, 3 * hidden_size] --> [b, s_q, 3, num_heads, head_size]
mixed_qkv_layer = mixed_qkv_layer.reshape(shape=[
......@@ -631,7 +648,10 @@ class MultiHeadAttention(paddle.nn.Layer):
)
else: # cross attention
mixed_kv_layer = self.key_value(encoder_output)
mixed_kv_layer = self.key_value(
encoder_output,
is_first_microbatch=is_first_microbatch,
)
# [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size]
mixed_kv_layer = mixed_kv_layer.reshape(shape=[
-1, max_seq_len, 2, self.num_attention_heads_per_partition,
......@@ -639,13 +659,19 @@ class MultiHeadAttention(paddle.nn.Layer):
])
if self.input_layernorm:
layernorm_query_outputs = self.layernorm_query(hidden_states)
layernorm_query_outputs = self.layernorm_query(
hidden_states,
is_first_microbatch=is_first_microbatch,
)
if self.return_layernorm_output:
query_layer, layernorm_output = layernorm_query_outputs
else:
query_layer = layernorm_query_outputs
else:
query_layer = self.query_layer(hidden_states)
query_layer = self.query_layer(
hidden_states,
is_first_microbatch=is_first_microbatch,
)
query_layer = query_layer.reshape(shape=[
-1, max_seq_len, self.num_attention_heads_per_partition,
......@@ -681,7 +707,7 @@ class MultiHeadAttention(paddle.nn.Layer):
[-1, context_layer.shape[2] * context_layer.shape[3]])
# Output. [b, s, hidden]
attention_output = self.proj(context_layer)
attention_output = self.proj(context_layer, is_first_microbatch=is_first_microbatch)
if self.input_layernorm and self.return_layernorm_output:
return attention_output, layernorm_output
......
......@@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
from contextlib import contextmanager
import os
import pickle
from typing import Generator, Dict, Tuple, Union, Any
from typing import Generator, Dict, Tuple, Union, Any, List, Optional
import numpy as np
......@@ -70,9 +70,12 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
self.fp8_meta["scaling_bwd"] = FP8TensorMeta(is_forward=False)
self.tp_group = None
self.tp_size = 1
self.sequence_parallel = False
self.fp8_meta["autocast_id_fwd_stack"] = []
self.fp8_meta["async_amax_reduction"] = bool(
int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0")))
self.fp8_weight_shapes = []
self.fp8_weight_cache = {}
def set_activation_dtype(self, inp: paddle.Tensor) -> None:
"""Get activation data type for AMP."""
......@@ -140,6 +143,29 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
self.fp8_initialized = False
return
def set_fp8_weights(self) -> None:
"""Initializes FP8 weights for the module"""
if not self.fp8_enabled:
return
for i, shape in enumerate(self.fp8_weight_shapes, start=1):
weight_cast_key = f"weight{i}_fp8"
weight_transpose_key = f"weight{i}_t_fp8"
if (weight_cast_key in self.fp8_weight_cache
and self.fp8_weight_cache[weight_cast_key].shape == shape):
return
self.fp8_weight_cache[weight_cast_key] = paddle.empty(
shape=shape,
dtype=paddle.uint8,
)
self.fp8_weight_cache[weight_transpose_key] = paddle.empty(
shape=[shape[1], shape[0]],
dtype=paddle.uint8,
)
def _get_fp8_state(self) -> paddle.Tensor:
"""Dump FP8 state to paddle.Tensor."""
state = None
......@@ -218,6 +244,7 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
def prepare_forward(
self,
inp: paddle.Tensor,
is_first_microbatch: Union[bool, None],
num_gemms: int = 1,
) -> Generator[paddle.Tensor, None, None]:
"""Checks and prep for FWD.
......@@ -234,16 +261,32 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
self.set_activation_dtype(inp)
self.fp8_init(num_gemms=num_gemms)
# Create persistent tensors for fp8 weights and their transposes
# only when fp8 weight caching is used.
if is_first_microbatch is not None:
self.set_fp8_weights()
if self.fp8_enabled and self.sequence_parallel:
assert self.fp8_meta["recipe"].reduce_amax, \
"Amax reduction across tensor parallel group is " \
"necessary when using sequence parallelism with FP8."
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
# Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False):
global_fp8_fwd_buffer = get_global_fp8_state().get_fp8_fwd_buffer()
global_fp8_fwd_buffer.wait()
if self.fp8_meta["recipe"].reduce_amax:
global_fp8_fwd_buffer.copy_amax_from_buffer(self.fp8_meta)
amax_and_scale_update(self.fp8_meta, True)
amax_and_scale_update(self.fp8_meta,
True,
update_weight_scale_inv=update_weight_scale_inv)
global_fp8_fwd_buffer.set_for_deletion(self.fp8_meta)
else:
amax_and_scale_update(self.fp8_meta, True)
amax_and_scale_update(self.fp8_meta,
True,
update_weight_scale_inv=update_weight_scale_inv)
if self.fp8_enabled and self.training:
# Setup for amax reduction
......@@ -383,3 +426,28 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
@abstractmethod
def forward(self):
"""Needs override."""
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[Optional[paddle.Tensor]]:
"""
Fetch the fp8 weight tensor placeholders if they exist (when
`is_first_microbatch` is not `None`)
"""
if not self.fp8_enabled or is_first_microbatch is None:
return [None, None] * len(self.fp8_weight_shapes)
out_list = []
for i, _ in enumerate(self.fp8_weight_shapes, start=1):
weight_cast_key = f"weight{i}_fp8"
weight_transpose_key = f"weight{i}_t_fp8"
assert weight_cast_key in self.fp8_weight_cache, \
"TE internal error: fp8 weight buffer is not found"
out_list.extend([
self.fp8_weight_cache[weight_cast_key],
self.fp8_weight_cache[weight_transpose_key],
])
return out_list
......@@ -3,6 +3,7 @@
# See LICENSE for license information.
"""LayerNormLinear API"""
import warnings
import os
from typing import Union, Tuple, Dict, Any, Optional
......@@ -131,6 +132,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ln_weight: paddle.Tensor,
ln_bias: paddle.Tensor,
weight: paddle.Tensor,
weight_fp8: Optional[paddle.Tensor],
weight_t_fp8: Optional[paddle.Tensor],
bias: Union[paddle.Tensor, None],
use_bias: bool,
eps: float,
......@@ -148,6 +151,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
sequence_parallel: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
is_first_microbatch: bool,
) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.shape[0]
......@@ -182,6 +186,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ln_out,
FP8FwdTensors.GEMM1_INPUT,
weight,
weight_fp8,
weight_t_fp8,
FP8FwdTensors.GEMM1_WEIGHT,
bias,
use_bias,
......@@ -194,6 +200,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
sequence_parallel,
tp_group,
is_grad_enabled,
is_first_microbatch,
)
if is_grad_enabled:
......@@ -227,6 +234,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ctx.requires_bgrad = use_bias and not bias.stop_gradient
ctx.requires_ln_bgrad = not ln_bias.stop_gradient
ctx.requires_ln_wgrad = not ln_weight.stop_gradient
ctx.is_first_microbatch = is_first_microbatch
# [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.reshape((-1, *inp.shape[1:-1], out.shape[-1]))
......@@ -320,11 +329,18 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
bgrad = bgrad if ctx.requires_bgrad else None
bgrad_out = (bgrad,) if ctx.use_bias else ()
if not ctx.fp8_enabled or ctx.is_first_microbatch is None:
weight_cache_grad = ()
else:
# weight_fp8 and weight_t_fp8 are stop_gradient tensors
weight_cache_grad = (None, None)
return (
dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma if ctx.requires_ln_wgrad else None,
dbeta if ctx.requires_ln_bgrad else None,
wgrad if ctx.requires_wgrad else None,
*weight_cache_grad,
*bgrad_out,
)
......@@ -447,6 +463,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
)
set_weight_tensor_dist_attr(self.weight, self.tensor_parallel, self.parallel_mode,
self.backend)
self.fp8_weight_shapes.append(self.weight.shape)
# Initialize Linear bias parameter
self.has_bias = self._bias_attr is not False
......@@ -483,21 +500,28 @@ class LayerNormLinear(TransformerEngineBaseLayer):
def _te_forward(
self,
inp: paddle.Tensor,
is_first_microbatch: Optional[bool] = None,
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]:
"""
Apply layer normalization to the input followed by a linear transformation.
"""
with self.prepare_forward(inp) as inp:
with self.prepare_forward(inp, is_first_microbatch=is_first_microbatch) as inp:
# Layer input should be casted outside PyLayer, as performing
# inplace cast to input tensors may cause problems when used
# together with Paddle native layers.
inp = cast_if_needed(inp, self.activation_dtype)
# Get persistent fp8 weight buffer. None if buffer does not exist.
weight_fp8, weight_t_fp8 = self.get_fp8_weights_scratchpad(is_first_microbatch)
out = _LayerNormLinear.apply(
inp,
self.ln_weight,
self.ln_bias,
self.weight,
weight_fp8,
weight_t_fp8,
self.bias if self.gemm_bias_fused_add else None,
self.has_bias and self.gemm_bias_fused_add,
self.eps,
......@@ -515,6 +539,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
self.sequence_parallel,
self.tp_group,
self.tp_size,
is_first_microbatch,
)
if self.return_layernorm_output:
......@@ -530,12 +555,17 @@ class LayerNormLinear(TransformerEngineBaseLayer):
def _pd_forward(
self,
inp: paddle.Tensor,
is_first_microbatch: Optional[bool] = None,
) -> paddle.Tensor:
"""Calls Paddle OP"""
if self.zero_centered_gamma:
raise NotImplementedError(
"Paddle backend does not support LayerNorm with zero-centered scale.")
if is_first_microbatch is not None:
warnings.warn(
"`is_first_microbatch` is not supported for paddle backend and is ignored.")
ln_out = F.layer_norm(x=inp,
normalized_shape=inp.shape[-1],
weight=self.ln_weight,
......@@ -557,8 +587,18 @@ class LayerNormLinear(TransformerEngineBaseLayer):
Parameters
----------
inp : torch.Tensor
inp : paddle.Tensor
Input tensor.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
"""
if self.backend == 'transformer_engine':
return self._te_forward(*args, **kwargs)
......
......@@ -4,6 +4,7 @@
"""LayerNormMLP API"""
import os
import warnings
from typing import Union, Tuple, Dict, Any, Optional
import paddle
......@@ -46,11 +47,15 @@ def _mlp_forward(
inputmat: paddle.Tensor,
inputmat_fp8_index: FP8FwdTensors,
fc1_weight: paddle.Tensor,
fc1_weight_fp8: Optional[paddle.Tensor],
fc1_weight_t_fp8: Optional[paddle.Tensor],
fc1_weight_fp8_index: FP8FwdTensors,
fc1_bias: Union[paddle.Tensor, None],
use_fc1_bias: bool,
fc2_input_fp8_index: FP8FwdTensors, # FP8FwdTensors.GEMM2_INPUT
fc2_weight: paddle.Tensor,
fc2_weight_fp8: Optional[paddle.Tensor],
fc2_weight_t_fp8: Optional[paddle.Tensor],
fc2_weight_fp8_index: FP8FwdTensors,
fc2_bias: Union[paddle.Tensor, None],
use_fc2_bias: bool,
......@@ -64,6 +69,7 @@ def _mlp_forward(
tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None],
is_first_microbatch: bool,
):
if fp8_enabled:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
......@@ -71,6 +77,8 @@ def _mlp_forward(
inputmat,
inputmat_fp8_index,
fc1_weight,
fc1_weight_fp8,
fc1_weight_t_fp8,
fc1_weight_fp8_index,
fc1_bias,
use_fc1_bias,
......@@ -81,6 +89,7 @@ def _mlp_forward(
sequence_parallel,
tp_group,
is_grad_enabled,
is_first_microbatch,
)
gelu_out = gelu_fp8(
......@@ -94,6 +103,8 @@ def _mlp_forward(
gelu_out,
fc2_input_fp8_index,
fc2_weight,
fc2_weight_fp8,
fc2_weight_t_fp8,
fc2_weight_fp8_index,
fc2_bias,
use_fc2_bias,
......@@ -104,6 +115,7 @@ def _mlp_forward(
sequence_parallel,
tp_group,
is_grad_enabled,
is_first_microbatch,
)
else:
fc1_out, gelu_out = _linear_fwd_non_fp8(
......@@ -321,9 +333,13 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
ln_weight: paddle.Tensor,
ln_bias: paddle.Tensor,
fc1_weight: paddle.Tensor,
fc1_weight_fp8: Optional[paddle.Tensor],
fc1_weight_t_fp8: Optional[paddle.Tensor],
fc1_bias: Union[paddle.Tensor, None],
use_fc1_bias: bool,
fc2_weight: paddle.Tensor,
fc2_weight_fp8: Optional[paddle.Tensor],
fc2_weight_t_fp8: Optional[paddle.Tensor],
fc2_bias: Union[paddle.Tensor, None],
use_fc2_bias: bool,
eps: float,
......@@ -342,6 +358,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
sequence_parallel: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
is_first_microbatch: bool,
) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.shape[0]
......@@ -385,11 +402,15 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
ln_out,
FP8FwdTensors.GEMM1_INPUT,
fc1_weight,
fc1_weight_fp8,
fc1_weight_t_fp8,
FP8FwdTensors.GEMM1_WEIGHT,
fc1_bias,
use_fc1_bias,
FP8FwdTensors.GEMM2_INPUT,
fc2_weight,
fc2_weight_fp8,
fc2_weight_t_fp8,
FP8FwdTensors.GEMM2_WEIGHT,
fc2_bias,
use_fc2_bias,
......@@ -403,6 +424,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
tensor_parallel,
sequence_parallel,
tp_group,
is_first_microbatch,
)
if is_grad_enabled:
......@@ -443,6 +465,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
ctx.requires_fc2_bgrad = use_fc2_bias and not fc2_bias.stop_gradient
ctx.requires_ln_bgrad = not ln_bias.stop_gradient
ctx.requires_ln_wgrad = not ln_weight.stop_gradient
ctx.is_first_microbatch = is_first_microbatch
# [*, in_features] -> [*, out_features] except first dimension changes for SP
fc2_out = fc2_out.reshape((-1, *inp.shape[1:-1], fc2_out.shape[-1]))
......@@ -543,13 +566,23 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
fc1_bgrad_out = (fc1_bgrad,) if ctx.use_fc1_bias else ()
fc2_bgrad_out = (fc2_bgrad,) if ctx.use_fc2_bias else ()
if not ctx.fp8_enabled or ctx.is_first_microbatch is None:
fc1_weight_cache_grad = ()
fc2_weight_cache_grad = ()
else:
# weight_fp8 and weight_t_fp8 are stop_gradient tensors
fc1_weight_cache_grad = (None, None)
fc2_weight_cache_grad = (None, None)
return (
dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma if ctx.requires_ln_wgrad else None,
dbeta if ctx.requires_ln_bgrad else None,
fc1_wgrad if ctx.requires_fc1_wgrad else None,
*fc1_weight_cache_grad,
*fc1_bgrad_out,
fc2_wgrad if ctx.requires_fc2_wgrad else None,
*fc2_weight_cache_grad,
*fc2_bgrad_out,
)
......@@ -675,6 +708,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
self.tensor_parallel,
parallel_mode='column',
backend=self.backend)
self.fp8_weight_shapes.append(self.fc1_weight.shape)
self.has_bias = self._bias_attr is not False
use_default_bias = self._bias_attr is None or self._bias_attr is True
......@@ -704,6 +738,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
self.tensor_parallel,
parallel_mode='row',
backend=self.backend)
self.fp8_weight_shapes.append(self.fc2_weight.shape)
if self.has_bias:
self.fc2_bias = self.create_parameter(
......@@ -734,24 +769,34 @@ class LayerNormMLP(TransformerEngineBaseLayer):
def _te_forward(
self,
inp: paddle.Tensor,
is_first_microbatch: Optional[bool] = None,
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]:
"""
Apply layer normalization to the input followed by a linear transformation.
"""
with self.prepare_forward(inp, num_gemms=2) as inp:
with self.prepare_forward(inp, num_gemms=2, is_first_microbatch=is_first_microbatch) as inp:
# Layer input should be casted outside PyLayer, as performing
# inplace cast to input tensors may cause problems when used
# together with Paddle native layers.
inp = cast_if_needed(inp, self.activation_dtype)
# Get persistent fp8 weight buffer. None if buffer does not exist.
fc1_weight_fp8, fc1_weight_t_fp8, fc2_weight_fp8, fc2_weight_t_fp8 = \
self.get_fp8_weights_scratchpad(is_first_microbatch)
out = _LayerNormMLP.apply(
inp,
self.ln_weight,
self.ln_bias,
self.fc1_weight,
fc1_weight_fp8,
fc1_weight_t_fp8,
self.fc1_bias,
self.has_bias,
self.fc2_weight,
fc2_weight_fp8,
fc2_weight_t_fp8,
self.fc2_bias,
self.has_bias,
self.eps,
......@@ -770,6 +815,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
self.sequence_parallel,
self.tp_group,
self.tp_size,
is_first_microbatch,
)
if self.return_layernorm_output:
......@@ -785,12 +831,17 @@ class LayerNormMLP(TransformerEngineBaseLayer):
def _pd_forward(
self,
inp: paddle.Tensor,
is_first_microbatch: Optional[bool] = None,
) -> paddle.Tensor:
"""Calls Paddle OP"""
if self.zero_centered_gamma:
raise NotImplementedError(
"Paddle backend does not support LayerNorm with zero-centered scale.")
if is_first_microbatch is not None:
warnings.warn(
"`is_first_microbatch` is not supported for paddle backend and is ignored.")
ln_out = F.layer_norm(x=inp,
normalized_shape=inp.shape[-1],
weight=self.ln_weight,
......@@ -816,8 +867,18 @@ class LayerNormMLP(TransformerEngineBaseLayer):
Parameters
----------
inp : torch.Tensor
inp : paddle.Tensor
Input tensor.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
"""
if self.backend == 'transformer_engine':
return self._te_forward(*args, **kwargs)
......
......@@ -3,6 +3,7 @@
# See LICENSE for license information.
"""Linear API"""
import warnings
from typing import Union, Tuple, Dict, Any, Optional
import paddle
......@@ -39,6 +40,7 @@ from ..utils import (
get_bias_dtype,
save_for_backward_allow_none,
saved_tensor_allow_none,
clear_tensor_data,
)
__all__ = ["Linear"]
......@@ -48,6 +50,8 @@ def _linear_fwd_fp8(
inputmat: paddle.Tensor,
inputmat_fp8_index: FP8FwdTensors,
weight: paddle.Tensor,
weight_fp8: Optional[paddle.Tensor],
weight_t_fp8: Optional[paddle.Tensor],
weight_fp8_index: FP8FwdTensors,
bias: paddle.Tensor,
use_bias: bool,
......@@ -58,6 +62,7 @@ def _linear_fwd_fp8(
sequence_parallel: bool,
tp_group: Union[dist_group_type, None],
is_grad_enabled: bool,
is_first_microbatch: bool = None,
):
"""FP8 path of Linear Fwd"""
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
......@@ -69,20 +74,26 @@ def _linear_fwd_fp8(
else:
inputmat_total = inputmat
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
if is_grad_enabled:
if update_fp8_weights:
weight_fp8, weight_t_fp8 = cast_transpose(
weight,
fp8_meta["scaling_fwd"],
weight_fp8_index,
fp8_dtype_forward,
cast_out=weight_fp8,
transpose_out=weight_t_fp8,
)
else:
weight_t_fp8 = None
if update_fp8_weights:
weight_fp8 = cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
weight_fp8_index,
fp8_dtype_forward,
out=weight_fp8,
)
out = fp8_gemm(
......@@ -146,6 +157,7 @@ def _linear_fwd_non_fp8(
# amax of weight
fp8_meta["scaling_fwd"].amax_history[0, weight_fp8_index.value] = \
paddle.max(paddle.abs(weight)).item()
fp8_meta["update_amax_and_scale_fwd"] = True
outputs = gemm(weight,
inputmat_total,
......@@ -172,6 +184,8 @@ def _linear_fwd(
inputmat: paddle.Tensor,
inputmat_fp8_index: FP8FwdTensors,
weight: paddle.Tensor,
weight_fp8: Optional[paddle.Tensor],
weight_t_fp8: Optional[paddle.Tensor],
weight_fp8_index: FP8FwdTensors,
bias: paddle.Tensor,
use_bias: bool,
......@@ -184,12 +198,15 @@ def _linear_fwd(
sequence_parallel: bool,
tp_group: Union[dist_group_type, None],
is_grad_enabled: bool,
is_first_microbatch: bool = None,
):
if fp8_enabled:
out, weight_t_fp8 = _linear_fwd_fp8(
inputmat,
inputmat_fp8_index,
weight,
weight_fp8,
weight_t_fp8,
weight_fp8_index,
bias,
use_bias,
......@@ -200,6 +217,7 @@ def _linear_fwd(
sequence_parallel,
tp_group,
is_grad_enabled,
is_first_microbatch,
)
else:
out = _linear_fwd_non_fp8(
......@@ -270,6 +288,7 @@ def _linear_bwd_fp8(
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
clear_tensor_data(grad_output_c)
# Overlap dgrad-RS/AR with wgrad
if parallel_mode == "column" and sequence_parallel:
......@@ -283,6 +302,7 @@ def _linear_bwd_fp8(
if not fp8_meta["recipe"].override_linear_precision.wgrad:
if inputmat_t_total is None:
inputmat_t_total = transpose(inputmat_total, fp8_dtype_backward)
clear_tensor_data(inputmat_total)
wgrad = fp8_gemm(
inputmat_t_total,
fwd_scale_inverses,
......@@ -296,6 +316,7 @@ def _linear_bwd_fp8(
get_workspace(),
use_split_accumulator=_2X_ACC_WGRAD,
)
clear_tensor_data(inputmat_t_total, grad_output_t)
else:
wgrad, _, _ = gemm(
inputmat_total,
......@@ -305,6 +326,7 @@ def _linear_bwd_fp8(
layout="NT",
grad=True,
)
clear_tensor_data(inputmat_total)
if parallel_mode == "column" and tensor_parallel and handle is not None:
handle.wait()
......@@ -446,6 +468,8 @@ class _Linear(paddle.autograd.PyLayer):
def forward(
ctx,
weight: paddle.Tensor,
weight_fp8: Optional[paddle.Tensor],
weight_t_fp8: Optional[paddle.Tensor],
inp: paddle.Tensor,
bias: paddle.Tensor,
use_bias: bool,
......@@ -459,6 +483,7 @@ class _Linear(paddle.autograd.PyLayer):
sequence_parallel: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
is_first_microbatch: bool,
) -> paddle.Tensor:
# Make sure input dimensions are compatible
in_features = weight.shape[-1]
......@@ -495,6 +520,8 @@ class _Linear(paddle.autograd.PyLayer):
inputmat,
FP8FwdTensors.GEMM1_INPUT,
weight,
weight_fp8,
weight_t_fp8,
FP8FwdTensors.GEMM1_WEIGHT,
bias,
use_bias,
......@@ -507,6 +534,7 @@ class _Linear(paddle.autograd.PyLayer):
sequence_parallel,
tp_group,
is_grad_enabled,
is_first_microbatch,
)
if is_grad_enabled:
......@@ -536,6 +564,7 @@ class _Linear(paddle.autograd.PyLayer):
ctx.requires_dgrad = not inp.stop_gradient
ctx.requires_wgrad = not weight.stop_gradient
ctx.requires_bgrad = use_bias and not bias.stop_gradient
ctx.is_first_microbatch = is_first_microbatch
return out.reshape((-1, *inp.shape[1:-1], out.shape[-1]))
......@@ -591,14 +620,22 @@ class _Linear(paddle.autograd.PyLayer):
# bgrad is fused with gemm for non-FP8 path
bgrad = bgrad_
if not ctx.fp8_enabled or ctx.is_first_microbatch is None:
weight_cache_grad = ()
else:
# weight_fp8 and weight_t_fp8 are stop_gradient tensors
weight_cache_grad = (None, None)
if not ctx.use_bias:
return (
wgrad if ctx.requires_wgrad else None,
*weight_cache_grad,
dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
)
return (
wgrad if ctx.requires_wgrad else None,
*weight_cache_grad,
dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
bgrad if ctx.requires_bgrad else None,
)
......@@ -699,6 +736,8 @@ class Linear(TransformerEngineBaseLayer):
else:
self.bias = None
self.fp8_weight_shapes.append(self.weight.shape)
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
if self.parallel_mode == "row" and self.tensor_parallel and self.has_bias:
......@@ -709,17 +748,24 @@ class Linear(TransformerEngineBaseLayer):
def _te_forward(
self,
inp: paddle.Tensor,
is_first_microbatch: Optional[bool] = None,
) -> paddle.Tensor:
"""
Apply the linear transformation to the input.
"""
with self.prepare_forward(inp) as inp:
with self.prepare_forward(inp, is_first_microbatch=is_first_microbatch) as inp:
# Layer input should be casted outside PyLayer, as performing
# inplace cast to input tensors may cause problems when used
# together with Paddle native layers.
inp = cast_if_needed(inp, self.activation_dtype)
# Get persistent fp8 weight buffer. None if buffer does not exist.
weight_fp8, weight_t_fp8 = self.get_fp8_weights_scratchpad(is_first_microbatch)
out = _Linear.apply(
self.weight,
weight_fp8,
weight_t_fp8,
inp,
self.bias if self.gemm_bias_fused_add else None,
self.has_bias and self.gemm_bias_fused_add,
......@@ -733,6 +779,7 @@ class Linear(TransformerEngineBaseLayer):
self.sequence_parallel,
self.tp_group,
self.tp_size,
is_first_microbatch,
)
if not self.gemm_bias_fused_add:
......@@ -743,8 +790,12 @@ class Linear(TransformerEngineBaseLayer):
def _pd_forward(
self,
inp: paddle.Tensor,
is_first_microbatch: Optional[bool] = None,
) -> paddle.Tensor:
"""Calls Paddle OP"""
if is_first_microbatch is not None:
warnings.warn(
"`is_first_microbatch` is not supported for paddle backend and is ignored.")
if self.parallel_mode == 'column' and self.tensor_parallel:
inp = identity(inp, self.tp_group)
out = F.linear(inp, self.weight, self.bias if self.gemm_bias_fused_add else None)
......@@ -759,8 +810,18 @@ class Linear(TransformerEngineBaseLayer):
Parameters
----------
inp : torch.Tensor
inp : paddle.Tensor
Input tensor.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
"""
if self.backend == 'transformer_engine':
return self._te_forward(*args, **kwargs)
......
......@@ -217,6 +217,7 @@ class TransformerLayer(paddle.nn.Layer):
core_attention_bias: Optional[paddle.Tensor] = None,
set_zero: bool = True,
recompute_core_attention: bool = False,
is_first_microbatch: Optional[bool] = None,
) -> paddle.Tensor:
"""
Transformer Layer: attention block and a feedforward network (MLP)
......@@ -248,6 +249,16 @@ class TransformerLayer(paddle.nn.Layer):
during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until
backprop.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
"""
if self.self_attn_mask_type != "causal" and attention_mask is not None:
......@@ -264,6 +275,7 @@ class TransformerLayer(paddle.nn.Layer):
core_attention_bias=core_attention_bias,
set_zero=set_zero,
recompute_core_attention=recompute_core_attention,
is_first_microbatch=is_first_microbatch,
)
if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
......@@ -286,6 +298,7 @@ class TransformerLayer(paddle.nn.Layer):
core_attention_bias=core_attention_bias,
set_zero=set_zero,
recompute_core_attention=recompute_core_attention,
is_first_microbatch=is_first_microbatch,
)
if self.apply_residual_connection_post_layernorm:
attention_output, residual = inter_attention_outputs
......@@ -298,7 +311,7 @@ class TransformerLayer(paddle.nn.Layer):
bda_output = self.fused_dropout_add2(attention_output, residual)
# MLP.
mlp_outputs = self.layernorm_mlp(bda_output)
mlp_outputs = self.layernorm_mlp(bda_output, is_first_microbatch=is_first_microbatch)
if self.apply_residual_connection_post_layernorm:
mlp_output, residual = mlp_outputs
else:
......
......@@ -121,3 +121,17 @@ def saved_tensor_allow_none(ctx) -> Tuple[Optional[paddle.Tensor]]:
outputs.append(saved_tensors[index])
return tuple(outputs)
def clear_tensor_data(*tensors: Tuple[Optional[paddle.Tensor], ...]) -> None:
"""
Free tensor buffer
"""
def can_free(t):
return (t is not None and isinstance(t, paddle.Tensor) and t._is_initialized()
and t.inplace_version == 0)
for t in tensors:
if can_free(t):
t._clear_dataptr()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment