Unverified Commit daad219f authored by Tian Zheng's avatar Tian Zheng Committed by GitHub
Browse files

[Paddle] Optimize memory usage when training in pipeline parallel (#580)



* Actively free tensor in bwd
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* - Add inplace support for fp8 casting
- Allow skipping weight update in fp8 meta update
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Support weight caching for Linear
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add weight caching for LayernormLinear
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add weight caching for LayerNormMLP
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add weight caching for Transformer layer
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add PP unittests
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix CI
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

---------
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
parent 2ae121d7
...@@ -19,6 +19,23 @@ from utils import assert_allclose, set_random_seed ...@@ -19,6 +19,23 @@ from utils import assert_allclose, set_random_seed
import transformer_engine.paddle as te import transformer_engine.paddle as te
class TELinear(te.Linear):
"""To pass is_first_microbatch"""
def __init__(self, *args, **kwargs):
assert 'accumulate_steps' in kwargs
self.accumulate_steps = kwargs['accumulate_steps']
del kwargs['accumulate_steps']
self._micro_batch_id = 0
super().__init__(*args, **kwargs)
def forward(self, *args, **kwargs):
kwargs['is_first_microbatch'] = (self._micro_batch_id % self.accumulate_steps) == 0
if paddle.is_grad_enabled() and self.training:
self._micro_batch_id += 1
return super().forward(*args, **kwargs)
class TEPipelineModel(PipelineLayer): class TEPipelineModel(PipelineLayer):
"""Model for pipeline parallel test""" """Model for pipeline parallel test"""
...@@ -28,6 +45,7 @@ class TEPipelineModel(PipelineLayer): ...@@ -28,6 +45,7 @@ class TEPipelineModel(PipelineLayer):
weight_attrs, weight_attrs,
use_te=True, use_te=True,
use_fp8=False, use_fp8=False,
accumulate_steps=1,
**kwargs): **kwargs):
self.in_features = in_features self.in_features = in_features
self.hidden_features = hidden_features self.hidden_features = hidden_features
...@@ -35,10 +53,22 @@ class TEPipelineModel(PipelineLayer): ...@@ -35,10 +53,22 @@ class TEPipelineModel(PipelineLayer):
hcg = fleet.get_hybrid_communicate_group() hcg = fleet.get_hybrid_communicate_group()
self.dp_group = hcg.get_data_parallel_group() self.dp_group = hcg.get_data_parallel_group()
Linear = te.Linear if use_te else paddle.nn.Linear Linear = TELinear if use_te else paddle.nn.Linear
extra_kwargs = {}
if use_te:
extra_kwargs['accumulate_steps'] = accumulate_steps
model_desc = [ model_desc = [
LayerDesc(Linear, self.in_features, self.hidden_features, weight_attr=weight_attrs[0]), LayerDesc(Linear,
LayerDesc(Linear, self.hidden_features, self.in_features, weight_attr=weight_attrs[1]), self.in_features,
self.hidden_features,
weight_attr=weight_attrs[0],
**extra_kwargs),
LayerDesc(Linear,
self.hidden_features,
self.in_features,
weight_attr=weight_attrs[1],
**extra_kwargs),
] ]
super().__init__(layers=model_desc, loss_fn=paddle.nn.CrossEntropyLoss(), **kwargs) super().__init__(layers=model_desc, loss_fn=paddle.nn.CrossEntropyLoss(), **kwargs)
...@@ -84,8 +114,9 @@ class TestLinearPipelineParallel(unittest.TestCase): ...@@ -84,8 +114,9 @@ class TestLinearPipelineParallel(unittest.TestCase):
"mp_degree": 1, "mp_degree": 1,
"pp_degree": self.pipeline_parallel_size, "pp_degree": self.pipeline_parallel_size,
} }
self.accumulate_steps = self.batch_size // self.micro_batch_size
strategy.pipeline_configs = { strategy.pipeline_configs = {
"accumulate_steps": self.batch_size // self.micro_batch_size, "accumulate_steps": self.accumulate_steps,
"micro_batch_size": self.micro_batch_size, "micro_batch_size": self.micro_batch_size,
} }
fleet.init(is_collective=True, strategy=strategy) fleet.init(is_collective=True, strategy=strategy)
...@@ -128,6 +159,7 @@ class TestLinearPipelineParallel(unittest.TestCase): ...@@ -128,6 +159,7 @@ class TestLinearPipelineParallel(unittest.TestCase):
use_fp8=self.fp8, use_fp8=self.fp8,
seg_method="layer:Linear", seg_method="layer:Linear",
num_stages=self.pipeline_parallel_size, num_stages=self.pipeline_parallel_size,
accumulate_steps=self.accumulate_steps,
) )
# Check if model is split across ranks as expected # Check if model is split across ranks as expected
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import math import math
import os import os
from utils import assert_allclose from utils import assert_allclose, is_fused_attention_supported
import paddle import paddle
import pytest import pytest
...@@ -14,8 +14,6 @@ from transformer_engine.common.recipe import DelayedScaling ...@@ -14,8 +14,6 @@ from transformer_engine.common.recipe import DelayedScaling
import transformer_engine.paddle as te import transformer_engine.paddle as te
from transformer_engine.paddle.fp8 import is_fp8_available, fp8_autocast from transformer_engine.paddle.fp8 import is_fp8_available, fp8_autocast
from utils import is_fused_attention_supported
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
LINEAR_CASES = [(16, 16, 32), (32, 32, 64)] LINEAR_CASES = [(16, 16, 32), (32, 32, 64)]
NORM_CASES = [(16, 32), (256, 1024)] NORM_CASES = [(16, 32), (256, 1024)]
...@@ -200,6 +198,50 @@ class TestLinear: ...@@ -200,6 +198,50 @@ class TestLinear:
if do_calibration: if do_calibration:
assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0 assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES)
@pytest.mark.parametrize('activation_dtype', ['bfloat16'])
@pytest.mark.parametrize('num_microbatch', [8])
def test_linear_fp8_microbatch(bs, in_features, out_features, activation_dtype, num_microbatch):
"""
Test FP8 Linear
"""
rtol = 0.1
atol = 0.1
recipe = DelayedScaling()
paddle.set_default_dtype(activation_dtype)
layer_cached = te.Linear(
in_features=in_features,
out_features=out_features,
)
layer_normal = te.Linear(
in_features=in_features,
out_features=out_features,
)
layer_cached.weight.copy_(layer_normal.weight, True)
layer_cached.bias.copy_(layer_normal.bias, True)
for iteration in range(num_microbatch):
input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype)
grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
out = layer_cached(input_tensor, is_first_microbatch=(iteration == 0))
out.backward(grad_out)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
out_ref = layer_normal(input_tensor)
out_ref.backward(grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(layer_cached.weight.grad,
layer_normal.weight.grad,
rtol=rtol,
atol=atol)
@pytest.mark.parametrize('bs,hidden_size', NORM_CASES) @pytest.mark.parametrize('bs,hidden_size', NORM_CASES)
@pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]]) @pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]])
...@@ -411,6 +453,62 @@ class TestLayerNormLinear: ...@@ -411,6 +453,62 @@ class TestLayerNormLinear:
if do_calibration: if do_calibration:
assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0 assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES)
@pytest.mark.parametrize('activation_dtype', ['bfloat16'])
@pytest.mark.parametrize('num_microbatch', [8])
def test_layernorm_linear_fp8_microbatch(bs, in_features, out_features, activation_dtype,
num_microbatch):
"""
Test FP8 LayerNormLinear Layer
"""
paddle.set_default_dtype(activation_dtype)
eps = 1e-3
rtol = 0.5
atol = 0.5
recipe = DelayedScaling()
layer_cached = te.LayerNormLinear(
in_features=in_features,
out_features=out_features,
eps=eps,
)
layer_normal = te.LayerNormLinear(
in_features=in_features,
out_features=out_features,
eps=eps,
)
layer_cached.ln_weight.copy_(layer_normal.ln_weight, True)
layer_cached.ln_bias.copy_(layer_normal.ln_bias, True)
layer_cached.weight.copy_(layer_normal.weight, True)
layer_cached.bias.copy_(layer_normal.bias, True)
for iteration in range(num_microbatch):
input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype)
grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
out = layer_cached(input_tensor, is_first_microbatch=(iteration == 0))
out.backward(grad_out)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
out_ref = layer_normal(input_tensor)
out_ref.backward(grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(layer_cached.weight.grad,
layer_normal.weight.grad,
rtol=rtol,
atol=atol)
assert_allclose(layer_cached.ln_weight.grad,
layer_normal.ln_weight.grad,
rtol=rtol,
atol=atol)
class TestLayerNormMLP: class TestLayerNormMLP:
""" """
...@@ -615,6 +713,75 @@ class TestLayerNormMLP: ...@@ -615,6 +713,75 @@ class TestLayerNormMLP:
if do_calibration: if do_calibration:
assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0 assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('bs,hidden_size,ffn_hidden_size', LINEAR_CASES)
@pytest.mark.parametrize('activation_dtype', ['bfloat16'])
@pytest.mark.parametrize('num_microbatch', [8])
def test_layernorm_mlp_fp8_microbatch(bs, hidden_size, ffn_hidden_size, activation_dtype,
num_microbatch):
"""
Test FP8 LayerNormMLP Layer
"""
paddle.set_default_dtype(activation_dtype)
rtol = 1e-5
atol = 1e-5
eps = 1e-3
recipe = DelayedScaling()
layer_cached = te.LayerNormMLP(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
eps=eps,
)
layer_normal = te.LayerNormMLP(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
eps=eps,
)
layer_normal.ln_weight.copy_(layer_cached.ln_weight, True)
layer_normal.ln_bias.copy_(layer_cached.ln_bias, True)
layer_normal.fc1_weight.copy_(layer_cached.fc1_weight, True)
layer_normal.fc2_weight.copy_(layer_cached.fc2_weight, True)
layer_normal.fc1_bias.copy_(layer_cached.fc1_bias, True)
layer_normal.fc2_bias.copy_(layer_cached.fc2_bias, True)
# Calibration to make sure weight scale is the same
input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe):
_ = layer_cached(input_tensor)
with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe):
_ = layer_normal(input_tensor)
for iteration in range(num_microbatch):
input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
out = layer_cached(input_tensor, is_first_microbatch=(iteration == 0))
out.backward(grad_out)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
out_ref = layer_normal(input_tensor)
out_ref.backward(grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(layer_cached.ln_weight.grad,
layer_normal.ln_weight.grad,
rtol=rtol,
atol=atol)
assert_allclose(layer_cached.fc1_weight.grad,
layer_normal.fc1_weight.grad,
rtol=rtol,
atol=atol)
assert_allclose(layer_cached.fc2_weight.grad,
layer_normal.fc2_weight.grad,
rtol=rtol,
atol=atol)
@pytest.mark.parametrize('bs', [1, 2, 8]) @pytest.mark.parametrize('bs', [1, 2, 8])
@pytest.mark.parametrize('hidden_size, num_heads', [[1024, 16], [768, 12]]) @pytest.mark.parametrize('hidden_size, num_heads', [[1024, 16], [768, 12]])
...@@ -1172,3 +1339,122 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -1172,3 +1339,122 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
layer_pd.inter_attention.layernorm_query.bias.grad, layer_pd.inter_attention.layernorm_query.bias.grad,
rtol=rtol, rtol=rtol,
atol=atol) atol=atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('bs', [8])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128]])
@pytest.mark.parametrize('mask_type', ['causal'])
@pytest.mark.parametrize('math_dtype', ['bfloat16'])
@pytest.mark.parametrize('num_microbatch', [8])
def test_transformer_encoder_layer_microbatch(bs, hidden_size, num_heads, ffn_hidden_size, q_seqlen,
kv_seqlen, mask_type, math_dtype, num_microbatch):
"""
Test Transformer Encoder Layer with FP8 weight caching
"""
paddle.set_default_dtype(math_dtype)
rtol = 1e-5
atol = 1e-5
eps = 1e-3
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
num_heads=num_heads,
num_gqa_groups=num_heads,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
head_size=hidden_size // num_heads,
dtype=math_dtype,
dropout=0.0,
qkv_layout="bs3hd",
bias_type="no_bias",
mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
layer_cached = te.TransformerLayer(hidden_size,
ffn_hidden_size,
num_heads,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
weight_attr=None,
bias_attr=None,
self_attn_mask_type=mask_type,
layer_type='encoder')
layer_normal = te.TransformerLayer(hidden_size,
ffn_hidden_size,
num_heads,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
weight_attr=None,
bias_attr=None,
self_attn_mask_type=mask_type,
layer_type='encoder')
layer_normal.self_attention.layernorm_qkv.ln_weight.copy_(
layer_cached.self_attention.layernorm_qkv.ln_weight, True)
layer_normal.self_attention.layernorm_qkv.ln_bias.copy_(
layer_cached.self_attention.layernorm_qkv.ln_bias, True)
layer_normal.self_attention.layernorm_qkv.weight.copy_(
layer_cached.self_attention.layernorm_qkv.weight, True)
layer_normal.self_attention.layernorm_qkv.bias.copy_(
layer_cached.self_attention.layernorm_qkv.bias, True)
layer_normal.self_attention.proj.weight.copy_(layer_cached.self_attention.proj.weight, True)
layer_normal.self_attention.proj.bias.copy_(layer_cached.self_attention.proj.bias, True)
# LayerNorm MLP params
layer_normal.layernorm_mlp.ln_weight.copy_(layer_cached.layernorm_mlp.ln_weight, True)
layer_normal.layernorm_mlp.ln_bias.copy_(layer_cached.layernorm_mlp.ln_bias, True)
layer_normal.layernorm_mlp.fc1_weight.copy_(layer_cached.layernorm_mlp.fc1_weight, True)
layer_normal.layernorm_mlp.fc2_weight.copy_(layer_cached.layernorm_mlp.fc2_weight, True)
layer_normal.layernorm_mlp.fc1_bias.copy_(layer_cached.layernorm_mlp.fc1_bias, True)
layer_normal.layernorm_mlp.fc2_bias.copy_(layer_cached.layernorm_mlp.fc2_bias, True)
recipe = DelayedScaling()
def generate_input():
encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)
q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen
kv_actual_seqlen = q_actual_seqlen
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool')
grad_out = paddle.normal(mean=0.0, std=0.02,
shape=(bs, q_seqlen, hidden_size)).astype('float32')
for i in range(0, bs):
grad_out[i, q_actual_seqlen[i]:, :] = 0
grad_out = grad_out.astype(math_dtype)
for i in range(0, bs):
attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False
return encoder_input, attn_mask, grad_out
# Calibration to make sure weight scale is the same
encoder_input, mask, _ = generate_input()
with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe):
_ = layer_cached(encoder_input, mask)
with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe):
_ = layer_normal(encoder_input, mask)
for iteration in range(num_microbatch):
encoder_input, mask, grad_out = generate_input()
with fp8_autocast(enabled=True, fp8_recipe=recipe):
out = layer_cached(encoder_input, mask, is_first_microbatch=(iteration == 0))
out.backward(grad_out)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
out_ref = layer_normal(encoder_input, mask)
out_ref.backward(grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(layer_cached.self_attention.layernorm_qkv.weight.grad,
layer_normal.self_attention.layernorm_qkv.weight.grad,
rtol=rtol,
atol=atol)
...@@ -10,8 +10,14 @@ import paddle ...@@ -10,8 +10,14 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
import pytest import pytest
import transformer_engine # pylint: disable=unused-import from utils import (
import transformer_engine_paddle as tex # pylint: disable=wrong-import-order assert_allclose,
create_fp8_meta,
get_fused_attention_backend,
is_fused_attention_supported,
)
import transformer_engine_paddle as tex
from transformer_engine.paddle.cpp_extensions import ( from transformer_engine.paddle.cpp_extensions import (
cast_to_fp8, cast_to_fp8,
cast_from_fp8, cast_from_fp8,
...@@ -44,13 +50,6 @@ from transformer_engine.paddle.fp8 import is_fp8_available ...@@ -44,13 +50,6 @@ from transformer_engine.paddle.fp8 import is_fp8_available
from transformer_engine.paddle.constants import FP8FwdTensors from transformer_engine.paddle.constants import FP8FwdTensors
from transformer_engine.common.recipe import DelayedScaling from transformer_engine.common.recipe import DelayedScaling
from utils import (
assert_allclose,
create_fp8_meta,
get_fused_attention_backend,
is_fused_attention_supported,
)
GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024), GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)] (16384, 1024, 1024)]
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
...@@ -69,20 +68,24 @@ def setup(): ...@@ -69,20 +68,24 @@ def setup():
yield yield
def test_quantize_dequantize(): @pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize('inplace', [True, False])
def test_quantize_dequantize(fp8_dtype, inplace):
""" """
Test cast_to_fp8 and cast_from_fp8 Test cast_to_fp8 and cast_from_fp8
""" """
a = paddle.rand(shape=(32, 32), dtype='float32') a = paddle.rand(shape=(32, 32), dtype='float32')
# Init fp8_meta # Init fp8_meta
fp8_meta = create_fp8_meta() fp8_meta = create_fp8_meta()
for fp8_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: a_fp8 = paddle.zeros(shape=a.shape, dtype=paddle.uint8) if inplace else None
a_fp8 = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_OUTPUT, otype=fp8_dtype) a_fp8 = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_OUTPUT, otype=fp8_dtype, out=a_fp8)
b = cast_from_fp8(a_fp8, b = cast_from_fp8(
a_fp8,
fp8_meta, fp8_meta,
FP8FwdTensors.GEMM1_OUTPUT, FP8FwdTensors.GEMM1_OUTPUT,
itype=fp8_dtype, itype=fp8_dtype,
otype=tex.DType.kFloat32) otype=tex.DType.kFloat32,
)
assert_allclose(a, b, rtol=5e-2, atol=5e-2) assert_allclose(a, b, rtol=5e-2, atol=5e-2)
...@@ -142,7 +145,8 @@ class TestTranspose: ...@@ -142,7 +145,8 @@ class TestTranspose:
@staticmethod @staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) @pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_cast_transpose(fp8_dtype): @pytest.mark.parametrize('inplace', [True, False])
def test_cast_transpose(fp8_dtype, inplace):
""" """
Test cast_transpose Test cast_transpose
""" """
...@@ -150,10 +154,16 @@ class TestTranspose: ...@@ -150,10 +154,16 @@ class TestTranspose:
max_val = 8 max_val = 8
a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32') a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32')
fp8_meta = create_fp8_meta() fp8_meta = create_fp8_meta()
a_fp8_casted, a_fp8_transposed = None, None
if inplace:
a_fp8_casted = paddle.zeros(shape=a.shape, dtype=paddle.uint8)
a_fp8_transposed = paddle.zeros(shape=a.T.shape, dtype=paddle.uint8)
a_fp8_casted, a_fp8_transposed = cast_transpose(a, a_fp8_casted, a_fp8_transposed = cast_transpose(a,
fp8_meta, fp8_meta,
FP8FwdTensors.GEMM1_INPUT, FP8FwdTensors.GEMM1_INPUT,
otype=fp8_dtype) otype=fp8_dtype,
cast_out=a_fp8_casted,
transpose_out=a_fp8_transposed)
a_transposed = cast_from_fp8(a_fp8_transposed, a_transposed = cast_from_fp8(a_fp8_transposed,
fp8_meta, fp8_meta,
...@@ -616,7 +626,7 @@ class TestFusedAttn: ...@@ -616,7 +626,7 @@ class TestFusedAttn:
assert attn_mode == "self_attn", "only support causal masking for self attention" assert attn_mode == "self_attn", "only support causal masking for self attention"
for i in range(0, self.batch_size): for i in range(0, self.batch_size):
for j in range(self.q_actual_seqlen[i]): for j in range(self.q_actual_seqlen[i]):
self.attn_mask[i, :, j, :j+1] = 0 self.attn_mask[i, :, j, :j + 1] = 0
else: else:
for i in range(0, self.batch_size): for i in range(0, self.batch_size):
self.attn_mask[i, :, :self.q_actual_seqlen[i], :self.kv_actual_seqlen[i]] = 0 self.attn_mask[i, :, :self.q_actual_seqlen[i], :self.kv_actual_seqlen[i]] = 0
...@@ -682,11 +692,7 @@ class TestFusedAttn: ...@@ -682,11 +692,7 @@ class TestFusedAttn:
q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True) q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True)
kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True) kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True)
qkv_layout = ( qkv_layout = ("bs3hd" if self.attn_mode == "self_attn" else "bshd_bs2hd")
"bs3hd"
if self.attn_mode == "self_attn"
else "bshd_bs2hd"
)
fused_attention_backend = get_fused_attention_backend( fused_attention_backend = get_fused_attention_backend(
num_heads=self.num_heads, num_heads=self.num_heads,
num_gqa_groups=self.num_heads, num_gqa_groups=self.num_heads,
...@@ -932,12 +938,14 @@ class TestSoftmax: ...@@ -932,12 +938,14 @@ class TestSoftmax:
assert_allclose(dx_ref, dx, rtol=1e-4, atol=5e-3) assert_allclose(dx_ref, dx, rtol=1e-4, atol=5e-3)
def test_amax_and_scale_update(): @pytest.mark.parametrize('update_weight_scale_inv', [True, False])
def test_amax_and_scale_update(update_weight_scale_inv):
"""Test update_scale""" """Test update_scale"""
num_gemm = 6 num_gemm = 6
history_len = 1024 history_len = 1024
recipe = DelayedScaling() recipe = DelayedScaling()
fp8_max = recipe.fp8_format.value.max_fwd fp8_max = recipe.fp8_format.value.max_fwd
non_weight_mask = paddle.to_tensor([True, False] * (num_gemm // 2))
amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype='float32') amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype='float32')
rolled_history_ref = paddle.roll(amax_history_tensor, -1, axis=0) rolled_history_ref = paddle.roll(amax_history_tensor, -1, axis=0)
...@@ -947,13 +955,17 @@ def test_amax_and_scale_update(): ...@@ -947,13 +955,17 @@ def test_amax_and_scale_update():
def calc_ref(amax, scale, fp8_max, margin=0): def calc_ref(amax, scale, fp8_max, margin=0):
"""Calculate reference scale""" """Calculate reference scale"""
sf = (fp8_max / amax) / (2 ** margin) sf = (fp8_max / amax) / (2**margin)
sf = paddle.where(amax > 0.0, sf, scale) sf = paddle.where(amax > 0.0, sf, scale)
sf = paddle.where(paddle.isfinite(amax), sf, scale) sf = paddle.where(paddle.isfinite(amax), sf, scale)
return sf return sf
scale_ref = calc_ref(amax_tensor, scale_tensor, fp8_max, 0.) scale_ref = calc_ref(amax_tensor, scale_tensor, fp8_max, 0.)
if update_weight_scale_inv:
scale_inv_ref = 1. / scale_ref scale_inv_ref = 1. / scale_ref
else:
scale_inv_ref = paddle.zeros_like(scale_tensor)
scale_inv_ref = paddle.where(non_weight_mask, 1. / scale_ref, scale_inv_ref)
# Placeholder # Placeholder
scale_actual = paddle.zeros_like(scale_tensor) scale_actual = paddle.zeros_like(scale_tensor)
...@@ -962,6 +974,8 @@ def test_amax_and_scale_update(): ...@@ -962,6 +974,8 @@ def test_amax_and_scale_update():
tex.amax_and_scale_update_inplace(_amax_history=amax_history_tensor, tex.amax_and_scale_update_inplace(_amax_history=amax_history_tensor,
_scale=scale_actual, _scale=scale_actual,
_scale_inv=scale_inv_actual, _scale_inv=scale_inv_actual,
non_weight_mask=non_weight_mask,
update_weight_scale_inv=update_weight_scale_inv,
fp8_max=fp8_max, fp8_max=fp8_max,
margin=0., margin=0.,
amax_compute="max") amax_compute="max")
......
...@@ -185,11 +185,22 @@ def cast_to_fp8( ...@@ -185,11 +185,22 @@ def cast_to_fp8(
fp8_meta_tensor: FP8TensorMeta, fp8_meta_tensor: FP8TensorMeta,
fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType, otype: tex.DType,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor: ) -> paddle.Tensor:
"""Cast input to FP8""" """Cast input to FP8"""
out, _, _ = tex.cast_to_fp8( if out is None:
out = paddle.empty(
shape=inp.shape,
dtype=paddle.uint8,
)
else:
assert out.shape == inp.shape, "Output shape does not match input shape."
assert out.dtype == paddle.uint8, "Output should be of uint8 dtype."
tex.cast_to_fp8(
inp, inp,
fp8_meta_tensor.scale, fp8_meta_tensor.scale,
out,
fp8_meta_tensor.amax_history, fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv, fp8_meta_tensor.scale_inv,
fp8_tensor.value, fp8_tensor.value,
...@@ -231,11 +242,34 @@ def cast_transpose( ...@@ -231,11 +242,34 @@ def cast_transpose(
fp8_meta_tensor: FP8TensorMeta, fp8_meta_tensor: FP8TensorMeta,
fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType, otype: tex.DType,
cast_out: Optional[paddle.Tensor] = None,
transpose_out: Optional[paddle.Tensor] = None,
) -> Union[Tuple[paddle.Tensor, paddle.Tensor], None]: ) -> Union[Tuple[paddle.Tensor, paddle.Tensor], None]:
"""Cast + Transpose with FP8 output""" """Cast + Transpose with FP8 output"""
cast_out, transpose_out, _, _ = tex.te_cast_transpose( if cast_out is None:
cast_out = paddle.empty(
shape=inp.shape,
dtype=paddle.uint8,
)
else:
assert cast_out.shape == inp.shape, "cast_out shape does not match input shape."
assert cast_out.dtype == paddle.uint8, "cast_out should be of uint8 dtype."
if transpose_out is None:
transpose_out = paddle.empty(
shape=[inp.shape[1], inp.shape[0]],
dtype=paddle.uint8,
)
else:
assert transpose_out.shape == [inp.shape[1], inp.shape[0]
], "Transposed output shape does not match input shape."
assert transpose_out.dtype == paddle.uint8, "Output should be of uint8 dtype."
tex.te_cast_transpose(
inp, inp,
fp8_meta_tensor.scale, fp8_meta_tensor.scale,
cast_out,
transpose_out,
fp8_meta_tensor.amax_history, fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv, fp8_meta_tensor.scale_inv,
fp8_tensor.value, fp8_tensor.value,
......
...@@ -40,21 +40,19 @@ NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type) { ...@@ -40,21 +40,19 @@ NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type) {
} }
} }
std::vector<paddle::Tensor> cast_to_fp8(const paddle::Tensor &input, const paddle::Tensor &scale, void cast_to_fp8(const paddle::Tensor &input, const paddle::Tensor &scale,
paddle::Tensor &amax, paddle::Tensor &scale_inv, // NOLINT paddle::Tensor &output, // NOLINT
paddle::Tensor &amax, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
int64_t index, int64_t otype) { int64_t index, int64_t otype) {
auto shape = GetShapeArray(input); auto shape = GetShapeArray(input);
auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)));
auto input_cu = MakeNvteTensor(input); auto input_cu = MakeNvteTensor(input);
auto output_cu = MakeNvteTensor( auto output_cu = MakeNvteTensor(
output.data(), shape, Int2NvteDType(otype), GetDataPtr<float>(amax, index), output.data(), shape, Int2NvteDType(otype), GetDataPtr<float>(amax, index),
const_cast<void *>(GetDataPtr<float>(scale, index)), GetDataPtr<float>(scale_inv, index)); const_cast<void *>(GetDataPtr<float>(scale, index)), GetDataPtr<float>(scale_inv, index));
nvte_fp8_quantize(input_cu.data(), output_cu.data(), input.stream()); nvte_fp8_quantize(input_cu.data(), output_cu.data(), input.stream());
return {output};
} }
std::vector<paddle::Tensor> cast_from_fp8(const paddle::Tensor &input, std::vector<paddle::Tensor> cast_from_fp8(const paddle::Tensor &input,
...@@ -89,8 +87,9 @@ std::vector<paddle::Tensor> te_transpose(const paddle::Tensor &input, int64_t ot ...@@ -89,8 +87,9 @@ std::vector<paddle::Tensor> te_transpose(const paddle::Tensor &input, int64_t ot
return {output}; return {output};
} }
std::vector<paddle::Tensor> te_cast_transpose(const paddle::Tensor &input, void te_cast_transpose(const paddle::Tensor &input, const paddle::Tensor &scale,
const paddle::Tensor &scale, paddle::Tensor &output_cast, // NOLINT
paddle::Tensor &output_transpose, // NOLINT
paddle::Tensor &amax, // NOLINT paddle::Tensor &amax, // NOLINT
paddle::Tensor &scale_inv, // NOLINT paddle::Tensor &scale_inv, // NOLINT
int64_t index, int64_t otype) { int64_t index, int64_t otype) {
...@@ -100,24 +99,17 @@ std::vector<paddle::Tensor> te_cast_transpose(const paddle::Tensor &input, ...@@ -100,24 +99,17 @@ std::vector<paddle::Tensor> te_cast_transpose(const paddle::Tensor &input,
size_t M = shape[0]; size_t M = shape[0];
size_t N = shape[1]; size_t N = shape[1];
auto input_cast =
paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto input_transpose = paddle::empty({input.shape()[1], input.shape()[0]},
Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto input_cu = MakeNvteTensor(input); auto input_cu = MakeNvteTensor(input);
void *amax_data = GetDataPtr<float>(amax, index); void *amax_data = GetDataPtr<float>(amax, index);
void *scale_data = const_cast<void *>(GetDataPtr<float>(scale, index)); void *scale_data = const_cast<void *>(GetDataPtr<float>(scale, index));
void *scale_inv_data = GetDataPtr<float>(scale_inv, index); void *scale_inv_data = GetDataPtr<float>(scale_inv, index);
auto output_cast_cu = MakeNvteTensor(input_cast.data(), {M, N}, Int2NvteDType(otype), amax_data, auto output_cast_cu = MakeNvteTensor(output_cast.data(), {M, N}, Int2NvteDType(otype),
scale_data, scale_inv_data); amax_data, scale_data, scale_inv_data);
auto output_transpose_cu = MakeNvteTensor(input_transpose.data(), {N, M}, Int2NvteDType(otype), auto output_transpose_cu = MakeNvteTensor(output_transpose.data(), {N, M}, Int2NvteDType(otype),
amax_data, scale_data, scale_inv_data); amax_data, scale_data, scale_inv_data);
nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(), nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(),
input.stream()); input.stream());
return {input_cast, input_transpose};
} }
std::vector<paddle::Tensor> te_cast_transpose_bgrad(const paddle::Tensor &grad_output, std::vector<paddle::Tensor> te_cast_transpose_bgrad(const paddle::Tensor &grad_output,
...@@ -1021,9 +1013,9 @@ void te_scaled_upper_triang_masked_softmax_backward(paddle::Tensor &output_grads ...@@ -1021,9 +1013,9 @@ void te_scaled_upper_triang_masked_softmax_backward(paddle::Tensor &output_grads
} }
__global__ void UpdateFP8MetaKernel(const float *amax, const float *rolled_amax_history, __global__ void UpdateFP8MetaKernel(const float *amax, const float *rolled_amax_history,
float *amax_history, float *scale, float *scale_inv, const bool *non_weight_mask, float *amax_history, float *scale,
float margin, float fp8_max, size_t history_numel, float *scale_inv, bool update_weight_scale_inv, float margin,
size_t amax_numel) { float fp8_max, size_t history_numel, size_t amax_numel) {
int idx = blockIdx.x * blockDim.x + threadIdx.x; int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= history_numel) { if (idx >= history_numel) {
...@@ -1036,7 +1028,7 @@ __global__ void UpdateFP8MetaKernel(const float *amax, const float *rolled_amax_ ...@@ -1036,7 +1028,7 @@ __global__ void UpdateFP8MetaKernel(const float *amax, const float *rolled_amax_
float sf = (fp8_max / amax[idx]) / powf(2.0f, margin); float sf = (fp8_max / amax[idx]) / powf(2.0f, margin);
float scale_reg = ((amax[idx] > 0.0f) && isfinite(amax[idx])) ? sf : scale[idx]; float scale_reg = ((amax[idx] > 0.0f) && isfinite(amax[idx])) ? sf : scale[idx];
scale[idx] = scale_reg; scale[idx] = scale_reg;
scale_inv[idx] = 1.0f / scale_reg; if (update_weight_scale_inv || non_weight_mask[idx]) scale_inv[idx] = 1.0f / scale_reg;
amax_history[idx] = 0.0f; amax_history[idx] = 0.0f;
} }
} }
...@@ -1046,7 +1038,9 @@ constexpr int BLOCK_SIZE = 512; ...@@ -1046,7 +1038,9 @@ constexpr int BLOCK_SIZE = 512;
void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT
paddle::Tensor &scale, // NOLINT paddle::Tensor &scale, // NOLINT
paddle::Tensor &scale_inv, // NOLINT paddle::Tensor &scale_inv, // NOLINT
float fp8_max, float margin, const std::string &amax_compute) { const paddle::Tensor &non_weight_mask,
bool update_weight_scale_inv, float fp8_max, float margin,
const std::string &amax_compute) {
NVTE_CHECK(amax_compute == "max" || amax_compute == "most_recent"); NVTE_CHECK(amax_compute == "max" || amax_compute == "most_recent");
paddle::Tensor amax; paddle::Tensor amax;
...@@ -1062,9 +1056,9 @@ void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT ...@@ -1062,9 +1056,9 @@ void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT
auto size = amax_history.numel(); auto size = amax_history.numel();
size_t num_blocks = (size + BLOCK_SIZE - 1) / BLOCK_SIZE; size_t num_blocks = (size + BLOCK_SIZE - 1) / BLOCK_SIZE;
UpdateFP8MetaKernel<<<num_blocks, BLOCK_SIZE, 0, amax_history.stream()>>>( UpdateFP8MetaKernel<<<num_blocks, BLOCK_SIZE, 0, amax_history.stream()>>>(
amax.data<float>(), rolled_amax_history.data<float>(), amax_history.data<float>(), amax.data<float>(), rolled_amax_history.data<float>(), non_weight_mask.data<bool>(),
scale.data<float>(), scale_inv.data<float>(), margin, fp8_max, amax_history.numel(), amax_history.data<float>(), scale.data<float>(), scale_inv.data<float>(),
amax.numel()); update_weight_scale_inv, margin, fp8_max, amax_history.numel(), amax.numel());
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
...@@ -1178,10 +1172,10 @@ PD_BUILD_OP(te_gemm) ...@@ -1178,10 +1172,10 @@ PD_BUILD_OP(te_gemm)
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_gemm)); .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_gemm));
PD_BUILD_OP(cast_to_fp8) PD_BUILD_OP(cast_to_fp8)
.Inputs({"Input", "Scale", "_Amax", "_ScaleInv"}) .Inputs({"Input", "Scale", "_Output", "_Amax", "_ScaleInv"})
.Outputs({"Output", "Amax", "ScaleInv"}) .Outputs({"Output", "Amax", "ScaleInv"})
.Attrs({"index: int64_t", "otype: int64_t"}) .Attrs({"index: int64_t", "otype: int64_t"})
.SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}}) .SetInplaceMap({{"_Output", "Output"}, {"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::cast_to_fp8)); .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::cast_to_fp8));
PD_BUILD_OP(cast_from_fp8) PD_BUILD_OP(cast_from_fp8)
...@@ -1197,9 +1191,12 @@ PD_BUILD_OP(te_transpose) ...@@ -1197,9 +1191,12 @@ PD_BUILD_OP(te_transpose)
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_transpose)); .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_transpose));
PD_BUILD_OP(te_cast_transpose) PD_BUILD_OP(te_cast_transpose)
.Inputs({"Input", "Scale", "_Amax", "_ScaleInv"}) .Inputs({"Input", "Scale", "_CastedOutput", "_TransposedOutput", "_Amax", "_ScaleInv"})
.Outputs({"CastedOutput", "TransposedOutput", "Amax", "ScaleInv"}) .Outputs({"CastedOutput", "TransposedOutput", "Amax", "ScaleInv"})
.SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}}) .SetInplaceMap({{"_CastedOutput", "CastedOutput"},
{"_TransposedOutput", "TransposedOutput"},
{"_Amax", "Amax"},
{"_ScaleInv", "ScaleInv"}})
.Attrs({"index: int64_t", "otype: int64_t"}) .Attrs({"index: int64_t", "otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_cast_transpose)); .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_cast_transpose));
...@@ -1361,12 +1358,13 @@ PD_BUILD_OP(te_scaled_upper_triang_masked_softmax_backward) ...@@ -1361,12 +1358,13 @@ PD_BUILD_OP(te_scaled_upper_triang_masked_softmax_backward)
PD_KERNEL(transformer_engine::paddle_ext::te_scaled_upper_triang_masked_softmax_backward)); PD_KERNEL(transformer_engine::paddle_ext::te_scaled_upper_triang_masked_softmax_backward));
PD_BUILD_OP(amax_and_scale_update_inplace) PD_BUILD_OP(amax_and_scale_update_inplace)
.Inputs({"_amax_history", "_scale", "_scale_inv"}) .Inputs({"_amax_history", "_scale", "_scale_inv", "non_weight_mask"})
.Outputs({"amax_history", "scale", "scale_inv"}) .Outputs({"amax_history", "scale", "scale_inv"})
.SetInplaceMap({{"_amax_history", "amax_history"}, .SetInplaceMap({{"_amax_history", "amax_history"},
{"_scale", "scale"}, {"_scale", "scale"},
{"_scale_inv", "scale_inv"}}) {"_scale_inv", "scale_inv"}})
.Attrs({"fp8_max: float", "margin: float", "amax_compute: std::string"}) .Attrs({"update_weight_scale_inv: bool", "fp8_max: float", "margin: float",
"amax_compute: std::string"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::amax_and_scale_update_inplace)); .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::amax_and_scale_update_inplace));
PD_BUILD_OP(update_latest_amax_history_inplace) PD_BUILD_OP(update_latest_amax_history_inplace)
......
...@@ -15,10 +15,8 @@ from transformer_engine.common.recipe import DelayedScaling, Format ...@@ -15,10 +15,8 @@ from transformer_engine.common.recipe import DelayedScaling, Format
from .constants import dist_group_type from .constants import dist_group_type
from .fp8_buffer import FP8MetaFwdBuffer, FP8MetaBwdBuffer, FP8RecomputeBuffer from .fp8_buffer import FP8MetaFwdBuffer, FP8MetaBwdBuffer, FP8RecomputeBuffer
__all__ = ['fp8_autocast'] __all__ = ['fp8_autocast']
# FP8 support # FP8 support
_is_fp8_available = None _is_fp8_available = None
_reason_for_no_fp8 = "" _reason_for_no_fp8 = ""
...@@ -227,6 +225,7 @@ def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> t ...@@ -227,6 +225,7 @@ def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> t
def amax_and_scale_update( def amax_and_scale_update(
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
fwd_update: bool, fwd_update: bool,
update_weight_scale_inv: bool = True,
) -> None: ) -> None:
"""Updates fp8 amaxes/scales for fwd | bwd.""" """Updates fp8 amaxes/scales for fwd | bwd."""
amax_compute = fp8_meta["recipe"].amax_compute_algo amax_compute = fp8_meta["recipe"].amax_compute_algo
...@@ -235,9 +234,12 @@ def amax_and_scale_update( ...@@ -235,9 +234,12 @@ def amax_and_scale_update(
fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd" fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd"
if not callable(amax_compute) and sf_compute is None: if not callable(amax_compute) and sf_compute is None:
tex.amax_and_scale_update_inplace(_amax_history=fp8_meta[fp8_meta_tensor_key].amax_history, tex.amax_and_scale_update_inplace(
_amax_history=fp8_meta[fp8_meta_tensor_key].amax_history,
_scale=fp8_meta[fp8_meta_tensor_key].scale, _scale=fp8_meta[fp8_meta_tensor_key].scale,
_scale_inv=fp8_meta[fp8_meta_tensor_key].scale_inv, _scale_inv=fp8_meta[fp8_meta_tensor_key].scale_inv,
non_weight_mask=fp8_meta[fp8_meta_tensor_key].non_weight_mask,
update_weight_scale_inv=update_weight_scale_inv,
fp8_max=fp8_meta[fp8_max_key], fp8_max=fp8_meta[fp8_max_key],
margin=float(fp8_meta["recipe"].margin), margin=float(fp8_meta["recipe"].margin),
amax_compute=amax_compute) amax_compute=amax_compute)
...@@ -254,10 +256,20 @@ class FP8TensorMeta(): ...@@ -254,10 +256,20 @@ class FP8TensorMeta():
self.scale = paddle.Tensor() self.scale = paddle.Tensor()
self.scale_inv = paddle.Tensor() self.scale_inv = paddle.Tensor()
self.amax_history = paddle.Tensor() self.amax_history = paddle.Tensor()
self.non_weight_mask = paddle.Tensor()
self.is_initialized = False self.is_initialized = False
self.is_forward = is_forward self.is_forward = is_forward
def prepare(self, num_gemms: bool, amax_history_len: int) -> None: def get_non_weight_mask(self, num_gemms: int):
"""Needed for calculation of scale inverses to
preserve scale_inv when caching FP8 weights"""
if self.is_forward:
# [True, False, True]: -> [input, weight, output]
return paddle.to_tensor([True, False, True] * num_gemms)
# [True, True]: -> [grad_output, grad_input]
return paddle.to_tensor([True, True] * num_gemms)
def prepare(self, num_gemms: int, amax_history_len: int) -> None:
"""Prepare scales and amax tensors. It is called during fprop in each iteration. """Prepare scales and amax tensors. It is called during fprop in each iteration.
If the meta tensors are not initialized yet, initialization is performed. If already If the meta tensors are not initialized yet, initialization is performed. If already
initialized, resize the meta tensors if amax_history_len has changed.""" initialized, resize the meta tensors if amax_history_len has changed."""
...@@ -284,6 +296,8 @@ class FP8TensorMeta(): ...@@ -284,6 +296,8 @@ class FP8TensorMeta():
self.scale = paddle.ones(num_fp8_tensors, dtype='float32') self.scale = paddle.ones(num_fp8_tensors, dtype='float32')
self.scale_inv = paddle.ones(num_fp8_tensors, dtype='float32') self.scale_inv = paddle.ones(num_fp8_tensors, dtype='float32')
self.amax_history = paddle.zeros([amax_history_len, num_fp8_tensors], dtype='float32') self.amax_history = paddle.zeros([amax_history_len, num_fp8_tensors], dtype='float32')
self.non_weight_mask = self.get_non_weight_mask(num_gemms=num_gemms)
self.is_initialized = True self.is_initialized = True
def to_numpy(self): def to_numpy(self):
...@@ -300,4 +314,9 @@ class FP8TensorMeta(): ...@@ -300,4 +314,9 @@ class FP8TensorMeta():
self.scale = paddle.to_tensor(data['scale']) self.scale = paddle.to_tensor(data['scale'])
self.scale_inv = paddle.to_tensor(data['scale_inv']) self.scale_inv = paddle.to_tensor(data['scale_inv'])
self.amax_history = paddle.to_tensor(data['amax_history']) self.amax_history = paddle.to_tensor(data['amax_history'])
num_fp8_tensors = self.scale.shape[0]
num_gemms = num_fp8_tensors // 3 if self.is_forward else num_fp8_tensors // 2
self.non_weight_mask = self.get_non_weight_mask(num_gemms=num_gemms)
self.is_initialized = True self.is_initialized = True
...@@ -552,6 +552,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -552,6 +552,7 @@ class MultiHeadAttention(paddle.nn.Layer):
core_attention_bias: Optional[paddle.Tensor] = None, core_attention_bias: Optional[paddle.Tensor] = None,
set_zero: bool = True, set_zero: bool = True,
recompute_core_attention: bool = False, recompute_core_attention: bool = False,
is_first_microbatch: Optional[bool] = None,
) -> Tuple[Union[paddle.Tensor, None], ...]: ) -> Tuple[Union[paddle.Tensor, None], ...]:
""" """
MultiHeadAttention Layer. MultiHeadAttention Layer.
...@@ -575,6 +576,16 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -575,6 +576,16 @@ class MultiHeadAttention(paddle.nn.Layer):
during the backward pass in order to save memory that would during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until otherwise be occupied to store the forward activations until
backprop. backprop.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
""" """
if self.attn_mask_type != "causal" and attention_mask is not None: if self.attn_mask_type != "causal" and attention_mask is not None:
...@@ -594,13 +605,19 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -594,13 +605,19 @@ class MultiHeadAttention(paddle.nn.Layer):
if self.attention_type == "self": if self.attention_type == "self":
if self.input_layernorm: if self.input_layernorm:
layernorm_qkv_outputs = self.layernorm_qkv(hidden_states) layernorm_qkv_outputs = self.layernorm_qkv(
hidden_states,
is_first_microbatch=is_first_microbatch,
)
if self.return_layernorm_output: if self.return_layernorm_output:
mixed_qkv_layer, layernorm_output = layernorm_qkv_outputs mixed_qkv_layer, layernorm_output = layernorm_qkv_outputs
else: else:
mixed_qkv_layer = layernorm_qkv_outputs mixed_qkv_layer = layernorm_qkv_outputs
else: else:
mixed_qkv_layer = self.qkv(hidden_states) mixed_qkv_layer = self.qkv(
hidden_states,
is_first_microbatch=is_first_microbatch,
)
# [b, s_q, 3 * hidden_size] --> [b, s_q, 3, num_heads, head_size] # [b, s_q, 3 * hidden_size] --> [b, s_q, 3, num_heads, head_size]
mixed_qkv_layer = mixed_qkv_layer.reshape(shape=[ mixed_qkv_layer = mixed_qkv_layer.reshape(shape=[
...@@ -631,7 +648,10 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -631,7 +648,10 @@ class MultiHeadAttention(paddle.nn.Layer):
) )
else: # cross attention else: # cross attention
mixed_kv_layer = self.key_value(encoder_output) mixed_kv_layer = self.key_value(
encoder_output,
is_first_microbatch=is_first_microbatch,
)
# [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size] # [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size]
mixed_kv_layer = mixed_kv_layer.reshape(shape=[ mixed_kv_layer = mixed_kv_layer.reshape(shape=[
-1, max_seq_len, 2, self.num_attention_heads_per_partition, -1, max_seq_len, 2, self.num_attention_heads_per_partition,
...@@ -639,13 +659,19 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -639,13 +659,19 @@ class MultiHeadAttention(paddle.nn.Layer):
]) ])
if self.input_layernorm: if self.input_layernorm:
layernorm_query_outputs = self.layernorm_query(hidden_states) layernorm_query_outputs = self.layernorm_query(
hidden_states,
is_first_microbatch=is_first_microbatch,
)
if self.return_layernorm_output: if self.return_layernorm_output:
query_layer, layernorm_output = layernorm_query_outputs query_layer, layernorm_output = layernorm_query_outputs
else: else:
query_layer = layernorm_query_outputs query_layer = layernorm_query_outputs
else: else:
query_layer = self.query_layer(hidden_states) query_layer = self.query_layer(
hidden_states,
is_first_microbatch=is_first_microbatch,
)
query_layer = query_layer.reshape(shape=[ query_layer = query_layer.reshape(shape=[
-1, max_seq_len, self.num_attention_heads_per_partition, -1, max_seq_len, self.num_attention_heads_per_partition,
...@@ -681,7 +707,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -681,7 +707,7 @@ class MultiHeadAttention(paddle.nn.Layer):
[-1, context_layer.shape[2] * context_layer.shape[3]]) [-1, context_layer.shape[2] * context_layer.shape[3]])
# Output. [b, s, hidden] # Output. [b, s, hidden]
attention_output = self.proj(context_layer) attention_output = self.proj(context_layer, is_first_microbatch=is_first_microbatch)
if self.input_layernorm and self.return_layernorm_output: if self.input_layernorm and self.return_layernorm_output:
return attention_output, layernorm_output return attention_output, layernorm_output
......
...@@ -7,7 +7,7 @@ from abc import ABC, abstractmethod ...@@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
import os import os
import pickle import pickle
from typing import Generator, Dict, Tuple, Union, Any from typing import Generator, Dict, Tuple, Union, Any, List, Optional
import numpy as np import numpy as np
...@@ -70,9 +70,12 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): ...@@ -70,9 +70,12 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
self.fp8_meta["scaling_bwd"] = FP8TensorMeta(is_forward=False) self.fp8_meta["scaling_bwd"] = FP8TensorMeta(is_forward=False)
self.tp_group = None self.tp_group = None
self.tp_size = 1 self.tp_size = 1
self.sequence_parallel = False
self.fp8_meta["autocast_id_fwd_stack"] = [] self.fp8_meta["autocast_id_fwd_stack"] = []
self.fp8_meta["async_amax_reduction"] = bool( self.fp8_meta["async_amax_reduction"] = bool(
int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0"))) int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0")))
self.fp8_weight_shapes = []
self.fp8_weight_cache = {}
def set_activation_dtype(self, inp: paddle.Tensor) -> None: def set_activation_dtype(self, inp: paddle.Tensor) -> None:
"""Get activation data type for AMP.""" """Get activation data type for AMP."""
...@@ -140,6 +143,29 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): ...@@ -140,6 +143,29 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
self.fp8_initialized = False self.fp8_initialized = False
return return
def set_fp8_weights(self) -> None:
"""Initializes FP8 weights for the module"""
if not self.fp8_enabled:
return
for i, shape in enumerate(self.fp8_weight_shapes, start=1):
weight_cast_key = f"weight{i}_fp8"
weight_transpose_key = f"weight{i}_t_fp8"
if (weight_cast_key in self.fp8_weight_cache
and self.fp8_weight_cache[weight_cast_key].shape == shape):
return
self.fp8_weight_cache[weight_cast_key] = paddle.empty(
shape=shape,
dtype=paddle.uint8,
)
self.fp8_weight_cache[weight_transpose_key] = paddle.empty(
shape=[shape[1], shape[0]],
dtype=paddle.uint8,
)
def _get_fp8_state(self) -> paddle.Tensor: def _get_fp8_state(self) -> paddle.Tensor:
"""Dump FP8 state to paddle.Tensor.""" """Dump FP8 state to paddle.Tensor."""
state = None state = None
...@@ -218,6 +244,7 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): ...@@ -218,6 +244,7 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
def prepare_forward( def prepare_forward(
self, self,
inp: paddle.Tensor, inp: paddle.Tensor,
is_first_microbatch: Union[bool, None],
num_gemms: int = 1, num_gemms: int = 1,
) -> Generator[paddle.Tensor, None, None]: ) -> Generator[paddle.Tensor, None, None]:
"""Checks and prep for FWD. """Checks and prep for FWD.
...@@ -234,16 +261,32 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): ...@@ -234,16 +261,32 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
self.set_activation_dtype(inp) self.set_activation_dtype(inp)
self.fp8_init(num_gemms=num_gemms) self.fp8_init(num_gemms=num_gemms)
# Create persistent tensors for fp8 weights and their transposes
# only when fp8 weight caching is used.
if is_first_microbatch is not None:
self.set_fp8_weights()
if self.fp8_enabled and self.sequence_parallel:
assert self.fp8_meta["recipe"].reduce_amax, \
"Amax reduction across tensor parallel group is " \
"necessary when using sequence parallelism with FP8."
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
# Previous iteration was grad_enabled # Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False): if self.fp8_meta.get("update_amax_and_scale_fwd", False):
global_fp8_fwd_buffer = get_global_fp8_state().get_fp8_fwd_buffer() global_fp8_fwd_buffer = get_global_fp8_state().get_fp8_fwd_buffer()
global_fp8_fwd_buffer.wait() global_fp8_fwd_buffer.wait()
if self.fp8_meta["recipe"].reduce_amax: if self.fp8_meta["recipe"].reduce_amax:
global_fp8_fwd_buffer.copy_amax_from_buffer(self.fp8_meta) global_fp8_fwd_buffer.copy_amax_from_buffer(self.fp8_meta)
amax_and_scale_update(self.fp8_meta, True) amax_and_scale_update(self.fp8_meta,
True,
update_weight_scale_inv=update_weight_scale_inv)
global_fp8_fwd_buffer.set_for_deletion(self.fp8_meta) global_fp8_fwd_buffer.set_for_deletion(self.fp8_meta)
else: else:
amax_and_scale_update(self.fp8_meta, True) amax_and_scale_update(self.fp8_meta,
True,
update_weight_scale_inv=update_weight_scale_inv)
if self.fp8_enabled and self.training: if self.fp8_enabled and self.training:
# Setup for amax reduction # Setup for amax reduction
...@@ -383,3 +426,28 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): ...@@ -383,3 +426,28 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
@abstractmethod @abstractmethod
def forward(self): def forward(self):
"""Needs override.""" """Needs override."""
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[Optional[paddle.Tensor]]:
"""
Fetch the fp8 weight tensor placeholders if they exist (when
`is_first_microbatch` is not `None`)
"""
if not self.fp8_enabled or is_first_microbatch is None:
return [None, None] * len(self.fp8_weight_shapes)
out_list = []
for i, _ in enumerate(self.fp8_weight_shapes, start=1):
weight_cast_key = f"weight{i}_fp8"
weight_transpose_key = f"weight{i}_t_fp8"
assert weight_cast_key in self.fp8_weight_cache, \
"TE internal error: fp8 weight buffer is not found"
out_list.extend([
self.fp8_weight_cache[weight_cast_key],
self.fp8_weight_cache[weight_transpose_key],
])
return out_list
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""LayerNormLinear API""" """LayerNormLinear API"""
import warnings
import os import os
from typing import Union, Tuple, Dict, Any, Optional from typing import Union, Tuple, Dict, Any, Optional
...@@ -131,6 +132,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -131,6 +132,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ln_weight: paddle.Tensor, ln_weight: paddle.Tensor,
ln_bias: paddle.Tensor, ln_bias: paddle.Tensor,
weight: paddle.Tensor, weight: paddle.Tensor,
weight_fp8: Optional[paddle.Tensor],
weight_t_fp8: Optional[paddle.Tensor],
bias: Union[paddle.Tensor, None], bias: Union[paddle.Tensor, None],
use_bias: bool, use_bias: bool,
eps: float, eps: float,
...@@ -148,6 +151,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -148,6 +151,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
sequence_parallel: bool, sequence_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
tp_size: int, tp_size: int,
is_first_microbatch: bool,
) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]: ) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.shape[0] in_features = ln_weight.shape[0]
...@@ -182,6 +186,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -182,6 +186,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ln_out, ln_out,
FP8FwdTensors.GEMM1_INPUT, FP8FwdTensors.GEMM1_INPUT,
weight, weight,
weight_fp8,
weight_t_fp8,
FP8FwdTensors.GEMM1_WEIGHT, FP8FwdTensors.GEMM1_WEIGHT,
bias, bias,
use_bias, use_bias,
...@@ -194,6 +200,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -194,6 +200,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
sequence_parallel, sequence_parallel,
tp_group, tp_group,
is_grad_enabled, is_grad_enabled,
is_first_microbatch,
) )
if is_grad_enabled: if is_grad_enabled:
...@@ -227,6 +234,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -227,6 +234,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ctx.requires_bgrad = use_bias and not bias.stop_gradient ctx.requires_bgrad = use_bias and not bias.stop_gradient
ctx.requires_ln_bgrad = not ln_bias.stop_gradient ctx.requires_ln_bgrad = not ln_bias.stop_gradient
ctx.requires_ln_wgrad = not ln_weight.stop_gradient ctx.requires_ln_wgrad = not ln_weight.stop_gradient
ctx.is_first_microbatch = is_first_microbatch
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.reshape((-1, *inp.shape[1:-1], out.shape[-1])) out = out.reshape((-1, *inp.shape[1:-1], out.shape[-1]))
...@@ -320,11 +329,18 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -320,11 +329,18 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
bgrad = bgrad if ctx.requires_bgrad else None bgrad = bgrad if ctx.requires_bgrad else None
bgrad_out = (bgrad,) if ctx.use_bias else () bgrad_out = (bgrad,) if ctx.use_bias else ()
if not ctx.fp8_enabled or ctx.is_first_microbatch is None:
weight_cache_grad = ()
else:
# weight_fp8 and weight_t_fp8 are stop_gradient tensors
weight_cache_grad = (None, None)
return ( return (
dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None, dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma if ctx.requires_ln_wgrad else None, dgamma if ctx.requires_ln_wgrad else None,
dbeta if ctx.requires_ln_bgrad else None, dbeta if ctx.requires_ln_bgrad else None,
wgrad if ctx.requires_wgrad else None, wgrad if ctx.requires_wgrad else None,
*weight_cache_grad,
*bgrad_out, *bgrad_out,
) )
...@@ -447,6 +463,7 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -447,6 +463,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
) )
set_weight_tensor_dist_attr(self.weight, self.tensor_parallel, self.parallel_mode, set_weight_tensor_dist_attr(self.weight, self.tensor_parallel, self.parallel_mode,
self.backend) self.backend)
self.fp8_weight_shapes.append(self.weight.shape)
# Initialize Linear bias parameter # Initialize Linear bias parameter
self.has_bias = self._bias_attr is not False self.has_bias = self._bias_attr is not False
...@@ -483,21 +500,28 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -483,21 +500,28 @@ class LayerNormLinear(TransformerEngineBaseLayer):
def _te_forward( def _te_forward(
self, self,
inp: paddle.Tensor, inp: paddle.Tensor,
is_first_microbatch: Optional[bool] = None,
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]: ) -> Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]:
""" """
Apply layer normalization to the input followed by a linear transformation. Apply layer normalization to the input followed by a linear transformation.
""" """
with self.prepare_forward(inp) as inp: with self.prepare_forward(inp, is_first_microbatch=is_first_microbatch) as inp:
# Layer input should be casted outside PyLayer, as performing # Layer input should be casted outside PyLayer, as performing
# inplace cast to input tensors may cause problems when used # inplace cast to input tensors may cause problems when used
# together with Paddle native layers. # together with Paddle native layers.
inp = cast_if_needed(inp, self.activation_dtype) inp = cast_if_needed(inp, self.activation_dtype)
# Get persistent fp8 weight buffer. None if buffer does not exist.
weight_fp8, weight_t_fp8 = self.get_fp8_weights_scratchpad(is_first_microbatch)
out = _LayerNormLinear.apply( out = _LayerNormLinear.apply(
inp, inp,
self.ln_weight, self.ln_weight,
self.ln_bias, self.ln_bias,
self.weight, self.weight,
weight_fp8,
weight_t_fp8,
self.bias if self.gemm_bias_fused_add else None, self.bias if self.gemm_bias_fused_add else None,
self.has_bias and self.gemm_bias_fused_add, self.has_bias and self.gemm_bias_fused_add,
self.eps, self.eps,
...@@ -515,6 +539,7 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -515,6 +539,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
self.sequence_parallel, self.sequence_parallel,
self.tp_group, self.tp_group,
self.tp_size, self.tp_size,
is_first_microbatch,
) )
if self.return_layernorm_output: if self.return_layernorm_output:
...@@ -530,12 +555,17 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -530,12 +555,17 @@ class LayerNormLinear(TransformerEngineBaseLayer):
def _pd_forward( def _pd_forward(
self, self,
inp: paddle.Tensor, inp: paddle.Tensor,
is_first_microbatch: Optional[bool] = None,
) -> paddle.Tensor: ) -> paddle.Tensor:
"""Calls Paddle OP""" """Calls Paddle OP"""
if self.zero_centered_gamma: if self.zero_centered_gamma:
raise NotImplementedError( raise NotImplementedError(
"Paddle backend does not support LayerNorm with zero-centered scale.") "Paddle backend does not support LayerNorm with zero-centered scale.")
if is_first_microbatch is not None:
warnings.warn(
"`is_first_microbatch` is not supported for paddle backend and is ignored.")
ln_out = F.layer_norm(x=inp, ln_out = F.layer_norm(x=inp,
normalized_shape=inp.shape[-1], normalized_shape=inp.shape[-1],
weight=self.ln_weight, weight=self.ln_weight,
...@@ -557,8 +587,18 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -557,8 +587,18 @@ class LayerNormLinear(TransformerEngineBaseLayer):
Parameters Parameters
---------- ----------
inp : torch.Tensor inp : paddle.Tensor
Input tensor. Input tensor.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
""" """
if self.backend == 'transformer_engine': if self.backend == 'transformer_engine':
return self._te_forward(*args, **kwargs) return self._te_forward(*args, **kwargs)
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""LayerNormMLP API""" """LayerNormMLP API"""
import os import os
import warnings
from typing import Union, Tuple, Dict, Any, Optional from typing import Union, Tuple, Dict, Any, Optional
import paddle import paddle
...@@ -46,11 +47,15 @@ def _mlp_forward( ...@@ -46,11 +47,15 @@ def _mlp_forward(
inputmat: paddle.Tensor, inputmat: paddle.Tensor,
inputmat_fp8_index: FP8FwdTensors, inputmat_fp8_index: FP8FwdTensors,
fc1_weight: paddle.Tensor, fc1_weight: paddle.Tensor,
fc1_weight_fp8: Optional[paddle.Tensor],
fc1_weight_t_fp8: Optional[paddle.Tensor],
fc1_weight_fp8_index: FP8FwdTensors, fc1_weight_fp8_index: FP8FwdTensors,
fc1_bias: Union[paddle.Tensor, None], fc1_bias: Union[paddle.Tensor, None],
use_fc1_bias: bool, use_fc1_bias: bool,
fc2_input_fp8_index: FP8FwdTensors, # FP8FwdTensors.GEMM2_INPUT fc2_input_fp8_index: FP8FwdTensors, # FP8FwdTensors.GEMM2_INPUT
fc2_weight: paddle.Tensor, fc2_weight: paddle.Tensor,
fc2_weight_fp8: Optional[paddle.Tensor],
fc2_weight_t_fp8: Optional[paddle.Tensor],
fc2_weight_fp8_index: FP8FwdTensors, fc2_weight_fp8_index: FP8FwdTensors,
fc2_bias: Union[paddle.Tensor, None], fc2_bias: Union[paddle.Tensor, None],
use_fc2_bias: bool, use_fc2_bias: bool,
...@@ -64,6 +69,7 @@ def _mlp_forward( ...@@ -64,6 +69,7 @@ def _mlp_forward(
tensor_parallel: bool, tensor_parallel: bool,
sequence_parallel: bool, sequence_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
is_first_microbatch: bool,
): ):
if fp8_enabled: if fp8_enabled:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
...@@ -71,6 +77,8 @@ def _mlp_forward( ...@@ -71,6 +77,8 @@ def _mlp_forward(
inputmat, inputmat,
inputmat_fp8_index, inputmat_fp8_index,
fc1_weight, fc1_weight,
fc1_weight_fp8,
fc1_weight_t_fp8,
fc1_weight_fp8_index, fc1_weight_fp8_index,
fc1_bias, fc1_bias,
use_fc1_bias, use_fc1_bias,
...@@ -81,6 +89,7 @@ def _mlp_forward( ...@@ -81,6 +89,7 @@ def _mlp_forward(
sequence_parallel, sequence_parallel,
tp_group, tp_group,
is_grad_enabled, is_grad_enabled,
is_first_microbatch,
) )
gelu_out = gelu_fp8( gelu_out = gelu_fp8(
...@@ -94,6 +103,8 @@ def _mlp_forward( ...@@ -94,6 +103,8 @@ def _mlp_forward(
gelu_out, gelu_out,
fc2_input_fp8_index, fc2_input_fp8_index,
fc2_weight, fc2_weight,
fc2_weight_fp8,
fc2_weight_t_fp8,
fc2_weight_fp8_index, fc2_weight_fp8_index,
fc2_bias, fc2_bias,
use_fc2_bias, use_fc2_bias,
...@@ -104,6 +115,7 @@ def _mlp_forward( ...@@ -104,6 +115,7 @@ def _mlp_forward(
sequence_parallel, sequence_parallel,
tp_group, tp_group,
is_grad_enabled, is_grad_enabled,
is_first_microbatch,
) )
else: else:
fc1_out, gelu_out = _linear_fwd_non_fp8( fc1_out, gelu_out = _linear_fwd_non_fp8(
...@@ -321,9 +333,13 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -321,9 +333,13 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
ln_weight: paddle.Tensor, ln_weight: paddle.Tensor,
ln_bias: paddle.Tensor, ln_bias: paddle.Tensor,
fc1_weight: paddle.Tensor, fc1_weight: paddle.Tensor,
fc1_weight_fp8: Optional[paddle.Tensor],
fc1_weight_t_fp8: Optional[paddle.Tensor],
fc1_bias: Union[paddle.Tensor, None], fc1_bias: Union[paddle.Tensor, None],
use_fc1_bias: bool, use_fc1_bias: bool,
fc2_weight: paddle.Tensor, fc2_weight: paddle.Tensor,
fc2_weight_fp8: Optional[paddle.Tensor],
fc2_weight_t_fp8: Optional[paddle.Tensor],
fc2_bias: Union[paddle.Tensor, None], fc2_bias: Union[paddle.Tensor, None],
use_fc2_bias: bool, use_fc2_bias: bool,
eps: float, eps: float,
...@@ -342,6 +358,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -342,6 +358,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
sequence_parallel: bool, sequence_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
tp_size: int, tp_size: int,
is_first_microbatch: bool,
) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]: ) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.shape[0] in_features = ln_weight.shape[0]
...@@ -385,11 +402,15 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -385,11 +402,15 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
ln_out, ln_out,
FP8FwdTensors.GEMM1_INPUT, FP8FwdTensors.GEMM1_INPUT,
fc1_weight, fc1_weight,
fc1_weight_fp8,
fc1_weight_t_fp8,
FP8FwdTensors.GEMM1_WEIGHT, FP8FwdTensors.GEMM1_WEIGHT,
fc1_bias, fc1_bias,
use_fc1_bias, use_fc1_bias,
FP8FwdTensors.GEMM2_INPUT, FP8FwdTensors.GEMM2_INPUT,
fc2_weight, fc2_weight,
fc2_weight_fp8,
fc2_weight_t_fp8,
FP8FwdTensors.GEMM2_WEIGHT, FP8FwdTensors.GEMM2_WEIGHT,
fc2_bias, fc2_bias,
use_fc2_bias, use_fc2_bias,
...@@ -403,6 +424,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -403,6 +424,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
tensor_parallel, tensor_parallel,
sequence_parallel, sequence_parallel,
tp_group, tp_group,
is_first_microbatch,
) )
if is_grad_enabled: if is_grad_enabled:
...@@ -443,6 +465,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -443,6 +465,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
ctx.requires_fc2_bgrad = use_fc2_bias and not fc2_bias.stop_gradient ctx.requires_fc2_bgrad = use_fc2_bias and not fc2_bias.stop_gradient
ctx.requires_ln_bgrad = not ln_bias.stop_gradient ctx.requires_ln_bgrad = not ln_bias.stop_gradient
ctx.requires_ln_wgrad = not ln_weight.stop_gradient ctx.requires_ln_wgrad = not ln_weight.stop_gradient
ctx.is_first_microbatch = is_first_microbatch
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
fc2_out = fc2_out.reshape((-1, *inp.shape[1:-1], fc2_out.shape[-1])) fc2_out = fc2_out.reshape((-1, *inp.shape[1:-1], fc2_out.shape[-1]))
...@@ -543,13 +566,23 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -543,13 +566,23 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
fc1_bgrad_out = (fc1_bgrad,) if ctx.use_fc1_bias else () fc1_bgrad_out = (fc1_bgrad,) if ctx.use_fc1_bias else ()
fc2_bgrad_out = (fc2_bgrad,) if ctx.use_fc2_bias else () fc2_bgrad_out = (fc2_bgrad,) if ctx.use_fc2_bias else ()
if not ctx.fp8_enabled or ctx.is_first_microbatch is None:
fc1_weight_cache_grad = ()
fc2_weight_cache_grad = ()
else:
# weight_fp8 and weight_t_fp8 are stop_gradient tensors
fc1_weight_cache_grad = (None, None)
fc2_weight_cache_grad = (None, None)
return ( return (
dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None, dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma if ctx.requires_ln_wgrad else None, dgamma if ctx.requires_ln_wgrad else None,
dbeta if ctx.requires_ln_bgrad else None, dbeta if ctx.requires_ln_bgrad else None,
fc1_wgrad if ctx.requires_fc1_wgrad else None, fc1_wgrad if ctx.requires_fc1_wgrad else None,
*fc1_weight_cache_grad,
*fc1_bgrad_out, *fc1_bgrad_out,
fc2_wgrad if ctx.requires_fc2_wgrad else None, fc2_wgrad if ctx.requires_fc2_wgrad else None,
*fc2_weight_cache_grad,
*fc2_bgrad_out, *fc2_bgrad_out,
) )
...@@ -675,6 +708,7 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -675,6 +708,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
self.tensor_parallel, self.tensor_parallel,
parallel_mode='column', parallel_mode='column',
backend=self.backend) backend=self.backend)
self.fp8_weight_shapes.append(self.fc1_weight.shape)
self.has_bias = self._bias_attr is not False self.has_bias = self._bias_attr is not False
use_default_bias = self._bias_attr is None or self._bias_attr is True use_default_bias = self._bias_attr is None or self._bias_attr is True
...@@ -704,6 +738,7 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -704,6 +738,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
self.tensor_parallel, self.tensor_parallel,
parallel_mode='row', parallel_mode='row',
backend=self.backend) backend=self.backend)
self.fp8_weight_shapes.append(self.fc2_weight.shape)
if self.has_bias: if self.has_bias:
self.fc2_bias = self.create_parameter( self.fc2_bias = self.create_parameter(
...@@ -734,24 +769,34 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -734,24 +769,34 @@ class LayerNormMLP(TransformerEngineBaseLayer):
def _te_forward( def _te_forward(
self, self,
inp: paddle.Tensor, inp: paddle.Tensor,
is_first_microbatch: Optional[bool] = None,
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]: ) -> Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]:
""" """
Apply layer normalization to the input followed by a linear transformation. Apply layer normalization to the input followed by a linear transformation.
""" """
with self.prepare_forward(inp, num_gemms=2) as inp: with self.prepare_forward(inp, num_gemms=2, is_first_microbatch=is_first_microbatch) as inp:
# Layer input should be casted outside PyLayer, as performing # Layer input should be casted outside PyLayer, as performing
# inplace cast to input tensors may cause problems when used # inplace cast to input tensors may cause problems when used
# together with Paddle native layers. # together with Paddle native layers.
inp = cast_if_needed(inp, self.activation_dtype) inp = cast_if_needed(inp, self.activation_dtype)
# Get persistent fp8 weight buffer. None if buffer does not exist.
fc1_weight_fp8, fc1_weight_t_fp8, fc2_weight_fp8, fc2_weight_t_fp8 = \
self.get_fp8_weights_scratchpad(is_first_microbatch)
out = _LayerNormMLP.apply( out = _LayerNormMLP.apply(
inp, inp,
self.ln_weight, self.ln_weight,
self.ln_bias, self.ln_bias,
self.fc1_weight, self.fc1_weight,
fc1_weight_fp8,
fc1_weight_t_fp8,
self.fc1_bias, self.fc1_bias,
self.has_bias, self.has_bias,
self.fc2_weight, self.fc2_weight,
fc2_weight_fp8,
fc2_weight_t_fp8,
self.fc2_bias, self.fc2_bias,
self.has_bias, self.has_bias,
self.eps, self.eps,
...@@ -770,6 +815,7 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -770,6 +815,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
self.sequence_parallel, self.sequence_parallel,
self.tp_group, self.tp_group,
self.tp_size, self.tp_size,
is_first_microbatch,
) )
if self.return_layernorm_output: if self.return_layernorm_output:
...@@ -785,12 +831,17 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -785,12 +831,17 @@ class LayerNormMLP(TransformerEngineBaseLayer):
def _pd_forward( def _pd_forward(
self, self,
inp: paddle.Tensor, inp: paddle.Tensor,
is_first_microbatch: Optional[bool] = None,
) -> paddle.Tensor: ) -> paddle.Tensor:
"""Calls Paddle OP""" """Calls Paddle OP"""
if self.zero_centered_gamma: if self.zero_centered_gamma:
raise NotImplementedError( raise NotImplementedError(
"Paddle backend does not support LayerNorm with zero-centered scale.") "Paddle backend does not support LayerNorm with zero-centered scale.")
if is_first_microbatch is not None:
warnings.warn(
"`is_first_microbatch` is not supported for paddle backend and is ignored.")
ln_out = F.layer_norm(x=inp, ln_out = F.layer_norm(x=inp,
normalized_shape=inp.shape[-1], normalized_shape=inp.shape[-1],
weight=self.ln_weight, weight=self.ln_weight,
...@@ -816,8 +867,18 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -816,8 +867,18 @@ class LayerNormMLP(TransformerEngineBaseLayer):
Parameters Parameters
---------- ----------
inp : torch.Tensor inp : paddle.Tensor
Input tensor. Input tensor.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
""" """
if self.backend == 'transformer_engine': if self.backend == 'transformer_engine':
return self._te_forward(*args, **kwargs) return self._te_forward(*args, **kwargs)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Linear API""" """Linear API"""
import warnings
from typing import Union, Tuple, Dict, Any, Optional from typing import Union, Tuple, Dict, Any, Optional
import paddle import paddle
...@@ -39,6 +40,7 @@ from ..utils import ( ...@@ -39,6 +40,7 @@ from ..utils import (
get_bias_dtype, get_bias_dtype,
save_for_backward_allow_none, save_for_backward_allow_none,
saved_tensor_allow_none, saved_tensor_allow_none,
clear_tensor_data,
) )
__all__ = ["Linear"] __all__ = ["Linear"]
...@@ -48,6 +50,8 @@ def _linear_fwd_fp8( ...@@ -48,6 +50,8 @@ def _linear_fwd_fp8(
inputmat: paddle.Tensor, inputmat: paddle.Tensor,
inputmat_fp8_index: FP8FwdTensors, inputmat_fp8_index: FP8FwdTensors,
weight: paddle.Tensor, weight: paddle.Tensor,
weight_fp8: Optional[paddle.Tensor],
weight_t_fp8: Optional[paddle.Tensor],
weight_fp8_index: FP8FwdTensors, weight_fp8_index: FP8FwdTensors,
bias: paddle.Tensor, bias: paddle.Tensor,
use_bias: bool, use_bias: bool,
...@@ -58,6 +62,7 @@ def _linear_fwd_fp8( ...@@ -58,6 +62,7 @@ def _linear_fwd_fp8(
sequence_parallel: bool, sequence_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
is_grad_enabled: bool, is_grad_enabled: bool,
is_first_microbatch: bool = None,
): ):
"""FP8 path of Linear Fwd""" """FP8 path of Linear Fwd"""
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
...@@ -69,20 +74,26 @@ def _linear_fwd_fp8( ...@@ -69,20 +74,26 @@ def _linear_fwd_fp8(
else: else:
inputmat_total = inputmat inputmat_total = inputmat
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
if is_grad_enabled: if is_grad_enabled:
if update_fp8_weights:
weight_fp8, weight_t_fp8 = cast_transpose( weight_fp8, weight_t_fp8 = cast_transpose(
weight, weight,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
weight_fp8_index, weight_fp8_index,
fp8_dtype_forward, fp8_dtype_forward,
cast_out=weight_fp8,
transpose_out=weight_t_fp8,
) )
else: else:
weight_t_fp8 = None weight_t_fp8 = None
if update_fp8_weights:
weight_fp8 = cast_to_fp8( weight_fp8 = cast_to_fp8(
weight, weight,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
weight_fp8_index, weight_fp8_index,
fp8_dtype_forward, fp8_dtype_forward,
out=weight_fp8,
) )
out = fp8_gemm( out = fp8_gemm(
...@@ -146,6 +157,7 @@ def _linear_fwd_non_fp8( ...@@ -146,6 +157,7 @@ def _linear_fwd_non_fp8(
# amax of weight # amax of weight
fp8_meta["scaling_fwd"].amax_history[0, weight_fp8_index.value] = \ fp8_meta["scaling_fwd"].amax_history[0, weight_fp8_index.value] = \
paddle.max(paddle.abs(weight)).item() paddle.max(paddle.abs(weight)).item()
fp8_meta["update_amax_and_scale_fwd"] = True
outputs = gemm(weight, outputs = gemm(weight,
inputmat_total, inputmat_total,
...@@ -172,6 +184,8 @@ def _linear_fwd( ...@@ -172,6 +184,8 @@ def _linear_fwd(
inputmat: paddle.Tensor, inputmat: paddle.Tensor,
inputmat_fp8_index: FP8FwdTensors, inputmat_fp8_index: FP8FwdTensors,
weight: paddle.Tensor, weight: paddle.Tensor,
weight_fp8: Optional[paddle.Tensor],
weight_t_fp8: Optional[paddle.Tensor],
weight_fp8_index: FP8FwdTensors, weight_fp8_index: FP8FwdTensors,
bias: paddle.Tensor, bias: paddle.Tensor,
use_bias: bool, use_bias: bool,
...@@ -184,12 +198,15 @@ def _linear_fwd( ...@@ -184,12 +198,15 @@ def _linear_fwd(
sequence_parallel: bool, sequence_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
is_grad_enabled: bool, is_grad_enabled: bool,
is_first_microbatch: bool = None,
): ):
if fp8_enabled: if fp8_enabled:
out, weight_t_fp8 = _linear_fwd_fp8( out, weight_t_fp8 = _linear_fwd_fp8(
inputmat, inputmat,
inputmat_fp8_index, inputmat_fp8_index,
weight, weight,
weight_fp8,
weight_t_fp8,
weight_fp8_index, weight_fp8_index,
bias, bias,
use_bias, use_bias,
...@@ -200,6 +217,7 @@ def _linear_fwd( ...@@ -200,6 +217,7 @@ def _linear_fwd(
sequence_parallel, sequence_parallel,
tp_group, tp_group,
is_grad_enabled, is_grad_enabled,
is_first_microbatch,
) )
else: else:
out = _linear_fwd_non_fp8( out = _linear_fwd_non_fp8(
...@@ -270,6 +288,7 @@ def _linear_bwd_fp8( ...@@ -270,6 +288,7 @@ def _linear_bwd_fp8(
get_workspace(), get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
) )
clear_tensor_data(grad_output_c)
# Overlap dgrad-RS/AR with wgrad # Overlap dgrad-RS/AR with wgrad
if parallel_mode == "column" and sequence_parallel: if parallel_mode == "column" and sequence_parallel:
...@@ -283,6 +302,7 @@ def _linear_bwd_fp8( ...@@ -283,6 +302,7 @@ def _linear_bwd_fp8(
if not fp8_meta["recipe"].override_linear_precision.wgrad: if not fp8_meta["recipe"].override_linear_precision.wgrad:
if inputmat_t_total is None: if inputmat_t_total is None:
inputmat_t_total = transpose(inputmat_total, fp8_dtype_backward) inputmat_t_total = transpose(inputmat_total, fp8_dtype_backward)
clear_tensor_data(inputmat_total)
wgrad = fp8_gemm( wgrad = fp8_gemm(
inputmat_t_total, inputmat_t_total,
fwd_scale_inverses, fwd_scale_inverses,
...@@ -296,6 +316,7 @@ def _linear_bwd_fp8( ...@@ -296,6 +316,7 @@ def _linear_bwd_fp8(
get_workspace(), get_workspace(),
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
) )
clear_tensor_data(inputmat_t_total, grad_output_t)
else: else:
wgrad, _, _ = gemm( wgrad, _, _ = gemm(
inputmat_total, inputmat_total,
...@@ -305,6 +326,7 @@ def _linear_bwd_fp8( ...@@ -305,6 +326,7 @@ def _linear_bwd_fp8(
layout="NT", layout="NT",
grad=True, grad=True,
) )
clear_tensor_data(inputmat_total)
if parallel_mode == "column" and tensor_parallel and handle is not None: if parallel_mode == "column" and tensor_parallel and handle is not None:
handle.wait() handle.wait()
...@@ -446,6 +468,8 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -446,6 +468,8 @@ class _Linear(paddle.autograd.PyLayer):
def forward( def forward(
ctx, ctx,
weight: paddle.Tensor, weight: paddle.Tensor,
weight_fp8: Optional[paddle.Tensor],
weight_t_fp8: Optional[paddle.Tensor],
inp: paddle.Tensor, inp: paddle.Tensor,
bias: paddle.Tensor, bias: paddle.Tensor,
use_bias: bool, use_bias: bool,
...@@ -459,6 +483,7 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -459,6 +483,7 @@ class _Linear(paddle.autograd.PyLayer):
sequence_parallel: bool, sequence_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
tp_size: int, tp_size: int,
is_first_microbatch: bool,
) -> paddle.Tensor: ) -> paddle.Tensor:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = weight.shape[-1] in_features = weight.shape[-1]
...@@ -495,6 +520,8 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -495,6 +520,8 @@ class _Linear(paddle.autograd.PyLayer):
inputmat, inputmat,
FP8FwdTensors.GEMM1_INPUT, FP8FwdTensors.GEMM1_INPUT,
weight, weight,
weight_fp8,
weight_t_fp8,
FP8FwdTensors.GEMM1_WEIGHT, FP8FwdTensors.GEMM1_WEIGHT,
bias, bias,
use_bias, use_bias,
...@@ -507,6 +534,7 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -507,6 +534,7 @@ class _Linear(paddle.autograd.PyLayer):
sequence_parallel, sequence_parallel,
tp_group, tp_group,
is_grad_enabled, is_grad_enabled,
is_first_microbatch,
) )
if is_grad_enabled: if is_grad_enabled:
...@@ -536,6 +564,7 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -536,6 +564,7 @@ class _Linear(paddle.autograd.PyLayer):
ctx.requires_dgrad = not inp.stop_gradient ctx.requires_dgrad = not inp.stop_gradient
ctx.requires_wgrad = not weight.stop_gradient ctx.requires_wgrad = not weight.stop_gradient
ctx.requires_bgrad = use_bias and not bias.stop_gradient ctx.requires_bgrad = use_bias and not bias.stop_gradient
ctx.is_first_microbatch = is_first_microbatch
return out.reshape((-1, *inp.shape[1:-1], out.shape[-1])) return out.reshape((-1, *inp.shape[1:-1], out.shape[-1]))
...@@ -591,14 +620,22 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -591,14 +620,22 @@ class _Linear(paddle.autograd.PyLayer):
# bgrad is fused with gemm for non-FP8 path # bgrad is fused with gemm for non-FP8 path
bgrad = bgrad_ bgrad = bgrad_
if not ctx.fp8_enabled or ctx.is_first_microbatch is None:
weight_cache_grad = ()
else:
# weight_fp8 and weight_t_fp8 are stop_gradient tensors
weight_cache_grad = (None, None)
if not ctx.use_bias: if not ctx.use_bias:
return ( return (
wgrad if ctx.requires_wgrad else None, wgrad if ctx.requires_wgrad else None,
*weight_cache_grad,
dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None, dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
) )
return ( return (
wgrad if ctx.requires_wgrad else None, wgrad if ctx.requires_wgrad else None,
*weight_cache_grad,
dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None, dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
bgrad if ctx.requires_bgrad else None, bgrad if ctx.requires_bgrad else None,
) )
...@@ -699,6 +736,8 @@ class Linear(TransformerEngineBaseLayer): ...@@ -699,6 +736,8 @@ class Linear(TransformerEngineBaseLayer):
else: else:
self.bias = None self.bias = None
self.fp8_weight_shapes.append(self.weight.shape)
# For RPL, bias has to be added after TP collectives # For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM # So it cannot be fused with the GEMM
if self.parallel_mode == "row" and self.tensor_parallel and self.has_bias: if self.parallel_mode == "row" and self.tensor_parallel and self.has_bias:
...@@ -709,17 +748,24 @@ class Linear(TransformerEngineBaseLayer): ...@@ -709,17 +748,24 @@ class Linear(TransformerEngineBaseLayer):
def _te_forward( def _te_forward(
self, self,
inp: paddle.Tensor, inp: paddle.Tensor,
is_first_microbatch: Optional[bool] = None,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Apply the linear transformation to the input. Apply the linear transformation to the input.
""" """
with self.prepare_forward(inp) as inp: with self.prepare_forward(inp, is_first_microbatch=is_first_microbatch) as inp:
# Layer input should be casted outside PyLayer, as performing # Layer input should be casted outside PyLayer, as performing
# inplace cast to input tensors may cause problems when used # inplace cast to input tensors may cause problems when used
# together with Paddle native layers. # together with Paddle native layers.
inp = cast_if_needed(inp, self.activation_dtype) inp = cast_if_needed(inp, self.activation_dtype)
# Get persistent fp8 weight buffer. None if buffer does not exist.
weight_fp8, weight_t_fp8 = self.get_fp8_weights_scratchpad(is_first_microbatch)
out = _Linear.apply( out = _Linear.apply(
self.weight, self.weight,
weight_fp8,
weight_t_fp8,
inp, inp,
self.bias if self.gemm_bias_fused_add else None, self.bias if self.gemm_bias_fused_add else None,
self.has_bias and self.gemm_bias_fused_add, self.has_bias and self.gemm_bias_fused_add,
...@@ -733,6 +779,7 @@ class Linear(TransformerEngineBaseLayer): ...@@ -733,6 +779,7 @@ class Linear(TransformerEngineBaseLayer):
self.sequence_parallel, self.sequence_parallel,
self.tp_group, self.tp_group,
self.tp_size, self.tp_size,
is_first_microbatch,
) )
if not self.gemm_bias_fused_add: if not self.gemm_bias_fused_add:
...@@ -743,8 +790,12 @@ class Linear(TransformerEngineBaseLayer): ...@@ -743,8 +790,12 @@ class Linear(TransformerEngineBaseLayer):
def _pd_forward( def _pd_forward(
self, self,
inp: paddle.Tensor, inp: paddle.Tensor,
is_first_microbatch: Optional[bool] = None,
) -> paddle.Tensor: ) -> paddle.Tensor:
"""Calls Paddle OP""" """Calls Paddle OP"""
if is_first_microbatch is not None:
warnings.warn(
"`is_first_microbatch` is not supported for paddle backend and is ignored.")
if self.parallel_mode == 'column' and self.tensor_parallel: if self.parallel_mode == 'column' and self.tensor_parallel:
inp = identity(inp, self.tp_group) inp = identity(inp, self.tp_group)
out = F.linear(inp, self.weight, self.bias if self.gemm_bias_fused_add else None) out = F.linear(inp, self.weight, self.bias if self.gemm_bias_fused_add else None)
...@@ -759,8 +810,18 @@ class Linear(TransformerEngineBaseLayer): ...@@ -759,8 +810,18 @@ class Linear(TransformerEngineBaseLayer):
Parameters Parameters
---------- ----------
inp : torch.Tensor inp : paddle.Tensor
Input tensor. Input tensor.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
""" """
if self.backend == 'transformer_engine': if self.backend == 'transformer_engine':
return self._te_forward(*args, **kwargs) return self._te_forward(*args, **kwargs)
......
...@@ -217,6 +217,7 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -217,6 +217,7 @@ class TransformerLayer(paddle.nn.Layer):
core_attention_bias: Optional[paddle.Tensor] = None, core_attention_bias: Optional[paddle.Tensor] = None,
set_zero: bool = True, set_zero: bool = True,
recompute_core_attention: bool = False, recompute_core_attention: bool = False,
is_first_microbatch: Optional[bool] = None,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Transformer Layer: attention block and a feedforward network (MLP) Transformer Layer: attention block and a feedforward network (MLP)
...@@ -248,6 +249,16 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -248,6 +249,16 @@ class TransformerLayer(paddle.nn.Layer):
during the backward pass in order to save memory that would during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until otherwise be occupied to store the forward activations until
backprop. backprop.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
""" """
if self.self_attn_mask_type != "causal" and attention_mask is not None: if self.self_attn_mask_type != "causal" and attention_mask is not None:
...@@ -264,6 +275,7 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -264,6 +275,7 @@ class TransformerLayer(paddle.nn.Layer):
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
set_zero=set_zero, set_zero=set_zero,
recompute_core_attention=recompute_core_attention, recompute_core_attention=recompute_core_attention,
is_first_microbatch=is_first_microbatch,
) )
if self.apply_residual_connection_post_layernorm and not self.output_layernorm: if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
...@@ -286,6 +298,7 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -286,6 +298,7 @@ class TransformerLayer(paddle.nn.Layer):
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
set_zero=set_zero, set_zero=set_zero,
recompute_core_attention=recompute_core_attention, recompute_core_attention=recompute_core_attention,
is_first_microbatch=is_first_microbatch,
) )
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
attention_output, residual = inter_attention_outputs attention_output, residual = inter_attention_outputs
...@@ -298,7 +311,7 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -298,7 +311,7 @@ class TransformerLayer(paddle.nn.Layer):
bda_output = self.fused_dropout_add2(attention_output, residual) bda_output = self.fused_dropout_add2(attention_output, residual)
# MLP. # MLP.
mlp_outputs = self.layernorm_mlp(bda_output) mlp_outputs = self.layernorm_mlp(bda_output, is_first_microbatch=is_first_microbatch)
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
mlp_output, residual = mlp_outputs mlp_output, residual = mlp_outputs
else: else:
......
...@@ -121,3 +121,17 @@ def saved_tensor_allow_none(ctx) -> Tuple[Optional[paddle.Tensor]]: ...@@ -121,3 +121,17 @@ def saved_tensor_allow_none(ctx) -> Tuple[Optional[paddle.Tensor]]:
outputs.append(saved_tensors[index]) outputs.append(saved_tensors[index])
return tuple(outputs) return tuple(outputs)
def clear_tensor_data(*tensors: Tuple[Optional[paddle.Tensor], ...]) -> None:
"""
Free tensor buffer
"""
def can_free(t):
return (t is not None and isinstance(t, paddle.Tensor) and t._is_initialized()
and t.inplace_version == 0)
for t in tensors:
if can_free(t):
t._clear_dataptr()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment