Unverified Commit 95ec1560 authored by Tian Zheng's avatar Tian Zheng Committed by GitHub
Browse files

[Paddle] Add FP8 support for nn Layers (#333)



* Add FP8 support

- Add FP8 recipe
- Add FP8 path for nn layers
- Add MNIST FP8 example
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Update README
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix LayerNormMLP FP8 backward
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix FP8 training in float32 accumulation
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix FP8 checkpointing for non forward execution cases (same as #323)
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Refactors and improvements for better code stype, readability and organization
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Remove unnecassary pylint override
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

---------
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
parent 054a9c17
# Basic MNIST Example (BF16) # Basic MNIST Example
```bash ```bash
python test_single_gpu_mnist.py python test_single_gpu_mnist.py
python test_single_gpu_mnist.py --use-te # Linear layers from TransformerEngine python test_single_gpu_mnist.py --use-te # Linear layers from TransformerEngine
python test_single_gpu_mnist.py --use-te --use-fp8 # FP8 + TransformerEngine for Linear layers
``` ```
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""MNIST example of Transformer Engine Paddle""" """MNIST example of Transformer Engine Paddle"""
import argparse import argparse
import os
import unittest import unittest
import paddle import paddle
...@@ -16,6 +17,7 @@ from paddle.vision.datasets import MNIST ...@@ -16,6 +17,7 @@ from paddle.vision.datasets import MNIST
from paddle.metric import Accuracy from paddle.metric import Accuracy
import transformer_engine.paddle as te import transformer_engine.paddle as te
from transformer_engine.paddle.fp8 import is_fp8_available
class Net(nn.Layer): class Net(nn.Layer):
...@@ -52,12 +54,13 @@ class Net(nn.Layer): ...@@ -52,12 +54,13 @@ class Net(nn.Layer):
return x return x
def train(args, model, train_loader, optimizer, epoch): def train(args, model, train_loader, optimizer, epoch, use_fp8):
"""Training function.""" """Training function."""
model.train() model.train()
for batch_id, (data, labels) in enumerate(train_loader): for batch_id, (data, labels) in enumerate(train_loader):
with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager
outputs = model(data) with te.fp8_autocast(enabled=use_fp8):
outputs = model(data)
loss = F.cross_entropy(outputs, labels) loss = F.cross_entropy(outputs, labels)
loss.backward() loss.backward()
...@@ -74,7 +77,7 @@ def train(args, model, train_loader, optimizer, epoch): ...@@ -74,7 +77,7 @@ def train(args, model, train_loader, optimizer, epoch):
return loss.item() return loss.item()
def evaluate(model, test_loader, epoch): def evaluate(model, test_loader, epoch, use_fp8):
"""Testing function.""" """Testing function."""
model.eval() model.eval()
metric = Accuracy() metric = Accuracy()
...@@ -83,13 +86,25 @@ def evaluate(model, test_loader, epoch): ...@@ -83,13 +86,25 @@ def evaluate(model, test_loader, epoch):
with paddle.no_grad(): with paddle.no_grad():
for data, labels in test_loader: for data, labels in test_loader:
with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager
outputs = model(data) with te.fp8_autocast(enabled=use_fp8):
outputs = model(data)
acc = metric.compute(outputs, labels) acc = metric.compute(outputs, labels)
metric.update(acc) metric.update(acc)
print(f"Epoch[{epoch}] - accuracy: {metric.accumulate():.6f}") print(f"Epoch[{epoch}] - accuracy: {metric.accumulate():.6f}")
return metric.accumulate() return metric.accumulate()
def calibrate(model, test_loader):
"""Calibration function."""
model.eval()
with paddle.no_grad():
for data, _ in test_loader:
with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager
with te.fp8_autocast(enabled=False, calibrating=True):
_ = model(data)
def mnist_parser(args): def mnist_parser(args):
"""Parse training settings""" """Parse training settings"""
parser = argparse.ArgumentParser(description="Paddle MNIST Example") parser = argparse.ArgumentParser(description="Paddle MNIST Example")
...@@ -127,6 +142,12 @@ def mnist_parser(args): ...@@ -127,6 +142,12 @@ def mnist_parser(args):
default=False, default=False,
help="quickly check a single pass", help="quickly check a single pass",
) )
parser.add_argument(
"--save-model",
action="store_true",
default=False,
help="For Saving the current Model",
)
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument( parser.add_argument(
"--log-interval", "--log-interval",
...@@ -135,11 +156,22 @@ def mnist_parser(args): ...@@ -135,11 +156,22 @@ def mnist_parser(args):
metavar="N", metavar="N",
help="how many batches to wait before logging training status", help="how many batches to wait before logging training status",
) )
parser.add_argument("--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration. " \
"It also enables Transformer Engine implicitly.")
parser.add_argument("--use-fp8-infer",
action="store_true",
default=False,
help="Use FP8 for inference only. If not using FP8 for training, "
"calibration is performed for FP8 infernece.")
parser.add_argument("--use-te", parser.add_argument("--use-te",
action="store_true", action="store_true",
default=False, default=False,
help="Use Transformer Engine") help="Use Transformer Engine")
return parser.parse_args(args) args = parser.parse_args(args)
return args
def train_and_evaluate(args): def train_and_evaluate(args):
...@@ -165,8 +197,18 @@ def train_and_evaluate(args): ...@@ -165,8 +197,18 @@ def train_and_evaluate(args):
model = paddle.amp.decorate(models=model, level='O2', dtype='bfloat16') model = paddle.amp.decorate(models=model, level='O2', dtype='bfloat16')
for epoch in range(1, args.epochs + 1): for epoch in range(1, args.epochs + 1):
loss = train(args, model, train_loader, optimizer, epoch) loss = train(args, model, train_loader, optimizer, epoch, args.use_fp8)
acc = evaluate(model, val_loader, epoch) acc = evaluate(model, val_loader, epoch, args.use_fp8)
if args.use_fp8_infer and not args.use_fp8:
calibrate(model, val_loader)
if args.save_model or args.use_fp8_infer:
paddle.save(model.state_dict(), "mnist_cnn.pdparams")
print('Eval with reloaded checkpoint : fp8=' + str(args.use_fp8))
weights = paddle.load("mnist_cnn.pdparams")
model.set_state_dict(weights)
acc = evaluate(model, val_loader, 0, args.use_fp8)
return loss, acc return loss, acc
...@@ -174,6 +216,8 @@ def train_and_evaluate(args): ...@@ -174,6 +216,8 @@ def train_and_evaluate(args):
class TestMNIST(unittest.TestCase): class TestMNIST(unittest.TestCase):
"""MNIST unittests""" """MNIST unittests"""
gpu_has_fp8, reason = is_fp8_available()
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
"""Run MNIST without Transformer Engine""" """Run MNIST without Transformer Engine"""
...@@ -192,7 +236,33 @@ class TestMNIST(unittest.TestCase): ...@@ -192,7 +236,33 @@ class TestMNIST(unittest.TestCase):
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
self.args.use_te = True self.args.use_te = True
self.args.use_fp8 = False
self.args.save_model = True
actual = train_and_evaluate(self.args)
if os.path.exists("mnist_cnn.pdparams"):
os.remove("mnist_cnn.pdparams")
self.verify(actual)
@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
self.args.use_te = True
self.args.use_fp8 = True
self.args.save_model = True
actual = train_and_evaluate(self.args)
if os.path.exists("mnist_cnn.pdparams"):
os.remove("mnist_cnn.pdparams")
self.verify(actual)
@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8_calibration(self):
"""Test Transformer Engine with FP8 calibration"""
self.args.use_te = True
self.args.use_fp8 = False
self.args.use_fp8_infer = True
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
if os.path.exists("mnist_cnn.pdparams"):
os.remove("mnist_cnn.pdparams")
self.verify(actual) self.verify(actual)
......
This diff is collapsed.
...@@ -20,6 +20,7 @@ from transformer_engine.paddle.cpp_extensions import ( ...@@ -20,6 +20,7 @@ from transformer_engine.paddle.cpp_extensions import (
fp8_gemm, fp8_gemm,
transpose, transpose,
cast_transpose, cast_transpose,
cast_transpose_bgrad,
te_gelu, te_gelu,
gelu_fp8, gelu_fp8,
dgelu_cast_transpose_bgrad_fp8, dgelu_cast_transpose_bgrad_fp8,
...@@ -41,6 +42,8 @@ from transformer_engine.paddle.cpp_extensions import ( ...@@ -41,6 +42,8 @@ from transformer_engine.paddle.cpp_extensions import (
scaled_upper_triang_masked_softmax_backward, scaled_upper_triang_masked_softmax_backward,
) )
from transformer_engine.paddle.fp8 import is_fp8_available from transformer_engine.paddle.fp8 import is_fp8_available
from transformer_engine.paddle.constants import FP8FwdTensors
from transformer_engine.common.recipe import DelayedScaling
np.random.seed(10) np.random.seed(10)
paddle.seed(10) paddle.seed(10)
...@@ -60,12 +63,12 @@ def test_quantize_dequantize(): ...@@ -60,12 +63,12 @@ def test_quantize_dequantize():
""" """
a = paddle.rand(shape=(32, 32), dtype='float32') a = paddle.rand(shape=(32, 32), dtype='float32')
# Init fp8_meta # Init fp8_meta
fp8_meta = create_fp8_meta(num_fp8_tensors=3, amax_history_len=10) fp8_meta = create_fp8_meta()
for fp8_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: for fp8_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
a_fp8 = cast_to_fp8(a, fp8_meta, tex.FP8FwdTensors.GEMM1_OUTPUT, otype=fp8_dtype) a_fp8 = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_OUTPUT, otype=fp8_dtype)
b = cast_from_fp8(a_fp8, b = cast_from_fp8(a_fp8,
fp8_meta, fp8_meta,
tex.FP8FwdTensors.GEMM1_OUTPUT, FP8FwdTensors.GEMM1_OUTPUT,
itype=fp8_dtype, itype=fp8_dtype,
otype=tex.DType.kFloat32) otype=tex.DType.kFloat32)
assert_allclose(a, b, rtol=5e-2, atol=5e-2) assert_allclose(a, b, rtol=5e-2, atol=5e-2)
...@@ -114,12 +117,12 @@ class TestTranspose: ...@@ -114,12 +117,12 @@ class TestTranspose:
min_val = -8 min_val = -8
max_val = 8 max_val = 8
a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32') a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32')
fp8_meta = create_fp8_meta(num_fp8_tensors=1, amax_history_len=1) fp8_meta = create_fp8_meta()
a_fp8 = cast_to_fp8(a, fp8_meta, tex.FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype) a_fp8 = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
a_fp8_transposed = transpose(a_fp8, otype=fp8_dtype) a_fp8_transposed = transpose(a_fp8, otype=fp8_dtype)
a_transposed = cast_from_fp8(a_fp8_transposed, a_transposed = cast_from_fp8(a_fp8_transposed,
fp8_meta, fp8_meta,
tex.FP8FwdTensors.GEMM1_INPUT, FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype, itype=fp8_dtype,
otype=tex.DType.kFloat32) otype=tex.DType.kFloat32)
assert_allclose(a_transposed, a.T) assert_allclose(a_transposed, a.T)
...@@ -134,27 +137,59 @@ class TestTranspose: ...@@ -134,27 +137,59 @@ class TestTranspose:
min_val = -8 min_val = -8
max_val = 8 max_val = 8
a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32') a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32')
fp8_meta = create_fp8_meta(num_fp8_tensors=1, amax_history_len=1) fp8_meta = create_fp8_meta()
a_fp8_casted, a_fp8_transposed = cast_transpose(a, a_fp8_casted, a_fp8_transposed = cast_transpose(a,
fp8_meta, fp8_meta,
tex.FP8FwdTensors.GEMM1_INPUT, FP8FwdTensors.GEMM1_INPUT,
otype=fp8_dtype) otype=fp8_dtype)
a_transposed = cast_from_fp8(a_fp8_transposed, a_transposed = cast_from_fp8(a_fp8_transposed,
fp8_meta, fp8_meta,
tex.FP8FwdTensors.GEMM1_INPUT, FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype, itype=fp8_dtype,
otype=tex.DType.kFloat32) otype=tex.DType.kFloat32)
a_casted = cast_from_fp8(a_fp8_casted, a_casted = cast_from_fp8(a_fp8_casted,
fp8_meta, fp8_meta,
tex.FP8FwdTensors.GEMM1_INPUT, FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype, itype=fp8_dtype,
otype=tex.DType.kFloat32) otype=tex.DType.kFloat32)
assert_allclose(a_casted, a) assert_allclose(a_casted, a)
assert_allclose(a_transposed, a.T) assert_allclose(a_transposed, a.T)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_cast_transpose_bgrad(fp8_dtype):
"""
Test cast_transpose_bgrad
"""
min_val = -8
max_val = 8
a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32')
fp8_meta = create_fp8_meta()
bgrad, a_fp8_casted, a_fp8_transposed = cast_transpose_bgrad(a,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
otype=fp8_dtype)
a_transposed = cast_from_fp8(a_fp8_transposed,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32)
a_casted = cast_from_fp8(a_fp8_casted,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32)
assert_allclose(a_casted, a)
assert_allclose(a_transposed, a.T)
assert_allclose(bgrad, a.sum(axis=0))
class TestActivation: class TestActivation:
""" """
...@@ -180,13 +215,13 @@ class TestActivation: ...@@ -180,13 +215,13 @@ class TestActivation:
Test FP8 GELU Forward Test FP8 GELU Forward
""" """
a = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1 a = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1
fp8_meta = create_fp8_meta(num_fp8_tensors=1, amax_history_len=1) fp8_meta = create_fp8_meta()
gelu_out_fp8 = gelu_fp8(a, fp8_meta, tex.FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype) gelu_out_fp8 = gelu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
gelu_out = cast_from_fp8(gelu_out_fp8, gelu_out = cast_from_fp8(gelu_out_fp8,
fp8_meta, fp8_meta,
tex.FP8FwdTensors.GEMM1_INPUT, FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype, itype=fp8_dtype,
otype=tex.DType.kFloat32) otype=tex.DType.kFloat32)
...@@ -208,19 +243,22 @@ class TestActivation: ...@@ -208,19 +243,22 @@ class TestActivation:
y_grad = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1 y_grad = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1
paddle.autograd.backward([y], [y_grad], True) paddle.autograd.backward([y], [y_grad], True)
# calculate fp8 # calculate fp8
fp8_meta = create_fp8_meta(num_fp8_tensors=1, amax_history_len=1) fp8_meta = create_fp8_meta()
x_grad_fp8, x_grad_t_fp8, dbias = dgelu_cast_transpose_bgrad_fp8( x_grad_fp8, x_grad_t_fp8, dbias = dgelu_cast_transpose_bgrad_fp8(y_grad,
y_grad, x, fp8_meta, tex.FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype) x,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
otype=fp8_dtype)
x_grad = cast_from_fp8(x_grad_fp8, x_grad = cast_from_fp8(x_grad_fp8,
fp8_meta, fp8_meta,
tex.FP8FwdTensors.GEMM1_INPUT, FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype, itype=fp8_dtype,
otype=tex.DType.kFloat32) otype=tex.DType.kFloat32)
x_grad_t = cast_from_fp8(x_grad_t_fp8, x_grad_t = cast_from_fp8(x_grad_t_fp8,
fp8_meta, fp8_meta,
tex.FP8FwdTensors.GEMM1_INPUT, FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype, itype=fp8_dtype,
otype=tex.DType.kFloat32) otype=tex.DType.kFloat32)
...@@ -292,19 +330,19 @@ class TestGemm: ...@@ -292,19 +330,19 @@ class TestGemm:
max_val = 8 max_val = 8
fp8_dtype = tex.DType.kFloat8E4M3 fp8_dtype = tex.DType.kFloat8E4M3
out_dtype = paddle.float32 out_dtype = paddle.float32
fp8_meta = create_fp8_meta(num_fp8_tensors=3, amax_history_len=10) fp8_meta = create_fp8_meta(num_gemms=1)
a = paddle.cast(paddle.randint(min_val, max_val, shape=(m, k)), 'float32') a = paddle.cast(paddle.randint(min_val, max_val, shape=(m, k)), 'float32')
a_casted = cast_to_fp8(a, fp8_meta, tex.FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype) a_casted = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
b = paddle.cast(paddle.randint(min_val, max_val, shape=(n, k)), 'float32') b = paddle.cast(paddle.randint(min_val, max_val, shape=(n, k)), 'float32')
b_casted = cast_to_fp8(b, fp8_meta, tex.FP8FwdTensors.GEMM1_WEIGHT, otype=fp8_dtype) b_casted = cast_to_fp8(b, fp8_meta, FP8FwdTensors.GEMM1_WEIGHT, otype=fp8_dtype)
workspace = paddle.zeros(shape=[33_554_432], dtype='uint8') workspace = paddle.zeros(shape=[33_554_432], dtype='uint8')
ref_out = paddle.matmul(a, b.T) ref_out = paddle.matmul(a, b.T)
actual_out = fp8_gemm(b_casted, fp8_meta.scale_inv, tex.FP8FwdTensors.GEMM1_WEIGHT, actual_out = fp8_gemm(b_casted, fp8_meta.scale_inv, FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype,
fp8_dtype, a_casted, fp8_meta.scale_inv, a_casted, fp8_meta.scale_inv, FP8FwdTensors.GEMM1_INPUT, fp8_dtype,
tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype, out_dtype, workspace) out_dtype, workspace)
assert_allclose(actual_out, ref_out) assert_allclose(actual_out, ref_out)
...@@ -379,8 +417,8 @@ class TestLayerNorm: ...@@ -379,8 +417,8 @@ class TestLayerNorm:
gamma = paddle.uniform(shape=(H,), dtype='float32') gamma = paddle.uniform(shape=(H,), dtype='float32')
beta = paddle.uniform(shape=(H,), dtype='float32') beta = paddle.uniform(shape=(H,), dtype='float32')
fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT fp8_tensor = FP8FwdTensors.GEMM1_INPUT
fp8_meta = create_fp8_meta(num_fp8_tensors=1, amax_history_len=1) fp8_meta = create_fp8_meta()
y_ref, mu_ref, rsigma_ref = layernorm_fwd(x, gamma, beta, eps, tex.DType.kFloat32) y_ref, mu_ref, rsigma_ref = layernorm_fwd(x, gamma, beta, eps, tex.DType.kFloat32)
...@@ -469,8 +507,8 @@ class TestRMSNorm: ...@@ -469,8 +507,8 @@ class TestRMSNorm:
x = paddle.uniform(shape=(N, H), dtype='float32') x = paddle.uniform(shape=(N, H), dtype='float32')
gamma = paddle.uniform(shape=(H,), dtype='float32') gamma = paddle.uniform(shape=(H,), dtype='float32')
fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT fp8_tensor = FP8FwdTensors.GEMM1_INPUT
fp8_meta = create_fp8_meta(num_fp8_tensors=1, amax_history_len=1) fp8_meta = create_fp8_meta()
y_ref, rsigma_ref = rmsnorm_fwd(x, gamma, eps, tex.DType.kFloat32) y_ref, rsigma_ref = rmsnorm_fwd(x, gamma, eps, tex.DType.kFloat32)
...@@ -748,8 +786,9 @@ class TestSoftmax: ...@@ -748,8 +786,9 @@ class TestSoftmax:
Test softmax operators Test softmax operators
""" """
@staticmethod
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) @pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
def test_scaled_softmax_fwd_bwd(self, dtype): def test_scaled_softmax_fwd_bwd(dtype):
"""test scaled softmax""" """test scaled softmax"""
B, H, S = (16, 4, 32) B, H, S = (16, 4, 32)
scale = 0.8 scale = 0.8
...@@ -768,8 +807,9 @@ class TestSoftmax: ...@@ -768,8 +807,9 @@ class TestSoftmax:
assert_allclose(y_ref, y, rtol=1e-4, atol=1e-3) assert_allclose(y_ref, y, rtol=1e-4, atol=1e-3)
assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3) assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3)
@staticmethod
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) @pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
def test_scaled_masked_softmax_fwd_bwd(self, dtype): def test_scaled_masked_softmax_fwd_bwd(dtype):
"""test scaled masked softmax""" """test scaled masked softmax"""
B, H, S = (16, 4, 32) B, H, S = (16, 4, 32)
scale = 0.8 scale = 0.8
...@@ -791,8 +831,9 @@ class TestSoftmax: ...@@ -791,8 +831,9 @@ class TestSoftmax:
assert_allclose(y_ref, y, rtol=1e-4, atol=1e-3) assert_allclose(y_ref, y, rtol=1e-4, atol=1e-3)
assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3) assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3)
@staticmethod
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) @pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
def test_scaled_upper_triang_masked_softmax_fwd_bwd(self, dtype): def test_scaled_upper_triang_masked_softmax_fwd_bwd(dtype):
"""test scaled upper triang masked softmax""" """test scaled upper triang masked softmax"""
B, S = (16, 32) B, S = (16, 32)
scale = 0.8 scale = 0.8
...@@ -818,3 +859,27 @@ class TestSoftmax: ...@@ -818,3 +859,27 @@ class TestSoftmax:
assert_allclose(y_ref, y, rtol=1e-4, atol=5e-3) assert_allclose(y_ref, y, rtol=1e-4, atol=5e-3)
assert_allclose(dx_ref, dx, rtol=1e-4, atol=5e-3) assert_allclose(dx_ref, dx, rtol=1e-4, atol=5e-3)
def test_update_scale():
"""Test update_scale"""
num_gemm = 6
recipe = DelayedScaling()
fp8_max = recipe.fp8_format.value.max_fwd
amax_tensor = paddle.rand(shape=[num_gemm], dtype='float32') * fp8_max
scale_tensor = paddle.ones(shape=[num_gemm], dtype='float32')
def calc_ref(amax, scale, fp8_max, margin=0):
"""Calculate reference scale"""
exp = paddle.floor(paddle.log2(fp8_max / amax)) - margin
sf = paddle.round(2**paddle.abs(exp))
sf = paddle.where(amax > 0.0, sf, scale)
sf = paddle.where(paddle.isfinite(amax), sf, scale)
sf = paddle.where(exp < 0, 1 / sf, sf)
return sf
scale_ref = calc_ref(amax_tensor, scale_tensor, fp8_max, 0.)
scale_actual = tex.update_scale(amax_tensor, scale_tensor, fp8_max, 0.)
assert_allclose(scale_ref, scale_actual, rtol=1e-5, atol=1e-5)
...@@ -8,17 +8,16 @@ import numpy as np ...@@ -8,17 +8,16 @@ import numpy as np
import paddle import paddle
import transformer_engine # pylint: disable=unused-import import transformer_engine # pylint: disable=unused-import
import transformer_engine_paddle as tex # pylint: disable=wrong-import-order
from transformer_engine.paddle.fp8 import FP8TensorMeta
def create_fp8_meta(num_fp8_tensors, amax_history_len):
def create_fp8_meta(num_gemms=1, amax_history_len=10):
""" """
Create and initialize FP8TensorMeta Create and initialize FP8TensorMeta
""" """
fp8_meta = tex.FP8TensorMeta() fp8_meta = FP8TensorMeta(is_forward=True)
fp8_meta.scale = paddle.ones(num_fp8_tensors, dtype='float32') fp8_meta.prepare(num_gemms, amax_history_len)
fp8_meta.scale_inv = paddle.ones(num_fp8_tensors, dtype='float32')
fp8_meta.amax_history = paddle.zeros((amax_history_len, num_fp8_tensors), dtype='float32')
return fp8_meta return fp8_meta
......
...@@ -4,3 +4,4 @@ ...@@ -4,3 +4,4 @@
"""Transformer Engine bindings for Paddle""" """Transformer Engine bindings for Paddle"""
from .layer import Linear, LayerNorm, LayerNormLinear, LayerNormMLP from .layer import Linear, LayerNorm, LayerNormLinear, LayerNormMLP
from .fp8 import fp8_autocast
...@@ -3,8 +3,33 @@ ...@@ -3,8 +3,33 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Constants""" """Constants"""
from enum import Enum
import paddle import paddle
import transformer_engine_paddle as tex import transformer_engine_paddle as tex
class FP8FwdTensors(Enum):
"""Used as named indices on the `scale`, `scale_inv`,
and `amax` tensors in the `FP8TensorMeta` class."""
GEMM1_INPUT = 0
GEMM1_WEIGHT = 1
GEMM1_OUTPUT = 2
GEMM2_INPUT = 3
GEMM2_WEIGHT = 4
GEMM2_OUTPUT = 5
class FP8BwdTensors(Enum):
"""Used as named indices on the `scale`, `scale_inv`,
and `amax` tensors in the `FP8TensorMeta` class."""
GRAD_OUTPUT1 = 0
GRAD_INPUT1 = 1
GRAD_OUTPUT2 = 2
GRAD_INPUT2 = 3
""" """
Map from paddle dtype to TE dtype Map from paddle dtype to TE dtype
""" """
......
...@@ -7,7 +7,8 @@ import math ...@@ -7,7 +7,8 @@ import math
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import paddle import paddle
import transformer_engine_paddle as tex import transformer_engine_paddle as tex
from .constants import TE_DType from .constants import TE_DType, FP8FwdTensors, FP8BwdTensors
from .fp8 import FP8TensorMeta
def gemm( def gemm(
...@@ -97,11 +98,11 @@ def gemm( ...@@ -97,11 +98,11 @@ def gemm(
def fp8_gemm( def fp8_gemm(
A: paddle.Tensor, A: paddle.Tensor,
A_scale_inv: paddle.Tensor, A_scale_inv: paddle.Tensor,
A_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], A_fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
A_dtype: tex.DType, A_dtype: tex.DType,
B: paddle.Tensor, B: paddle.Tensor,
B_scale_inv: paddle.Tensor, B_scale_inv: paddle.Tensor,
B_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], B_fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
B_dtype: tex.DType, B_dtype: tex.DType,
out_dtype: paddle.dtype, out_dtype: paddle.dtype,
workspace: paddle.Tensor, workspace: paddle.Tensor,
...@@ -109,7 +110,7 @@ def fp8_gemm( ...@@ -109,7 +110,7 @@ def fp8_gemm(
accumulate: bool = False, accumulate: bool = False,
out: Optional[paddle.Tensor] = None, out: Optional[paddle.Tensor] = None,
out_index=None, out_index=None,
fp8_meta_tensor: tex.FP8TensorMeta = None, fp8_meta_tensor: FP8TensorMeta = None,
bias: Optional[paddle.Tensor] = None, bias: Optional[paddle.Tensor] = None,
use_bias: bool = False, use_bias: bool = False,
use_split_accumulator: bool = False, use_split_accumulator: bool = False,
...@@ -151,8 +152,8 @@ def fp8_gemm( ...@@ -151,8 +152,8 @@ def fp8_gemm(
None if out_index is None else fp8_meta_tensor.amax_history, None if out_index is None else fp8_meta_tensor.amax_history,
gelu_input, # this is pre_gelu_out gelu_input, # this is pre_gelu_out
workspace, workspace,
int(A_fp8_tensor), A_fp8_tensor.value,
int(B_fp8_tensor), B_fp8_tensor.value,
0 if out_index is None else out_index, 0 if out_index is None else out_index,
int(A_dtype), int(A_dtype),
int(B_dtype), int(B_dtype),
...@@ -178,8 +179,8 @@ def fp8_gemm( ...@@ -178,8 +179,8 @@ def fp8_gemm(
def cast_to_fp8( def cast_to_fp8(
inp: paddle.Tensor, inp: paddle.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType, otype: tex.DType,
) -> paddle.Tensor: ) -> paddle.Tensor:
"""Cast input to FP8""" """Cast input to FP8"""
...@@ -188,7 +189,7 @@ def cast_to_fp8( ...@@ -188,7 +189,7 @@ def cast_to_fp8(
fp8_meta_tensor.scale, fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history, fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv, fp8_meta_tensor.scale_inv,
int(fp8_tensor), fp8_tensor.value,
int(otype), int(otype),
) )
return out return out
...@@ -196,8 +197,8 @@ def cast_to_fp8( ...@@ -196,8 +197,8 @@ def cast_to_fp8(
def cast_from_fp8( def cast_from_fp8(
inp: paddle.Tensor, inp: paddle.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
itype: tex.DType, itype: tex.DType,
otype: tex.DType, otype: tex.DType,
) -> paddle.Tensor: ) -> paddle.Tensor:
...@@ -205,7 +206,7 @@ def cast_from_fp8( ...@@ -205,7 +206,7 @@ def cast_from_fp8(
return tex.cast_from_fp8( return tex.cast_from_fp8(
inp, inp,
fp8_meta_tensor.scale_inv, fp8_meta_tensor.scale_inv,
int(fp8_tensor), fp8_tensor.value,
int(itype), int(itype),
int(otype), int(otype),
) )
...@@ -224,8 +225,8 @@ def transpose( ...@@ -224,8 +225,8 @@ def transpose(
def cast_transpose( def cast_transpose(
inp: paddle.Tensor, inp: paddle.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType, otype: tex.DType,
) -> Union[Tuple[paddle.Tensor, paddle.Tensor], None]: ) -> Union[Tuple[paddle.Tensor, paddle.Tensor], None]:
"""Cast + Transpose with FP8 output""" """Cast + Transpose with FP8 output"""
...@@ -234,13 +235,32 @@ def cast_transpose( ...@@ -234,13 +235,32 @@ def cast_transpose(
fp8_meta_tensor.scale, fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history, fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv, fp8_meta_tensor.scale_inv,
int(fp8_tensor), fp8_tensor.value,
int(otype), int(otype),
) )
return cast_out, transpose_out return cast_out, transpose_out
def cast_transpose_bgrad(
inp: paddle.Tensor,
fp8_meta_tensor: FP8TensorMeta,
fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType,
) -> Union[Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor], None]:
"""Fused Cast + Transpose + Bias Grad"""
grad_bias, cast_out, transpose_out, _, _ = tex.te_cast_transpose_bgrad(
inp,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor.value,
int(otype),
)
return grad_bias, cast_out, transpose_out
def te_gelu( def te_gelu(
inp: paddle.Tensor, inp: paddle.Tensor,
otype: tex.DType, otype: tex.DType,
...@@ -254,8 +274,8 @@ def te_gelu( ...@@ -254,8 +274,8 @@ def te_gelu(
def gelu_fp8( def gelu_fp8(
inp: paddle.Tensor, inp: paddle.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType, otype: tex.DType,
) -> paddle.Tensor: ) -> paddle.Tensor:
"""GELU + FP8 cast""" """GELU + FP8 cast"""
...@@ -264,7 +284,7 @@ def gelu_fp8( ...@@ -264,7 +284,7 @@ def gelu_fp8(
fp8_meta_tensor.scale, fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history, fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv, fp8_meta_tensor.scale_inv,
int(fp8_tensor), fp8_tensor.value,
int(otype), int(otype),
) )
...@@ -274,8 +294,8 @@ def gelu_fp8( ...@@ -274,8 +294,8 @@ def gelu_fp8(
def dgelu_cast_transpose_bgrad_fp8( def dgelu_cast_transpose_bgrad_fp8(
grad_output: paddle.Tensor, grad_output: paddle.Tensor,
gelu_input: paddle.Tensor, gelu_input: paddle.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType, otype: tex.DType,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" """
...@@ -288,7 +308,7 @@ def dgelu_cast_transpose_bgrad_fp8( ...@@ -288,7 +308,7 @@ def dgelu_cast_transpose_bgrad_fp8(
fp8_meta_tensor.scale, fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history, fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv, fp8_meta_tensor.scale_inv,
int(fp8_tensor), fp8_tensor.value,
int(otype), int(otype),
) )
...@@ -300,8 +320,8 @@ def layernorm_fwd_fp8( ...@@ -300,8 +320,8 @@ def layernorm_fwd_fp8(
weight: paddle.Tensor, weight: paddle.Tensor,
bias: paddle.Tensor, bias: paddle.Tensor,
eps: float, eps: float,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType, otype: tex.DType,
sm_margin: int = 0, sm_margin: int = 0,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
...@@ -310,7 +330,7 @@ def layernorm_fwd_fp8( ...@@ -310,7 +330,7 @@ def layernorm_fwd_fp8(
out, mu, rsigma, _, _ = tex.te_layernorm_fwd_fp8(inp, weight, bias, fp8_meta_tensor.scale, out, mu, rsigma, _, _ = tex.te_layernorm_fwd_fp8(inp, weight, bias, fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history, fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv, eps, fp8_meta_tensor.scale_inv, eps,
int(fp8_tensor), int(otype), sm_margin, fp8_tensor.value, int(otype), sm_margin,
zero_centered_gamma) zero_centered_gamma)
return out, mu, rsigma return out, mu, rsigma
...@@ -356,15 +376,15 @@ def rmsnorm_fwd_fp8( ...@@ -356,15 +376,15 @@ def rmsnorm_fwd_fp8(
inp: paddle.Tensor, inp: paddle.Tensor,
weight: paddle.Tensor, weight: paddle.Tensor,
eps: float, eps: float,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType, otype: tex.DType,
sm_margin: int = 0, sm_margin: int = 0,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""RMSNorm with FP8 output""" """RMSNorm with FP8 output"""
out, rsigma, _, _ = tex.te_rmsnorm_fwd_fp8(inp, weight, fp8_meta_tensor.scale, out, rsigma, _, _ = tex.te_rmsnorm_fwd_fp8(inp, weight, fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history, fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv, eps, int(fp8_tensor), fp8_meta_tensor.scale_inv, eps, fp8_tensor.value,
int(otype), sm_margin) int(otype), sm_margin)
return out, rsigma return out, rsigma
......
...@@ -16,36 +16,13 @@ ...@@ -16,36 +16,13 @@
#include <transformer_engine/softmax.h> #include <transformer_engine/softmax.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transpose.h> #include <transformer_engine/transpose.h>
#include <cstdlib>
#include <vector> #include <vector>
#include "paddle/extension.h" #include "paddle/extension.h"
namespace transformer_engine { namespace transformer_engine {
namespace paddle_ext { namespace paddle_ext {
// Each tensor here is shape (N, ) holding all scaling
// data for a single FP8 block, e.g. LayerNormLinear
class FP8TensorMeta {
public:
paddle::Tensor scale;
paddle::Tensor scale_inv;
paddle::Tensor amax_history;
};
// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum class FP8FwdTensors {
GEMM1_INPUT = 0,
GEMM1_WEIGHT = 1,
GEMM1_OUTPUT = 2,
GEMM2_INPUT = 3,
GEMM2_WEIGHT = 4,
GEMM2_OUTPUT = 5
};
// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum class FP8BwdTensors { GRAD_OUTPUT1 = 0, GRAD_INPUT1 = 1, GRAD_OUTPUT2 = 2, GRAD_INPUT2 = 3 };
// Paddle Tensor Utils // Paddle Tensor Utils
template <typename T> template <typename T>
inline const void *GetDataPtr(const paddle::Tensor &x, int64_t index) { inline const void *GetDataPtr(const paddle::Tensor &x, int64_t index) {
......
...@@ -131,6 +131,50 @@ std::vector<paddle::Tensor> te_cast_transpose(const paddle::Tensor &input, ...@@ -131,6 +131,50 @@ std::vector<paddle::Tensor> te_cast_transpose(const paddle::Tensor &input,
return {input_cast, input_transpose}; return {input_cast, input_transpose};
} }
std::vector<paddle::Tensor> te_cast_transpose_bgrad(const paddle::Tensor &grad_output,
const paddle::Tensor &scale,
paddle::Tensor &amax, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
int64_t index, int64_t otype) {
auto shape = GetShapeArray(grad_output);
NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions.");
size_t M = shape[0];
size_t N = shape[1];
auto grad_bias =
paddle::empty({grad_output.shape()[1]}, grad_output.dtype(), grad_output.place());
auto grad_output_cast = paddle::empty_like(grad_output, Nvte2PaddleDType(Int2NvteDType(otype)),
grad_output.place());
auto grad_output_transpose =
paddle::empty({grad_output.shape()[1], grad_output.shape()[0]},
Nvte2PaddleDType(Int2NvteDType(otype)), grad_output.place());
auto input_cu = MakeNvteTensor(grad_output);
void *amax_data = GetDataPtr<float>(amax, index);
void *scale_data = const_cast<void *>(GetDataPtr<float>(scale, index));
void *scale_inv_data = GetDataPtr<float>(scale_inv, index);
auto output_cast_cu = MakeNvteTensor(grad_output_cast.data(), {M, N}, Int2NvteDType(otype),
amax_data, scale_data, scale_inv_data);
auto output_transpose_cu =
MakeNvteTensor(grad_output_transpose.data(), {N, M}, Int2NvteDType(otype), amax_data,
scale_data, scale_inv_data);
auto dbias_cu = MakeNvteTensor(grad_bias);
transformer_engine::TensorWrapper workspace;
nvte_cast_transpose_dbias(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(),
dbias_cu.data(), workspace.data(), grad_output.stream());
// Fill workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), grad_output.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
nvte_cast_transpose_dbias(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(),
dbias_cu.data(), workspace.data(), grad_output.stream());
return {grad_bias, grad_output_cast, grad_output_transpose};
}
void te_gemm(const paddle::Tensor &A, const paddle::optional<paddle::Tensor> &A_scale_inverse, void te_gemm(const paddle::Tensor &A, const paddle::optional<paddle::Tensor> &A_scale_inverse,
const paddle::Tensor &B, const paddle::optional<paddle::Tensor> &B_scale_inverse, const paddle::Tensor &B, const paddle::optional<paddle::Tensor> &B_scale_inverse,
const paddle::optional<paddle::Tensor> &bias, paddle::Tensor &D, // NOLINT const paddle::optional<paddle::Tensor> &bias, paddle::Tensor &D, // NOLINT
...@@ -975,6 +1019,30 @@ void te_scaled_upper_triang_masked_softmax_backward(paddle::Tensor &output_grads ...@@ -975,6 +1019,30 @@ void te_scaled_upper_triang_masked_softmax_backward(paddle::Tensor &output_grads
softmax_results.stream()); softmax_results.stream());
} }
__global__ void UpdateScalesKernel(const float *amax, const float *scale, float margin,
float fp8_max, size_t size, float *scale_out) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < size) {
float exp = floor(log2(fp8_max / amax[idx])) - margin;
float sf = round(powf(2.0f, abs(exp)));
sf = ((amax[idx] > 0.0f) && isfinite(amax[idx])) ? sf : scale[idx];
scale_out[idx] = exp < 0.0f ? 1 / sf : sf;
}
}
std::vector<paddle::Tensor> update_scale(const paddle::Tensor &amax, const paddle::Tensor &scale,
float fp8_max, float margin) {
const size_t block_size = 512;
size_t size = static_cast<size_t>(amax.numel());
size_t num_blocks = (size + block_size - 1) / block_size;
auto scale_out = paddle::empty_like(scale, scale.dtype(), scale.place());
UpdateScalesKernel<<<num_blocks, block_size, 0, amax.stream()>>>(
amax.data<float>(), scale.data<float>(), margin, fp8_max, size, scale_out.data<float>());
return {scale_out};
}
} // namespace paddle_ext } // namespace paddle_ext
} // namespace transformer_engine } // namespace transformer_engine
...@@ -1021,6 +1089,13 @@ PD_BUILD_OP(te_cast_transpose) ...@@ -1021,6 +1089,13 @@ PD_BUILD_OP(te_cast_transpose)
.Attrs({"index: int64_t", "otype: int64_t"}) .Attrs({"index: int64_t", "otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_cast_transpose)); .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_cast_transpose));
PD_BUILD_OP(te_cast_transpose_bgrad)
.Inputs({"GradOutput", "Scale", "_Amax", "_ScaleInv"})
.Outputs({"dBias", "CastedOutput", "TransposedOutput", "Amax", "ScaleInv"})
.SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}})
.Attrs({"index: int64_t", "otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_cast_transpose_bgrad));
PD_BUILD_OP(te_gelu_fp8) PD_BUILD_OP(te_gelu_fp8)
.Inputs({"Input", "Scale", "_Amax", "_ScaleInv"}) .Inputs({"Input", "Scale", "_Amax", "_ScaleInv"})
.Outputs({"Output", "Amax", "ScaleInv"}) .Outputs({"Output", "Amax", "ScaleInv"})
...@@ -1166,3 +1241,9 @@ PD_BUILD_OP(te_scaled_upper_triang_masked_softmax_backward) ...@@ -1166,3 +1241,9 @@ PD_BUILD_OP(te_scaled_upper_triang_masked_softmax_backward)
.SetInplaceMap({{"out_grad_", "out_grad"}}) .SetInplaceMap({{"out_grad_", "out_grad"}})
.SetKernelFn( .SetKernelFn(
PD_KERNEL(transformer_engine::paddle_ext::te_scaled_upper_triang_masked_softmax_backward)); PD_KERNEL(transformer_engine::paddle_ext::te_scaled_upper_triang_masked_softmax_backward));
PD_BUILD_OP(update_scale)
.Inputs({"Amax", "Scale"})
.Outputs({"ScaleOut"})
.Attrs({"fp8_max: float", "margin: float"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::update_scale));
...@@ -15,12 +15,6 @@ PYBIND11_MODULE(transformer_engine_paddle, m) { ...@@ -15,12 +15,6 @@ PYBIND11_MODULE(transformer_engine_paddle, m) {
// Misc // Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version"); m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version");
// Data structures // Data structures
py::class_<FP8TensorMeta>(m, "FP8TensorMeta")
.def(py::init<>())
.def_readwrite("scale", &FP8TensorMeta::scale)
.def_readwrite("scale_inv", &FP8TensorMeta::scale_inv)
.def_readwrite("amax_history", &FP8TensorMeta::amax_history);
py::enum_<DType>(m, "DType", py::module_local()) py::enum_<DType>(m, "DType", py::module_local())
.value("kByte", DType::kByte) .value("kByte", DType::kByte)
.value("kInt32", DType::kInt32) .value("kInt32", DType::kInt32)
...@@ -29,20 +23,6 @@ PYBIND11_MODULE(transformer_engine_paddle, m) { ...@@ -29,20 +23,6 @@ PYBIND11_MODULE(transformer_engine_paddle, m) {
.value("kBFloat16", DType::kBFloat16) .value("kBFloat16", DType::kBFloat16)
.value("kFloat8E4M3", DType::kFloat8E4M3) .value("kFloat8E4M3", DType::kFloat8E4M3)
.value("kFloat8E5M2", DType::kFloat8E5M2); .value("kFloat8E5M2", DType::kFloat8E5M2);
py::enum_<FP8FwdTensors>(m, "FP8FwdTensors")
.value("GEMM1_INPUT", FP8FwdTensors::GEMM1_INPUT)
.value("GEMM1_WEIGHT", FP8FwdTensors::GEMM1_WEIGHT)
.value("GEMM1_OUTPUT", FP8FwdTensors::GEMM1_OUTPUT)
.value("GEMM2_INPUT", FP8FwdTensors::GEMM2_INPUT)
.value("GEMM2_WEIGHT", FP8FwdTensors::GEMM2_WEIGHT)
.value("GEMM2_OUTPUT", FP8FwdTensors::GEMM2_OUTPUT);
py::enum_<FP8BwdTensors>(m, "FP8BwdTensors")
.value("GRAD_OUTPUT1", FP8BwdTensors::GRAD_OUTPUT1)
.value("GRAD_INPUT1", FP8BwdTensors::GRAD_INPUT1)
.value("GRAD_OUTPUT2", FP8BwdTensors::GRAD_OUTPUT2)
.value("GRAD_INPUT2", FP8BwdTensors::GRAD_INPUT2);
} }
} // namespace paddle_ext } // namespace paddle_ext
} // namespace transformer_engine } // namespace transformer_engine
...@@ -3,13 +3,22 @@ ...@@ -3,13 +3,22 @@
# See LICENSE for license information. # See LICENSE for license information.
"""FP8 utilities for TransformerEngine""" """FP8 utilities for TransformerEngine"""
from typing import Tuple from contextlib import contextmanager
from typing import Tuple, Optional, Dict, Any
import numpy as np
import paddle import paddle
import transformer_engine_paddle as tex import transformer_engine_paddle as tex
from transformer_engine.common.recipe import DelayedScaling, Format
# FP8 support
_is_fp8_available = None _is_fp8_available = None
_reason_for_no_fp8 = "" _reason_for_no_fp8 = ""
# FP8 status
_FP8_ENABLED = False
_FP8_CALIBRATION = False
_FP8_RECIPE = None
def _check_fp8_support() -> Tuple[bool, str]: def _check_fp8_support() -> Tuple[bool, str]:
...@@ -38,3 +47,156 @@ def is_fp8_available() -> Tuple[bool, str]: ...@@ -38,3 +47,156 @@ def is_fp8_available() -> Tuple[bool, str]:
if _is_fp8_available is None: if _is_fp8_available is None:
_is_fp8_available, _reason_for_no_fp8 = _check_fp8_support() _is_fp8_available, _reason_for_no_fp8 = _check_fp8_support()
return _is_fp8_available, _reason_for_no_fp8 return _is_fp8_available, _reason_for_no_fp8
# Functions used to access fp8 status
def is_fp8_enabled() -> bool:
"""Is FP8 enabled"""
return _FP8_ENABLED
def is_fp8_calibration() -> bool:
"""Is FP8 calibration"""
return _FP8_CALIBRATION
def get_fp8_recipe() -> DelayedScaling:
"""Return the fp8 recipe"""
return _FP8_RECIPE
def get_default_fp8_recipe() -> DelayedScaling:
"""FP8 recipe if not provided by user
Margin = 0, interval = 1, E4M3
"""
return DelayedScaling()
@contextmanager
def fp8_autocast(
enabled: bool = False,
calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
) -> None:
"""
Context manager for FP8 usage.
"""
global _FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE
fp8_state = (_FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE)
try:
_FP8_ENABLED = enabled
_FP8_CALIBRATION = calibrating
_FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
if enabled:
fp8_available, reason_for_no_fp8 = is_fp8_available()
assert fp8_available, reason_for_no_fp8
yield
finally:
(_FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE) = fp8_state
def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType:
"""Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or (fp8_recipe.fp8_format == Format.HYBRID
and fprop_tensor):
return tex.DType.kFloat8E4M3
return tex.DType.kFloat8E5M2
def amax_and_scale_update(
fp8_meta: Dict[str, Any],
fwd_update: bool,
) -> None:
"""Updates fp8 amaxes/scales for fwd | bwd."""
amax_compute = fp8_meta["recipe"].amax_compute_algo
sf_compute = fp8_meta["recipe"].scaling_factor_compute_algo
fp8_meta_tensor_key = "scaling_fwd" if fwd_update else "scaling_bwd"
fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd"
if not callable(amax_compute) and sf_compute is None:
# Obtain amax from history
amax_history = fp8_meta[fp8_meta_tensor_key].amax_history
if amax_compute == "max":
amax = paddle.max(amax_history, axis=0)
else: # amax_compute_algo == "most_recent"
amax = amax_history[0]
# Update amax history and set next amax to zero
if amax_history.shape[0] > 1:
amax_history = paddle.roll(amax_history, -1, 0)
amax_history[0] = 0.0
fp8_meta[fp8_meta_tensor_key].amax_history = amax_history
# Update scaling factor
fp8_meta[fp8_meta_tensor_key].scale = tex.update_scale(
amax=amax,
scale=fp8_meta[fp8_meta_tensor_key].scale,
fp8_max=fp8_meta[fp8_max_key],
margin=float(fp8_meta["recipe"].margin))
# Update scale_inv
fp8_meta[fp8_meta_tensor_key].scale_inv = \
1.0 / fp8_meta[fp8_meta_tensor_key].scale
else:
raise ValueError("We only support the fp8 recipe with 'max' or 'most_recent' "
"amax_compute_algo and default scaling_factor_compute_algo at this "
"moment.")
class FP8TensorMeta():
"""Holds FP8 scaling and amax history for FP8 layers"""
def __init__(self, is_forward: bool):
self.scale = paddle.Tensor()
self.scale_inv = paddle.Tensor()
self.amax_history = paddle.Tensor()
self.is_initialized = False
self.is_forward = is_forward
def prepare(self, num_gemms: bool, amax_history_len: int) -> None:
"""Prepare scales and amax tensors. It is called during fprop in each iteration.
If the meta tensors are not initialized yet, initialization is performed. If already
initialized, resize the meta tensors if amax_history_len has changed."""
if self.is_initialized:
# Handle changed amax history size.
curr_len = self.amax_history.shape[0]
num_fp8_tensors = self.amax_history.shape[1]
if amax_history_len < curr_len:
self.amax_history = (self.amax_history[:amax_history_len])
elif amax_history_len > curr_len:
extra_rows = amax_history_len - curr_len
self.amax_history = paddle.concat([
self.amax_history,
paddle.zeros((extra_rows, num_fp8_tensors), dtype='float32')
],
axis=0)
return
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd
num_fp8_tensors = (num_gemms * 3 if self.is_forward else num_gemms * 2)
self.scale = paddle.ones(num_fp8_tensors, dtype='float32')
self.scale_inv = paddle.ones(num_fp8_tensors, dtype='float32')
self.amax_history = paddle.zeros([amax_history_len, num_fp8_tensors], dtype='float32')
self.is_initialized = True
def to_numpy(self):
"""Convert FP8 meta tensors to numpy."""
assert self.is_initialized, "FP8TensorMeta is not initialized yet."
return {
'scale': self.scale.numpy(),
'scale_inv': self.scale_inv.numpy(),
'amax_history': self.amax_history.numpy(),
}
def from_numpy(self, data: Dict[str, np.array]):
"""Set FP8 meta tensors from numpy"""
self.scale = paddle.to_tensor(data['scale'])
self.scale_inv = paddle.to_tensor(data['scale_inv'])
self.amax_history = paddle.to_tensor(data['amax_history'])
self.is_initialized = True
...@@ -5,13 +5,32 @@ ...@@ -5,13 +5,32 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
import pickle
from typing import Generator, Dict, Tuple, Union, Any
import numpy as np
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.framework import _dygraph_tracer from paddle.fluid.framework import _dygraph_tracer
from ..constants import FP8BwdTensors
from ..cpp_extensions import cast_transpose, cast_transpose_bgrad, cast_to_fp8
from ..fp8 import (
get_fp8_recipe,
get_default_fp8_recipe,
is_fp8_enabled,
is_fp8_calibration,
amax_and_scale_update,
get_fp8_te_dtype,
FP8TensorMeta,
)
from ..profile import nvtx_range from ..profile import nvtx_range
from ..utils import get_bias_dtype, cast_if_needed
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
_cublas_workspace = None _cublas_workspace = None
...@@ -39,12 +58,20 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): ...@@ -39,12 +58,20 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
assert 'gpu' in paddle.device.get_device(), "TransformerEngine needs CUDA." assert 'gpu' in paddle.device.get_device(), "TransformerEngine needs CUDA."
self.fp8_initialized = False
self.fp8_enabled = False
self.fp8_calibration = False
self.fp8_meta = {}
self.fp8_meta["fp8_checkpoint"] = False
self.fp8_meta["recipe"] = get_default_fp8_recipe()
self.fp8_meta["scaling_fwd"] = FP8TensorMeta(is_forward=True)
self.fp8_meta["scaling_bwd"] = FP8TensorMeta(is_forward=False)
def set_activation_dtype(self, inp: paddle.Tensor) -> None: def set_activation_dtype(self, inp: paddle.Tensor) -> None:
"""Get activation data type for AMP.""" """Get activation data type for AMP."""
# Native AMP (`paddle.amp.auto_cast`) gets highest priority
tracer = _dygraph_tracer() tracer = _dygraph_tracer()
if tracer and tracer._amp_level != core.AmpLevel.O0: if tracer and tracer._amp_level != core.AmpLevel.O0:
# Set activation_dtype to the Paddle AMP dtype if under 'paddle.amp.auto_cast' context
if tracer._amp_dtype == 'float32': if tracer._amp_dtype == 'float32':
self.activation_dtype = paddle.float32 self.activation_dtype = paddle.float32
elif tracer._amp_dtype == 'bfloat16': elif tracer._amp_dtype == 'bfloat16':
...@@ -53,37 +80,207 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): ...@@ -53,37 +80,207 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
self.activation_dtype = paddle.float16 self.activation_dtype = paddle.float16
else: else:
raise RuntimeError(f"AMP format {tracer._amp_dtype} is not supported.") raise RuntimeError(f"AMP format {tracer._amp_dtype} is not supported.")
else:
# If not under paddle.amp.auto_cast, set activation_dtype to the input dtype.
# Also, make sure the parameters match the input dtype.
# Skip the check if activation_dtype is already set and if activation_dtype
# matches input dtype. If they do not match, e.g, when user switch from AMP
# training to normal training, activation_dtype will still be updated.
if hasattr(self, "activation_dtype") and self.activation_dtype == inp.dtype:
return
dtype = inp.dtype
for name, param in self.named_parameters():
if param is not None:
assert dtype == param.dtype, (
"Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}")
self.activation_dtype = dtype
# This routine is shared across FP8 and FP8_calibration paths so should not actually
# assume FP8 execution.
def fp8_init(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop."""
self.fp8_enabled = is_fp8_enabled()
self.fp8_calibration = is_fp8_calibration()
self.fp8_meta["fp8_checkpoint"] = self.fp8_enabled or self.fp8_calibration
if self.fp8_enabled or self.fp8_calibration:
# FP8 init has already been run and recipe is the same, don't do anything.
if self.fp8_initialized and get_fp8_recipe() == self.fp8_meta["recipe"]:
return
# Set FP8, recipe, and other FP8 metadata
self.fp8_meta["recipe"] = get_fp8_recipe()
# Set FP8_MAX per tensor according to recipe
self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd
# Allocate scales and amaxes
amax_history_len = self.fp8_meta["recipe"].amax_history_len
self.fp8_meta["scaling_fwd"].prepare(num_gemms, amax_history_len)
self.fp8_meta["scaling_bwd"].prepare(num_gemms, amax_history_len)
self.fp8_initialized = True
else:
# If fp8 isn't enabled, turn off and return.
self.fp8_initialized = False
return
def _get_fp8_state(self) -> paddle.Tensor:
"""Dump FP8 state to paddle.Tensor."""
state = None
if self.fp8_meta["fp8_checkpoint"]:
state = {}
state["scaling_fwd"] = self.fp8_meta["scaling_fwd"].to_numpy()
state["scaling_bwd"] = self.fp8_meta["scaling_bwd"].to_numpy()
# Store other pickelable values.
extra = {}
for k, v in self.fp8_meta.items():
if isinstance(v, (bool, int, float, str)):
extra[k] = v
state["extra_fp8_variables"] = extra
state_serialized = pickle.dumps(state)
state_tensor = paddle.to_tensor(np.frombuffer(state_serialized, dtype=np.uint8))
return state_tensor
@paddle.no_grad()
def state_dict(
self,
destination=None,
include_sublayers=True,
structured_name_prefix="",
use_hook=True,
):
"""Save FP8 State when checkpointing."""
st = super().state_dict(
destination=destination,
include_sublayers=include_sublayers,
structured_name_prefix=structured_name_prefix,
use_hook=use_hook,
)
st["fp8_state"] = self._get_fp8_state()
return st
def _set_fp8_state(self, state: paddle.Tensor) -> None:
"""Load previous state."""
if state is None:
return return
# All checks after this have already been performed once, thus skip state = pickle.loads(state.numpy().tobytes())
# We assume that user doesn't change input types across iterations if state is None:
if hasattr(self, "activation_dtype") and self.activation_dtype == inp.dtype:
return return
dtype = inp.dtype # Load fp8 meta tensors.
self.fp8_meta["scaling_fwd"].from_numpy(state["scaling_fwd"])
self.fp8_meta["scaling_bwd"].from_numpy(state["scaling_bwd"])
for name, param in self.named_parameters(): # Load extra items.
if param is not None: self.fp8_meta.update(state["extra_fp8_variables"])
assert dtype == param.dtype, ( self.fp8_meta["recipe"].amax_history_len = self.fp8_meta["scaling_fwd"].amax_history.shape[
"Data types for parameters must match when outside of autocasted region. " 0]
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}")
self.activation_dtype = dtype @paddle.no_grad()
def set_state_dict(self, state_dict, use_structured_name=True):
"""Restore FP8 State from checkpoint."""
fp8_state_tensor = state_dict.pop("fp8_state")
self._set_fp8_state(fp8_state_tensor)
return super().set_state_dict(state_dict)
@contextmanager @contextmanager
def prepare_forward( def prepare_forward(
self, self,
inp: paddle.Tensor, inp: paddle.Tensor,
) -> None: num_gemms: int = 1,
""" ) -> Generator[paddle.Tensor, None, None]:
Checks and prep for FWD. """Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one.
""" """
self.set_activation_dtype(inp) self.set_activation_dtype(inp)
self.fp8_init(num_gemms=num_gemms)
# Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False):
amax_and_scale_update(self.fp8_meta, True)
if self.fp8_enabled and self.training:
self.fp8_meta["update_amax_and_scale_fwd"] = True
else:
self.fp8_meta["update_amax_and_scale_fwd"] = False
with nvtx_range(self.__class__.__name__ + " forward"): with nvtx_range(self.__class__.__name__ + " forward"):
yield inp yield inp
@staticmethod
@contextmanager
def prepare_backward(fp8_enabled: bool,
fp8_meta: Dict[str, Any],
name: str = "") -> Generator[None, None, None]:
"""Checks and prep for BWD."""
if fp8_enabled:
amax_and_scale_update(fp8_meta, False)
with nvtx_range(name + " backward"):
yield
@staticmethod
def grad_output_preprocess(
ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]:
"""Utility function for backward.
Returns tuple in order (all optional/None based on training precion/recipe):
R1: gathered `grad_output` in higher precision.
R2: gathered `grad_output` in FP8.
R3: R2 transposed.
R4: bias gradient on R1.
"""
grad_output_mat = grad_output.reshape((-1, grad_output.shape[-1]))
# No-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8_enabled:
return grad_output_mat, None, None, None
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
# FP8 case without gather: cast, transpose, bgrad fused
if ctx.use_bias:
bgrad, grad_output_c, grad_output_t = cast_transpose_bgrad(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
)
bias_dtype = get_bias_dtype(ctx.activation_dtype)
bgrad = cast_if_needed(bgrad, bias_dtype)
else:
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
grad_output_c, grad_output_t = cast_transpose(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
)
else:
grad_output_t = None
grad_output_c = cast_to_fp8(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
)
bgrad = None
return grad_output_mat, grad_output_c, grad_output_t, bgrad
@abstractmethod @abstractmethod
def forward(self): def forward(self):
"""Needs override.""" """Needs override."""
...@@ -4,27 +4,110 @@ ...@@ -4,27 +4,110 @@
"""LayerNormLinear API""" """LayerNormLinear API"""
import os import os
from typing import Union, Tuple from typing import Union, Tuple, Dict, Any
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.nn.initializer import Constant from paddle.nn.initializer import Constant
from ..cpp_extensions import ( from ..cpp_extensions import (
gemm, cast_to_fp8,
cast_from_fp8,
layernorm_fwd, layernorm_fwd,
layernorm_fwd_fp8,
layernorm_bwd, layernorm_bwd,
transpose,
) )
from .base import get_workspace, TransformerEngineBaseLayer from .base import TransformerEngineBaseLayer
from ..constants import TE_DType from .linear import _linear_fwd, _linear_bwd
from ..utils import cast_if_needed from ..constants import TE_DType, FP8FwdTensors, FP8BwdTensors
from ..fp8 import get_fp8_te_dtype
__all__ = ["LayerNormLinear"] from ..utils import cast_if_needed, cast_if_needed_inplace, assert_dim_for_fp8_forward_exec
__all__ = ["LayerNormLinear", "_layernorm_fwd_fp8_cast", "_layernorm_bwd"]
def _layernorm_fwd_fp8_cast(
inputmat: paddle.Tensor,
ln_weight: paddle.Tensor,
ln_bias: paddle.Tensor,
out_fp8_index: FP8FwdTensors,
eps: float,
fp8_enabled: bool,
fp8_meta: Dict[str, Any],
activation_dtype: paddle.dtype,
return_layernorm_output: bool,
fwd_ln_sm_margin: int,
zero_centered_gamma: bool,
):
"""Performs LayerNorm + FP8_Cast for FP8 path. LayerNorm only for BF16 path"""
ln_weight = cast_if_needed_inplace(ln_weight, activation_dtype)
ln_bias = cast_if_needed_inplace(ln_bias, activation_dtype)
if fp8_enabled:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not return_layernorm_output:
ln_out, mu, rsigma = layernorm_fwd_fp8(
inputmat,
ln_weight,
ln_bias,
eps,
fp8_meta["scaling_fwd"],
out_fp8_index,
fp8_dtype_forward,
fwd_ln_sm_margin,
zero_centered_gamma,
)
ln_out_return = ln_out
else:
ln_out_return, mu, rsigma = layernorm_fwd(inputmat, ln_weight, ln_bias, eps,
TE_DType[activation_dtype], fwd_ln_sm_margin,
zero_centered_gamma)
ln_out = cast_to_fp8(
ln_out_return,
fp8_meta["scaling_fwd"],
out_fp8_index,
fp8_dtype_forward,
)
else:
ln_out, mu, rsigma = layernorm_fwd(inputmat, ln_weight, ln_bias, eps,
TE_DType[activation_dtype], fwd_ln_sm_margin,
zero_centered_gamma)
ln_out_return = ln_out
return (
ln_out_return,
ln_out,
mu,
rsigma,
)
def _layernorm_bwd(
inputmat: paddle.Tensor,
dgrad: paddle.Tensor,
ln_weight: paddle.Tensor,
mu: paddle.Tensor,
rsigma: paddle.Tensor,
grad_ln_out_return: paddle.Tensor,
return_layernorm_output: bool,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
):
# LayerNorm gradient
d_ln_out = dgrad.reshape(inputmat.shape)
# Residual gradient
if return_layernorm_output:
d_ln_out = d_ln_out + grad_ln_out_return.reshape(d_ln_out.shape)
return layernorm_bwd(d_ln_out, inputmat, mu, rsigma, ln_weight, bwd_ln_sm_margin,
zero_centered_gamma)
class _LayerNormLinear(paddle.autograd.PyLayer): class _LayerNormLinear(paddle.autograd.PyLayer):
"""TE implementation of non-FP8 LayerNormLinear""" """TE implementation of LayerNormLinear"""
@staticmethod @staticmethod
def forward( def forward(
...@@ -36,8 +119,12 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -36,8 +119,12 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
bias: Union[paddle.Tensor, None], bias: Union[paddle.Tensor, None],
use_bias: bool, use_bias: bool,
eps: float, eps: float,
fp8_enabled: bool,
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
activation_dtype: paddle.dtype, activation_dtype: paddle.dtype,
return_layernorm_output: bool, return_layernorm_output: bool,
is_grad_enabled: bool,
fwd_ln_sm_margin: int, fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int, bwd_ln_sm_margin: int,
zero_centered_gamma: bool, zero_centered_gamma: bool,
...@@ -46,105 +133,165 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -46,105 +133,165 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
in_features = ln_weight.numel() in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible" assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.reshape((-1, in_features)) inputmat = inp.reshape((-1, in_features))
if fp8_enabled:
assert_dim_for_fp8_forward_exec(inputmat)
assert_dim_for_fp8_forward_exec(weight)
ln_out, mu, rsigma = layernorm_fwd(inputmat, ln_weight, ln_bias, eps, # LayerNorm Fwd + FP8 Cast
TE_DType[activation_dtype], fwd_ln_sm_margin, (
zero_centered_gamma) ln_out_return,
out, _, _ = gemm(
weight,
ln_out, ln_out,
mu,
rsigma,
) = _layernorm_fwd_fp8_cast(
inputmat,
ln_weight,
ln_bias,
FP8FwdTensors.GEMM1_INPUT,
eps,
fp8_enabled,
fp8_meta,
activation_dtype, activation_dtype,
get_workspace(), return_layernorm_output,
bias=bias, fwd_ln_sm_margin,
use_bias=use_bias, zero_centered_gamma,
) )
ctx.save_for_backward( # Linear Fwd
inputmat, out, weight_t_fp8 = _linear_fwd(
ln_weight,
mu,
rsigma,
weight,
ln_out, ln_out,
FP8FwdTensors.GEMM1_INPUT,
weight,
FP8FwdTensors.GEMM1_WEIGHT,
bias,
use_bias,
fp8_enabled,
fp8_calibration,
fp8_meta,
activation_dtype,
is_grad_enabled,
) )
ctx.activation_dtype = activation_dtype if is_grad_enabled:
ctx.use_bias = use_bias ctx.save_for_backward(
ctx.inp_shape = inp.shape inputmat,
ctx.return_layernorm_output = return_layernorm_output ln_weight,
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin mu,
ctx.zero_centered_gamma = zero_centered_gamma rsigma,
ctx.requires_dgrad = not inp.stop_gradient weight,
weight_t_fp8 if fp8_enabled else None,
ln_out,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8_enabled else None,
)
ctx.activation_dtype = activation_dtype
ctx.fp8_enabled = fp8_enabled
ctx.fp8_meta = fp8_meta
ctx.use_bias = use_bias
ctx.inp_shape = inp.shape
ctx.return_layernorm_output = return_layernorm_output
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
ctx.requires_dgrad = not inp.stop_gradient
ctx.requires_bgrad = use_bias and not bias.stop_gradient
ctx.requires_ln_bgrad = not ln_bias.stop_gradient
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.reshape((-1, *inp.shape[1:-1], out.shape[-1])) out = out.reshape((-1, *inp.shape[1:-1], out.shape[-1]))
if return_layernorm_output: if return_layernorm_output:
return out, ln_out.reshape(inp.shape) return out, ln_out_return.reshape(inp.shape)
return out return out
@staticmethod @staticmethod
def backward( def backward(
ctx, *grad_outputs: Tuple[paddle.Tensor, ctx, *grad_outputs: Tuple[paddle.Tensor,
...]) -> Tuple[Union[paddle.Tensor, None], ...]: ...]) -> Tuple[Union[paddle.Tensor, None], ...]:
( with TransformerEngineBaseLayer.prepare_backward(ctx.fp8_enabled,
inputmat, ctx.fp8_meta,
ln_weight, name="_LayerNormLinear"):
mu, (
rsigma, inputmat,
weight, ln_weight,
ln_out, mu,
) = ctx.saved_tensor() rsigma,
grad_output = grad_outputs[0] weight,
weight_t_fp8,
# Dgrad
dgrad, _, _ = gemm(
weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
grad=True,
)
# Wgrad
if not weight.stop_gradient:
wgrad, grad_bias, _ = gemm(
ln_out, ln_out,
fwd_scale_inverses,
) = ctx.saved_tensor()
(
grad_output,
grad_output_c,
grad_output_t,
bgrad,
) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_outputs[0])
# Prepare ln_out for Linear bwd
ln_out_no_fp8, ln_out_t = None, None
if ctx.fp8_enabled:
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_wgrad = not ctx.fp8_meta["recipe"].override_linear_precision.wgrad
if not weight.stop_gradient:
if fp8_wgrad:
ln_out_t = transpose(ln_out, fp8_dtype_forward)
else:
ln_out_no_fp8 = cast_from_fp8(
ln_out,
ctx.fp8_meta["scaling_fwd"],
FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
)
# Linear Bwd
dgrad, wgrad, bgrad_ = _linear_bwd(
ln_out_no_fp8 if ctx.fp8_enabled else ln_out,
ln_out_t,
FP8FwdTensors.GEMM1_INPUT,
weight,
weight_t_fp8,
FP8FwdTensors.GEMM1_WEIGHT,
grad_output, grad_output,
grad_output_c,
grad_output_t,
FP8BwdTensors.GRAD_OUTPUT1,
fwd_scale_inverses,
ctx.requires_bgrad,
ctx.fp8_enabled,
ctx.fp8_meta,
True, # Always compute dgrad to feed into LayerNorm bwd
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
) )
# LayerNorm gradient if not ctx.fp8_enabled:
d_ln_out = dgrad.reshape(inputmat.shape) # bgrad is fused with gemm for non-FP8 path
# Residual gradient bgrad = bgrad_
if ctx.return_layernorm_output:
d_ln_out = d_ln_out + grad_outputs[1].reshape(d_ln_out.shape) # LayerNorm Bwd
dxmat, dgamma, dbeta = _layernorm_bwd(
inputmat,
dgrad,
ln_weight,
mu,
rsigma,
grad_outputs[1] if ctx.return_layernorm_output else None,
ctx.return_layernorm_output,
ctx.bwd_ln_sm_margin,
ctx.zero_centered_gamma,
)
dxmat, dgamma, dbeta = layernorm_bwd(d_ln_out, inputmat, mu, rsigma, ln_weight, bgrad = bgrad if ctx.requires_bgrad else None
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma) bgrad_out = (bgrad,) if ctx.use_bias else ()
if not ctx.use_bias:
return ( return (
dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None, dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma, dgamma if not ln_weight.stop_gradient else None,
dbeta, dbeta if ctx.requires_ln_bgrad else None,
wgrad if not weight.stop_gradient else None, wgrad if not weight.stop_gradient else None,
*bgrad_out,
) )
return (
dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma,
dbeta,
wgrad if not weight.stop_gradient else None,
grad_bias,
)
class LayerNormLinear(TransformerEngineBaseLayer): class LayerNormLinear(TransformerEngineBaseLayer):
r""" r"""
...@@ -201,10 +348,11 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -201,10 +348,11 @@ class LayerNormLinear(TransformerEngineBaseLayer):
) )
self.has_bias = self._bias_attr is not False self.has_bias = self._bias_attr is not False
use_default_bias = self._bias_attr is None or self._bias_attr is True
if self.has_bias: if self.has_bias:
self.bias = self.create_parameter( self.bias = self.create_parameter(
shape=[out_features], shape=[out_features],
attr=self._bias_attr if self._bias_attr is not None else paddle.ParamAttr( attr=self._bias_attr if not use_default_bias else paddle.ParamAttr(
initializer=Constant(value=0.0)), initializer=Constant(value=0.0)),
dtype=self._dtype, dtype=self._dtype,
is_bias=True, is_bias=True,
...@@ -228,16 +376,24 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -228,16 +376,24 @@ class LayerNormLinear(TransformerEngineBaseLayer):
""" """
with self.prepare_forward(inp) as inp: with self.prepare_forward(inp) as inp:
# Layer input should be casted outside PyLayer, as performing
# inplace cast to input tensors may cause problems when used
# together with Paddle native layers.
inp = cast_if_needed(inp, self.activation_dtype)
out = _LayerNormLinear.apply( out = _LayerNormLinear.apply(
cast_if_needed(inp, self.activation_dtype), inp,
cast_if_needed(self.ln_weight, self.activation_dtype), self.ln_weight,
cast_if_needed(self.ln_bias, self.activation_dtype), self.ln_bias,
cast_if_needed(self.weight, self.activation_dtype), self.weight,
cast_if_needed(self.bias, self.activation_dtype), self.bias,
self.has_bias, self.has_bias,
self.eps, self.eps,
self.fp8_enabled,
self.fp8_calibration,
self.fp8_meta,
self.activation_dtype, self.activation_dtype,
self.return_layernorm_output, self.return_layernorm_output,
paddle.is_grad_enabled(),
self.fwd_ln_sm_margin, self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
......
...@@ -3,21 +3,325 @@ ...@@ -3,21 +3,325 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Linear API""" """Linear API"""
from typing import Union, Tuple from typing import Union, Tuple, Dict, Any
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.nn.initializer import Constant from paddle.nn.initializer import Constant
from .base import TransformerEngineBaseLayer, get_workspace from .base import (
from ..cpp_extensions import gemm TransformerEngineBaseLayer,
from ..utils import cast_if_needed get_workspace,
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
__all__ = ["Linear"] from ..fp8 import get_fp8_te_dtype
from ..constants import FP8FwdTensors, FP8BwdTensors
from ..cpp_extensions import gemm, fp8_gemm, cast_to_fp8, cast_transpose
from ..utils import (
cast_if_needed,
cast_if_needed_inplace,
assert_dim_for_fp8_forward_exec,
get_bias_dtype,
)
__all__ = ["Linear", "_linear_fwd", "_linear_fwd_fp8", "_linear_bwd", "_linear_fwd_non_fp8"]
def _linear_fwd_fp8(
inputmat: paddle.Tensor,
inputmat_fp8_index: FP8FwdTensors,
weight: paddle.Tensor,
weight_fp8_index: FP8FwdTensors,
bias: paddle.Tensor,
use_bias: bool,
fp8_meta: Dict[str, Any],
activation_dtype: paddle.dtype,
is_grad_enabled: bool,
):
"""FP8 path of Linear Fwd"""
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
bias_dtype = get_bias_dtype(activation_dtype)
bias = cast_if_needed_inplace(bias, bias_dtype)
if is_grad_enabled:
weight_fp8, weight_t_fp8 = cast_transpose(
weight,
fp8_meta["scaling_fwd"],
weight_fp8_index,
fp8_dtype_forward,
)
else:
weight_t_fp8 = None
weight_fp8 = cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
weight_fp8_index,
fp8_dtype_forward,
)
out = fp8_gemm(
weight_fp8,
fp8_meta["scaling_fwd"].scale_inv,
weight_fp8_index,
fp8_dtype_forward,
inputmat,
fp8_meta["scaling_fwd"].scale_inv,
inputmat_fp8_index,
fp8_dtype_forward,
activation_dtype,
get_workspace(),
bias=bias,
use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP,
)
return out, weight_t_fp8
def _linear_fwd_non_fp8(
inputmat: paddle.Tensor,
inputmat_fp8_index: FP8FwdTensors,
weight: paddle.Tensor,
weight_fp8_index: FP8FwdTensors,
bias: paddle.Tensor,
use_bias: bool,
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
activation_dtype: paddle.dtype,
activation: str = "",
):
"""Non-FP8 path of Linear Fwd"""
# Layer parameters are initialized as float32 dtype by default.
# Cast the parameters to activation_dtype if the current dtype
# does not match activation_dtype. The casting is inplace, so it
# only needs to performed once throughout the traing process.
weight = cast_if_needed_inplace(weight, activation_dtype)
bias = cast_if_needed_inplace(bias, activation_dtype)
if fp8_calibration:
# amax of input
fp8_meta["scaling_fwd"].amax_history[0, inputmat_fp8_index.value] = \
paddle.max(paddle.abs(inputmat)).item()
# amax of weight
fp8_meta["scaling_fwd"].amax_history[0, weight_fp8_index.value] = \
paddle.max(paddle.abs(weight)).item()
outputs = gemm(weight,
inputmat,
activation_dtype,
get_workspace(),
bias=bias,
use_bias=use_bias,
gelu=(activation == 'gelu'))
if activation == 'gelu':
gelu_out, _, out = outputs
return out, gelu_out
out, _, _ = outputs
return out
def _linear_fwd(
inputmat: paddle.Tensor,
inputmat_fp8_index: FP8FwdTensors,
weight: paddle.Tensor,
weight_fp8_index: FP8FwdTensors,
bias: paddle.Tensor,
use_bias: bool,
fp8_enabled: bool,
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
activation_dtype: paddle.dtype,
is_grad_enabled: bool,
):
if fp8_enabled:
out, weight_t_fp8 = _linear_fwd_fp8(
inputmat,
inputmat_fp8_index,
weight,
weight_fp8_index,
bias,
use_bias,
fp8_meta,
activation_dtype,
is_grad_enabled,
)
else:
out = _linear_fwd_non_fp8(
inputmat,
inputmat_fp8_index,
weight,
weight_fp8_index,
bias,
use_bias,
fp8_calibration,
fp8_meta,
activation_dtype,
)
return (
out,
weight_t_fp8 if fp8_enabled else None,
)
def _linear_bwd_fp8(
inputmat: paddle.Tensor,
inputmat_t: paddle.Tensor,
inputmat_fp8_index: FP8FwdTensors,
weight_t_fp8: paddle.Tensor,
weight_fp8_index: FP8FwdTensors,
grad_output: paddle.Tensor,
grad_output_c: paddle.Tensor,
grad_output_t: paddle.Tensor,
grad_output_fp8_index: FP8BwdTensors,
fwd_scale_inverses: paddle.Tensor,
fp8_meta: Dict[str, Any],
requires_dgrad: bool,
requires_wgrad: bool,
activation_dtype: paddle.dtype,
):
dgrad, wgrad = None, None
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False)
if requires_dgrad:
dgrad = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses,
weight_fp8_index,
fp8_dtype_forward,
grad_output_c,
fp8_meta["scaling_bwd"].scale_inv,
grad_output_fp8_index,
fp8_dtype_backward,
activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
if requires_wgrad:
if not fp8_meta["recipe"].override_linear_precision.wgrad:
wgrad = fp8_gemm(
inputmat_t,
fwd_scale_inverses,
inputmat_fp8_index,
fp8_dtype_forward,
grad_output_t,
fp8_meta["scaling_bwd"].scale_inv,
grad_output_fp8_index,
fp8_dtype_backward,
activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_WGRAD,
)
else:
wgrad, _, _ = gemm(
inputmat,
grad_output,
activation_dtype,
get_workspace(),
layout="NT",
grad=True,
)
return dgrad, wgrad
def _linear_bwd_non_fp8(
inputmat: paddle.Tensor,
weight: paddle.Tensor,
grad_output: paddle.Tensor,
requires_bgrad: bool,
requires_dgrad: bool,
activation_dtype: paddle.dtype,
gelu_input: Union[paddle.Tensor, None] = None,
activation: str = "",
):
"""
Performs Linear Backward. Optionally, fuses GELU backward and dbias.
"""
dgrad, wgrad, bgrad = None, None, None
requires_wgrad = not weight.stop_gradient
if requires_dgrad:
dgrad, _, _ = gemm(
weight,
grad_output,
activation_dtype,
get_workspace(),
layout="NN",
gelu=(activation == 'gelu'),
gelu_input=gelu_input,
grad=True,
)
if requires_wgrad:
wgrad, bgrad, _ = gemm(
inputmat,
grad_output,
activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=requires_bgrad,
)
elif requires_bgrad:
bgrad = grad_output.sum(axis=0)
return dgrad, wgrad, bgrad
def _linear_bwd(
inputmat: paddle.Tensor,
inputmat_t: paddle.Tensor,
inputmat_fp8_index: FP8FwdTensors,
weight: paddle.Tensor,
weight_t_fp8: paddle.Tensor,
weight_fp8_index: FP8FwdTensors,
grad_output: paddle.Tensor,
grad_output_c: paddle.Tensor,
grad_output_t: paddle.Tensor,
grad_output_fp8_index: FP8BwdTensors,
fwd_scale_inverses: paddle.Tensor,
requires_bgrad: bool,
fp8_enabled: bool,
fp8_meta: Dict[str, Any],
requires_dgrad: bool,
activation_dtype: paddle.dtype,
):
dgrad, wgrad, bgrad = None, None, None
requires_wgrad = not weight.stop_gradient
if fp8_enabled:
dgrad, wgrad = _linear_bwd_fp8(
inputmat,
inputmat_t,
inputmat_fp8_index,
weight_t_fp8,
weight_fp8_index,
grad_output,
grad_output_c,
grad_output_t,
grad_output_fp8_index,
fwd_scale_inverses,
fp8_meta,
requires_dgrad,
requires_wgrad,
activation_dtype,
)
else:
dgrad, wgrad, bgrad = _linear_bwd_non_fp8(
inputmat,
weight,
grad_output,
requires_bgrad,
requires_dgrad,
activation_dtype,
)
return dgrad, wgrad, bgrad
class _Linear(paddle.autograd.PyLayer): class _Linear(paddle.autograd.PyLayer):
"""TE implementation of non-FP8 Linear""" """TE implementation of Linear"""
@staticmethod @staticmethod
def forward( def forward(
...@@ -26,69 +330,138 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -26,69 +330,138 @@ class _Linear(paddle.autograd.PyLayer):
inp: paddle.Tensor, inp: paddle.Tensor,
bias: paddle.Tensor, bias: paddle.Tensor,
use_bias: bool, use_bias: bool,
fp8_enabled: bool,
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
activation_dtype: paddle.dtype, activation_dtype: paddle.dtype,
is_grad_enabled: bool,
) -> paddle.Tensor: ) -> paddle.Tensor:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = weight.shape[-1] in_features = weight.shape[-1]
assert inp.shape[-1] == in_features, "GEMM not possible" assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.reshape((-1, in_features)) inputmat = inp.reshape((-1, in_features))
if fp8_enabled:
assert_dim_for_fp8_forward_exec(inputmat)
assert_dim_for_fp8_forward_exec(weight)
out, _, _ = gemm( inputmat_no_fp8 = inputmat
weight,
inputmat, # FP8 casting
activation_dtype, if fp8_enabled:
get_workspace(), fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
bias=bias,
use_bias=use_bias, if not fp8_meta["recipe"].override_linear_precision.wgrad:
) if is_grad_enabled:
inputmat, inputmat_t = cast_transpose(
inputmat,
fp8_meta["scaling_fwd"],
FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else:
inputmat = cast_to_fp8(
inputmat,
fp8_meta["scaling_fwd"],
FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else:
inputmat, inputmat_t = cast_to_fp8(
inputmat,
fp8_meta["scaling_fwd"],
FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
), None
ctx.save_for_backward( # GEMM Fwd
out, weight_t_fp8 = _linear_fwd(
inputmat, inputmat,
FP8FwdTensors.GEMM1_INPUT,
weight, weight,
FP8FwdTensors.GEMM1_WEIGHT,
bias,
use_bias,
fp8_enabled,
fp8_calibration,
fp8_meta,
activation_dtype,
is_grad_enabled,
) )
ctx.activation_dtype = activation_dtype
ctx.use_bias = use_bias if is_grad_enabled:
ctx.inp_shape = inp.shape fp8_wgrad = fp8_enabled and not fp8_meta["recipe"].override_linear_precision.wgrad
ctx.requires_dgrad = not inp.stop_gradient ctx.save_for_backward(
inputmat_no_fp8 if not weight.stop_gradient and not fp8_wgrad else None,
inputmat_t if not weight.stop_gradient and fp8_wgrad else None,
weight,
weight_t_fp8 if fp8_enabled else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8_enabled else None,
)
ctx.activation_dtype = activation_dtype
ctx.fp8_enabled = fp8_enabled
ctx.fp8_meta = fp8_meta
ctx.use_bias = use_bias
ctx.inp_shape = inp.shape
ctx.requires_dgrad = not inp.stop_gradient
ctx.requires_bgrad = use_bias and not bias.stop_gradient
return out.reshape((-1, *inp.shape[1:-1], out.shape[-1])) return out.reshape((-1, *inp.shape[1:-1], out.shape[-1]))
@staticmethod @staticmethod
def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]:
inputmat, weight = ctx.saved_tensor() with TransformerEngineBaseLayer.prepare_backward(ctx.fp8_enabled,
if ctx.requires_dgrad: ctx.fp8_meta,
dgrad, _, _ = gemm( name="_Linear"):
(
inputmat,
inputmat_t,
weight, weight,
weight_t_fp8,
fwd_scale_inverses,
) = ctx.saved_tensor()
(
grad_output, grad_output,
ctx.activation_dtype, grad_output_c,
get_workspace(), grad_output_t,
layout="NN", bgrad,
grad=True, ) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_output)
)
if not weight.stop_gradient: dgrad, wgrad, bgrad_ = _linear_bwd(
wgrad, grad_bias, _ = gemm(
inputmat, inputmat,
inputmat_t,
FP8FwdTensors.GEMM1_INPUT,
weight,
weight_t_fp8,
FP8FwdTensors.GEMM1_WEIGHT,
grad_output, grad_output,
grad_output_c,
grad_output_t,
FP8BwdTensors.GRAD_OUTPUT1,
fwd_scale_inverses,
ctx.requires_bgrad,
ctx.fp8_enabled,
ctx.fp8_meta,
ctx.requires_dgrad,
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
) )
if not ctx.use_bias: if not ctx.fp8_enabled:
# bgrad is fused with gemm for non-FP8 path
bgrad = bgrad_
if not ctx.use_bias:
return (
wgrad if not weight.stop_gradient else None,
dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
)
return ( return (
wgrad if not weight.stop_gradient else None, wgrad if not weight.stop_gradient else None,
dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None, dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
bgrad if ctx.requires_bgrad else None,
) )
return (
wgrad if not weight.stop_gradient else None,
dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
grad_bias,
)
class Linear(TransformerEngineBaseLayer): class Linear(TransformerEngineBaseLayer):
""" """
...@@ -121,10 +494,11 @@ class Linear(TransformerEngineBaseLayer): ...@@ -121,10 +494,11 @@ class Linear(TransformerEngineBaseLayer):
) )
self.has_bias = self._bias_attr is not False self.has_bias = self._bias_attr is not False
use_default_bias = self._bias_attr is None or self._bias_attr is True
if self.has_bias: if self.has_bias:
self.bias = self.create_parameter( self.bias = self.create_parameter(
shape=[out_features], shape=[out_features],
attr=self._bias_attr if self._bias_attr is not None else paddle.ParamAttr( attr=self._bias_attr if not use_default_bias else paddle.ParamAttr(
initializer=Constant(value=0.0)), initializer=Constant(value=0.0)),
dtype=self._dtype, dtype=self._dtype,
is_bias=True, is_bias=True,
...@@ -139,14 +513,21 @@ class Linear(TransformerEngineBaseLayer): ...@@ -139,14 +513,21 @@ class Linear(TransformerEngineBaseLayer):
""" """
Apply the linear transformation to the input. Apply the linear transformation to the input.
""" """
with self.prepare_forward(inp) as inp: with self.prepare_forward(inp) as inp:
# Layer input should be casted outside PyLayer, as performing
# inplace cast to input tensors may cause problems when used
# together with Paddle native layers.
inp = cast_if_needed(inp, self.activation_dtype)
out = _Linear.apply( out = _Linear.apply(
cast_if_needed(self.weight, self.activation_dtype), self.weight,
cast_if_needed(inp, self.activation_dtype), inp,
cast_if_needed(self.bias, self.activation_dtype), self.bias,
self.has_bias, self.has_bias,
self.fp8_enabled,
self.fp8_calibration,
self.fp8_meta,
self.activation_dtype, self.activation_dtype,
paddle.is_grad_enabled(),
) )
return out return out
......
...@@ -15,6 +15,34 @@ def cast_if_needed(tensor: Union[paddle.Tensor, None], ...@@ -15,6 +15,34 @@ def cast_if_needed(tensor: Union[paddle.Tensor, None],
return tensor if tensor is None or tensor.dtype == dtype else paddle.cast(tensor, dtype) return tensor if tensor is None or tensor.dtype == dtype else paddle.cast(tensor, dtype)
def cast_if_needed_inplace(tensor: Union[paddle.Tensor, None],
dtype: paddle.dtype) -> Union[paddle.Tensor, None]:
"""Cast tensor to dtype (inplace), not to be used on layer inputs"""
return tensor if tensor is None or tensor.dtype == dtype else tensor._to(dtype=dtype)
def check_dim_for_fp8_forward_exec(tensor: paddle.Tensor) -> bool:
"""For fp8 fprop (TN layout), inputs and weights must be such
that dim0 is divisible by 8 and dim1 is divisible by 16.
"""
return not tensor.shape[0] % 8 and not tensor.shape[1] % 16
def assert_dim_for_fp8_forward_exec(tensor: paddle.Tensor) -> None:
"""For fp8 fprop (TN layout), inputs and weights must be such
that dim0 is divisible by 8 and dim1 is divisible by 16.
"""
# single tensor check so it's clear which tensor is triggering the assertion
assert check_dim_for_fp8_forward_exec(tensor), (
"Tensor dimensions are not compatible for FP8 execution: "
f"({tensor.shape[0]} % 8 != 0, {tensor.shape[1]} % 16 != 0)")
def get_bias_dtype(activation_dtype: paddle.dtype):
"""Get bias dtype given activation_dtype"""
return paddle.bfloat16 if activation_dtype == paddle.float32 else activation_dtype
def get_paddle_act_func(activation): def get_paddle_act_func(activation):
"""Get paddle activation function""" """Get paddle activation function"""
funcs = { funcs = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment