Commit 544dd14b authored by Przemek Tredak's avatar Przemek Tredak
Browse files

Update main branch with TE 2.0 code, update version to 2.1.0.dev0


Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent e5369541
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Unittest for Transformer layer in tensor parallel"""
import unittest
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.layers.mpu import mp_ops
from utils import assert_allclose, set_random_seed, register_sequence_parallel_allreduce_hooks
import transformer_engine.paddle as te
class TestAttentionTp(unittest.TestCase):
"""Tests MultiHeadAttention layer with model parallel in BF16"""
def setUp(self):
self.set_attr()
self.init_dist_env()
paddle.set_default_dtype(self.global_dtype)
def init_dist_env(self):
"""Init Paddle Fleet environment"""
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
}
strategy.hybrid_configs["mp_configs"].need_broadcast_data = False
fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index()
self.hcg = fleet.get_hybrid_communicate_group()
self.tp_group = self.hcg.get_model_parallel_group()
self.world_size = self.hcg.get_model_parallel_world_size()
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = "padding"
self.global_dtype = "bfloat16"
self.rtol = 5e-3
self.atol = 5e-3
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = False
def _train_one_step(self, layer, inp_list, optimizer, fp8_enabled, sequence_parallel=False):
inp, mask = inp_list
if sequence_parallel:
split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :]
else:
input_parallel = inp
with te.fp8_autocast(enabled=fp8_enabled):
out = layer(input_parallel, mask)
if sequence_parallel:
total_out = mp_ops._c_concat(out, group=self.tp_group)
total_out = paddle.concat(paddle.split(total_out, self.world_size, axis=-1), axis=0)
else:
total_out = out
loss = total_out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss, total_out
def test_parallel_layer(self):
"""Tests parallel Transformer"""
set_random_seed(1024)
common_args = (
self.hidden_size,
self.num_heads,
)
common_kwargs = {
"layernorm_epsilon": self.eps,
"attention_dropout": 0.0,
"attn_mask_type": self.mask_type,
"attention_type": "self",
"tp_group": self.tp_group,
"input_layernorm": True,
}
layer_tp = te.MultiHeadAttention(
*common_args,
**common_kwargs,
set_parallel_mode=True,
sequence_parallel=self.sequence_parallel,
)
layer_single = te.MultiHeadAttention(*common_args, **common_kwargs, set_parallel_mode=False)
def _get_total_weight(local_weight, tp_group, axis, interleave=False):
total_weight = []
partial_weight = local_weight.clone().detach()
paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group)
if interleave:
# Due to the interleaved qkv layout, need to concat on num_head
# dimension for column parallel linear in MultiHeadAttention layer
assert axis == 0
assert [
3 * self.hidden_size // self.world_size,
self.hidden_size,
] == partial_weight.shape
local_num_head = self.num_heads // self.world_size
for idx, _ in enumerate(total_weight):
total_weight[idx] = total_weight[idx].reshape(
[3, local_num_head, -1, self.hidden_size]
)
total_weight = paddle.concat(total_weight, axis=1).reshape([-1, self.hidden_size])
else:
total_weight = paddle.concat(total_weight, axis=axis)
return total_weight
def _get_weight(obj, weight_names):
for name in weight_names:
obj = getattr(obj, name)
return obj
def copy_weight(layer_src, layer_dst, partition_mode, weight_names, interleave=False):
weight_src = _get_weight(layer_src, weight_names)
weight_dst = _get_weight(layer_dst, weight_names)
if partition_mode is None:
total_weight = weight_src
elif partition_mode == "column":
total_weight = _get_total_weight(
weight_src, tp_group=self.tp_group, axis=0, interleave=interleave
)
elif partition_mode == "row":
total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1)
else:
raise ValueError(f"Partition Mode {partition_mode} is not supported.")
assert (
weight_dst.shape == total_weight.shape
), f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match."
weight_dst.copy_(total_weight, True)
copy_weight(layer_tp, layer_single, None, ["layernorm_qkv", "ln_weight"])
copy_weight(layer_tp, layer_single, "column", ["layernorm_qkv", "weight"], interleave=True)
copy_weight(layer_tp, layer_single, "row", ["proj", "weight"])
if self.sequence_parallel:
register_sequence_parallel_allreduce_hooks(layer_tp, accumulation_steps=1)
optimizer_tp = paddle.optimizer.SGD(learning_rate=0.01, parameters=layer_tp.parameters())
optimizer_single = paddle.optimizer.SGD(
learning_rate=0.01, parameters=layer_single.parameters()
)
layer_tp = fleet.distributed_model(layer_tp)
optimizer_tp = fleet.distributed_optimizer(optimizer_tp)
for _ in range(5):
inp = paddle.uniform(
[self.batch_size, self.q_seqlen, self.hidden_size], self.global_dtype
)
mask = paddle.zeros(
shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), dtype="bool"
)
loss_tp, out_tp = self._train_one_step(
layer_tp, [inp, mask], optimizer_tp, self.fp8, self.sequence_parallel
)
loss_single, out_single = self._train_one_step(
layer_single, [inp, mask], optimizer_single, self.fp8
)
assert_allclose(out_tp, out_single, rtol=self.rtol, atol=self.atol)
assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol)
class TestAttentionTpFp8(TestAttentionTp):
"""Tests MultiHeadAttention layer with model parallel in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = "padding"
self.global_dtype = "bfloat16"
self.rtol = 5e-2
self.atol = 5e-2
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = False
class TestAttentionSp(TestAttentionTp):
"""Tests MultiHeadAttention layer with sequence parallel in BF16"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = "padding"
self.global_dtype = "bfloat16"
self.rtol = 5e-3
self.atol = 5e-3
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = True
class TestAttentionSpFp8(TestAttentionTp):
"""Tests MultiHeadAttention layer with sequence parallel in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = "padding"
self.global_dtype = "bfloat16"
self.rtol = 5e-2
self.atol = 1e-1
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = True
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Unittest for group sharding"""
import unittest
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import (
DygraphShardingOptimizer,
)
from utils import assert_allclose, set_random_seed
import transformer_engine.paddle as te
class TestGroupSharding(unittest.TestCase):
"""Tests group sharding"""
def setUp(self):
self.set_attr()
self.init_dist_env()
paddle.set_default_dtype(self.global_dtype)
def set_attr(self):
"""Set test configs"""
self.sharding_degree = 2
self.global_dtype = "float32"
self.rtol = 1e-5
self.atol = 1e-5
self.batch_size = 16
self.in_channels = 16
self.out_channels = 32
self.fp8 = False
def init_dist_env(self):
"""Init Paddle Fleet environment"""
strategy = fleet.DistributedStrategy()
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": 1,
"pp_degree": 1,
"sharding_degree": self.sharding_degree,
}
self.strategy = strategy
fleet.init(is_collective=True, strategy=strategy)
def _get_model_and_optimizer(self, model, stage):
if stage == 1:
optimizer = DygraphShardingOptimizer(
paddle.optimizer.AdamW(learning_rate=0.01, parameters=model.parameters()),
fleet.get_hybrid_communicate_group(),
)
model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer)
elif stage in [2, 3]:
optimizer = paddle.optimizer.AdamW(learning_rate=0.01, parameters=model.parameters())
group = fleet.get_hybrid_communicate_group().get_sharding_parallel_group()
class ShardingLevel: # pylint: disable=too-few-public-methods,
"""Paddle sharding options"""
kStage1 = "os"
kStage2 = "os_g"
kStage3 = "p_g_os"
level = ShardingLevel.kStage3 if stage == 3 else ShardingLevel.kStage2
model, optimizer, _ = paddle.distributed.sharding.group_sharded_parallel(
model=model,
optimizer=optimizer,
level=level,
group=group,
segment_size=256,
)
else:
raise ValueError(f"Stage {stage} not supported")
return model, optimizer
def test_group_sharding_stage1(self):
"""Tests group sharding training"""
set_random_seed(1024)
model_te = te.Linear(self.in_channels, self.out_channels)
model_pd = paddle.nn.Linear(self.in_channels, self.out_channels)
model_pd.weight.copy_(model_te.weight.T, True)
model_pd.bias.copy_(model_te.bias, True)
model_te, optimizer_te = self._get_model_and_optimizer(model_te, stage=1)
model_pd, optimizer_pd = self._get_model_and_optimizer(model_pd, stage=1)
rank_id = paddle.distributed.get_rank()
paddle.seed(rank_id)
def train_one_step(model, inp, optimizer):
out = model(inp)
loss = out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.in_channels], self.global_dtype)
with te.fp8_autocast(enabled=False):
loss_te = train_one_step(model_te, inp, optimizer_te)
loss_pd = train_one_step(model_pd, inp, optimizer_pd)
assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol)
assert (
len(optimizer_te.state_dict()) == 4
), "Expect each rank to hold 4 optimizer state entries."
def test_group_sharding_stage2(self):
"""Tests group sharding training"""
set_random_seed(1024)
model_te = te.Linear(self.in_channels, self.out_channels)
model_pd = paddle.nn.Linear(self.in_channels, self.out_channels)
model_pd.weight.copy_(model_te.weight.T, True)
model_pd.bias.copy_(model_te.bias, True)
model_te, optimizer_te = self._get_model_and_optimizer(model_te, stage=2)
model_pd, optimizer_pd = self._get_model_and_optimizer(model_pd, stage=2)
rank_id = paddle.distributed.get_rank()
paddle.seed(rank_id)
def train_one_step(model, inp, optimizer):
out = model(inp)
loss = out.mean()
loss.backward()
# Check gradients are split to different trainers
if rank_id == 0:
assert model.bias.grad is None and model.weight.grad is not None
elif rank_id == 1:
assert model.weight.grad is None and model.bias.grad is not None
optimizer.step()
optimizer.clear_grad()
return loss
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.in_channels], self.global_dtype)
with te.fp8_autocast(enabled=False):
loss_te = train_one_step(model_te, inp, optimizer_te)
loss_pd = train_one_step(model_pd, inp, optimizer_pd)
assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol)
assert (
len(optimizer_te.state_dict()) == 4
), "Expect each rank to hold 4 optimizer state entries."
def test_group_sharding_stage3(self):
"""Tests group sharding training"""
set_random_seed(1024)
model_te = te.Linear(self.in_channels, self.out_channels)
model_pd = paddle.nn.Linear(self.in_channels, self.out_channels)
model_pd.weight.copy_(model_te.weight.T, True)
model_pd.bias.copy_(model_te.bias, True)
model_te, optimizer_te = self._get_model_and_optimizer(model_te, stage=3)
model_pd, optimizer_pd = self._get_model_and_optimizer(model_pd, stage=3)
rank_id = paddle.distributed.get_rank()
paddle.seed(rank_id)
def train_one_step(model, inp, optimizer):
out = model(inp)
loss = out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.in_channels], self.global_dtype)
with te.fp8_autocast(enabled=False):
loss_te = train_one_step(model_te, inp, optimizer_te)
loss_pd = train_one_step(model_pd, inp, optimizer_pd)
assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol)
for name, value in optimizer_te.state_dict().items():
if name.endswith("w_0_moment1_0"):
assert (
value.numel() == self.in_channels * self.out_channels // self.sharding_degree
), "Expect optimizer state to be sharded across trainers."
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Unittest for LayerNormLinear layer in tensor parallel"""
import unittest
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.layers.mpu import mp_ops
from utils import assert_allclose, assert_shape, set_random_seed
import transformer_engine.paddle as te
class TestLayerNormLinearTp(unittest.TestCase):
"""Tests LayerNormLinear layer with column/row parallelism in BF16"""
def setUp(self):
self.set_attr()
self.init_dist_env()
paddle.set_default_dtype(self.global_dtype)
def init_dist_env(self):
"""Init Paddle Fleet environment"""
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
}
strategy.hybrid_configs["mp_configs"].need_broadcast_data = False
fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index()
self.hcg = fleet.get_hybrid_communicate_group()
self.tp_group = self.hcg.get_model_parallel_group()
self.world_size = self.hcg.get_model_parallel_world_size()
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = "bfloat16"
self.rtol = 1e-3
self.atol = 1e-3
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = False
def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False):
inp = paddle.to_tensor(inp, stop_gradient=True)
assert split_input in ["none", "column", "row"]
if split_input == "column":
split_size = inp.shape[1] // self.world_size
input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)]
elif split_input == "row":
split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :]
else:
input_parallel = inp
input_parallel.stop_gradient = False
out = layer(input_parallel)
if gather_output:
total_out = mp_ops._c_concat(out, group=self.tp_group)
else:
total_out = out
loss = total_out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
if split_input != "none":
grad_input = []
paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group)
if split_input == "column":
grad_input = paddle.concat(grad_input, axis=1)
elif split_input == "row":
grad_input = paddle.concat(grad_input, axis=0)
else:
grad_input = input_parallel.grad
return loss, grad_input
def test_column_parallel_layer(self):
"""Tests column parallel LayerNormLinear"""
set_random_seed(1024)
layer_te = te.LayerNormLinear(
self.in_features,
self.out_features,
eps=self.eps,
parallel_mode="column",
sequence_parallel=self.sequence_parallel,
)
layer_pd = te.LayerNormLinear(
self.in_features,
self.out_features,
eps=self.eps,
backend="paddle",
)
# Get total weight
total_weight = []
partial_weight = layer_te.weight.clone().detach()
paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group)
total_weight = paddle.concat(total_weight, axis=0)
layer_pd.weight.copy_(total_weight.T, True)
assert_shape(
layer_te.weight, [self.out_features // self.model_parallel_size, self.in_features]
)
assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size])
optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters())
layer_te = fleet.distributed_model(layer_te)
optimizer_te = fleet.distributed_optimizer(optimizer_te)
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype)
with te.fp8_autocast(enabled=self.fp8):
loss_tp, grad_input = self._train_one_step(
layer_te,
inp,
optimizer_te,
split_input="row" if self.sequence_parallel else "none",
gather_output=True,
)
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
class TestLayerNormLinearTpFp8(TestLayerNormLinearTp):
"""Tests LayernormLinear layer with column/row parallelism in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = "bfloat16"
self.rtol = 1e-2
self.atol = 1e-2
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = False
class TestLayerNormLinearSp(TestLayerNormLinearTp):
"""Tests LayernormLinear layer with sequence parallelism"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = "bfloat16"
self.rtol = 1e-3
self.atol = 1e-3
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = True
class TestLayerNormLinearSpFp8(TestLayerNormLinearTp):
"""Tests LayernormLinear layer with sequence parallelism in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = "bfloat16"
self.rtol = 1e-2
self.atol = 1e-2
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = True
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Unittest for LayerNormMLP layer in tensor parallel"""
import unittest
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.layers.mpu import mp_ops
from utils import assert_allclose, assert_shape, set_random_seed
import transformer_engine.paddle as te
class TestLayerNormMLPTp(unittest.TestCase):
"""Tests LayerNormMLP layer with model parallel in BF16"""
def setUp(self):
self.set_attr()
self.init_dist_env()
paddle.set_default_dtype(self.global_dtype)
def init_dist_env(self):
"""Init Paddle Fleet environment"""
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
}
strategy.hybrid_configs["mp_configs"].need_broadcast_data = False
fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index()
self.hcg = fleet.get_hybrid_communicate_group()
self.tp_group = self.hcg.get_model_parallel_group()
self.world_size = self.hcg.get_model_parallel_world_size()
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 32
self.ffn_hidden_size = 64
self.global_dtype = "bfloat16"
self.rtol = 1e-3
self.atol = 1e-3
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = False
def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False):
inp = paddle.to_tensor(inp, stop_gradient=True)
assert split_input in ["none", "column", "row"]
if split_input == "column":
split_size = inp.shape[1] // self.world_size
input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)]
elif split_input == "row":
split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :]
else:
input_parallel = inp
input_parallel.stop_gradient = False
out = layer(input_parallel)
if gather_output:
# Need to concat on the first dim, while _c_concat concats on the last dim
total_out = mp_ops._c_concat(out.T, group=self.tp_group).T
else:
total_out = out
loss = total_out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
if split_input != "none":
grad_input = []
paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group)
if split_input == "column":
grad_input = paddle.concat(grad_input, axis=1)
elif split_input == "row":
grad_input = paddle.concat(grad_input, axis=0)
else:
grad_input = input_parallel.grad
return loss, grad_input
def test_parallel_layer(self):
"""Tests parallel LayerNormMLP"""
set_random_seed(1024)
layer_te = te.LayerNormMLP(
hidden_size=self.hidden_size,
ffn_hidden_size=self.ffn_hidden_size,
eps=self.eps,
set_parallel_mode=True,
sequence_parallel=self.sequence_parallel,
)
layer_pd = te.LayerNormMLP(
hidden_size=self.hidden_size,
ffn_hidden_size=self.ffn_hidden_size,
eps=self.eps,
set_parallel_mode=False,
backend="paddle",
)
def _get_total_weight(local_weight, tp_group, axis):
total_weight = []
partial_weight = local_weight.clone().detach()
paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group)
total_weight = paddle.concat(total_weight, axis=axis)
return total_weight
# Get total weight
total_fc1_weight = _get_total_weight(layer_te.fc1_weight, tp_group=self.tp_group, axis=0)
total_fc2_weight = _get_total_weight(layer_te.fc2_weight, tp_group=self.tp_group, axis=1)
layer_pd.fc1_weight.copy_(total_fc1_weight.T, True)
layer_pd.fc2_weight.copy_(total_fc2_weight.T, True)
assert_shape(
layer_te.fc1_weight,
[self.ffn_hidden_size // self.model_parallel_size, self.hidden_size],
)
assert_shape(layer_te.fc1_bias, [self.ffn_hidden_size // self.model_parallel_size])
assert_shape(
layer_te.fc2_weight,
[self.hidden_size, self.ffn_hidden_size // self.model_parallel_size],
)
assert_shape(layer_te.fc2_bias, [self.hidden_size])
optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters())
layer_te = fleet.distributed_model(layer_te)
optimizer_te = fleet.distributed_optimizer(optimizer_te)
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.hidden_size], self.global_dtype)
with te.fp8_autocast(enabled=self.fp8):
loss_tp, grad_input = self._train_one_step(
layer_te,
inp,
optimizer_te,
split_input="row" if self.sequence_parallel else "none",
gather_output=self.sequence_parallel,
)
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
class TestLayerNormMLPTpFp8(TestLayerNormMLPTp):
"""Tests LayerNormMLP layer with tensor parallelism in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 32
self.ffn_hidden_size = 64
self.global_dtype = "bfloat16"
self.rtol = 1e-2
self.atol = 1e-2
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = False
class TestLayerNormMLPSp(TestLayerNormMLPTp):
"""Tests LayerNormMLP layer with sequence parallel in BF16"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 32
self.ffn_hidden_size = 64
self.global_dtype = "bfloat16"
self.rtol = 1e-3
self.atol = 1e-3
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = True
class TestLayerNormMLPSpFp8(TestLayerNormMLPTp):
"""Tests LayerNormMLP layer with sequence parallelism in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 32
self.ffn_hidden_size = 64
self.global_dtype = "bfloat16"
self.rtol = 1e-2
self.atol = 1e-2
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = True
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Unittest for Linear layer in pipeline parallel"""
import unittest
import numpy as np
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import (
LayerDesc,
PipelineLayer,
)
from utils import assert_allclose, set_random_seed
import transformer_engine.paddle as te
class TELinear(te.Linear):
"""To pass is_first_microbatch"""
def __init__(self, *args, **kwargs):
assert "accumulate_steps" in kwargs
self.accumulate_steps = kwargs["accumulate_steps"]
del kwargs["accumulate_steps"]
self._micro_batch_id = 0
super().__init__(*args, **kwargs)
def forward(self, *args, **kwargs):
kwargs["is_first_microbatch"] = (self._micro_batch_id % self.accumulate_steps) == 0
if paddle.is_grad_enabled() and self.training:
self._micro_batch_id += 1
return super().forward(*args, **kwargs)
class TEPipelineModel(PipelineLayer):
"""Model for pipeline parallel test"""
def __init__(
self,
in_features,
hidden_features,
weight_attrs,
use_te=True,
use_fp8=False,
accumulate_steps=1,
**kwargs,
):
self.in_features = in_features
self.hidden_features = hidden_features
self.fp8 = use_fp8
hcg = fleet.get_hybrid_communicate_group()
self.dp_group = hcg.get_data_parallel_group()
Linear = TELinear if use_te else paddle.nn.Linear
extra_kwargs = {}
if use_te:
extra_kwargs["accumulate_steps"] = accumulate_steps
model_desc = [
LayerDesc(
Linear,
self.in_features,
self.hidden_features,
weight_attr=weight_attrs[0],
**extra_kwargs,
),
LayerDesc(
Linear,
self.hidden_features,
self.in_features,
weight_attr=weight_attrs[1],
**extra_kwargs,
),
]
super().__init__(layers=model_desc, loss_fn=paddle.nn.CrossEntropyLoss(), **kwargs)
def forward(self, *args, **kwargs):
with te.fp8_autocast(enabled=self.fp8, fp8_group=self.dp_group):
return super().forward(*args, **kwargs)
class StandaloneModel(paddle.nn.Layer):
"""Model for pipeline parallel test"""
def __init__(self, in_features, hidden_features, weight_attrs):
super().__init__()
self.in_features = in_features
self.hidden_features = hidden_features
Linear = paddle.nn.Linear
self.layer = paddle.nn.Sequential(
Linear(self.in_features, self.hidden_features, weight_attr=weight_attrs[0]),
Linear(self.hidden_features, self.in_features, weight_attr=weight_attrs[1]),
)
self.loss = paddle.nn.CrossEntropyLoss()
def forward(self, inp):
out = self.layer(inp[0])
loss = self.loss(out, inp[1])
return loss
class TestLinearPipelineParallel(unittest.TestCase):
"""Tests Linear layer with pipeline parallel"""
def setUp(self):
self.set_attr()
self.init_dist_env()
paddle.set_default_dtype(self.global_dtype)
def init_dist_env(self):
"""Init Paddle Fleet environment"""
strategy = fleet.DistributedStrategy()
self.pipeline_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": 1,
"pp_degree": self.pipeline_parallel_size,
}
self.accumulate_steps = self.batch_size // self.micro_batch_size
strategy.pipeline_configs = {
"accumulate_steps": self.accumulate_steps,
"micro_batch_size": self.micro_batch_size,
}
fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index()
self.hcg = fleet.get_hybrid_communicate_group()
def set_attr(self):
"""Set test configs"""
self.batch_size = 32
self.micro_batch_size = 16
self.in_features = 32
self.hidden_features = 64
self.global_dtype = "float32"
self.rtol = 1e-5
self.atol = 1e-5
self.iter = 10
self.fp8 = False
def test_pipeline_train(self):
"""Test pipeline parallel training"""
set_random_seed(1024)
np.random.seed(1024)
weight1_np = np.random.normal(size=[self.in_features, self.hidden_features])
weight2_np = np.random.normal(size=[self.hidden_features, self.in_features])
weight_attrs = [
paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight1_np)),
paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight2_np)),
]
weight_attrs_transposed = [
paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight1_np.T)),
paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight2_np.T)),
]
pipe_model = TEPipelineModel(
self.in_features,
self.hidden_features,
weight_attrs_transposed,
use_te=True,
use_fp8=self.fp8,
seg_method="layer:Linear",
num_stages=self.pipeline_parallel_size,
accumulate_steps=self.accumulate_steps,
)
# Check if model is split across ranks as expected
for name, sublayer in pipe_model.named_sublayers():
if name in ("_loss_fn", "shared_layers"):
continue
if self.rank == 0:
assert tuple(sublayer.weight.shape) == weight1_np.T.shape, (
f"Shape does not match, expect: {weight1_np.T.shape}, "
f"actual: {tuple(sublayer.weight.shape)}"
)
elif self.rank == 1:
assert tuple(sublayer.weight.shape) == weight2_np.T.shape, (
f"Shape does not match, expect: {weight2_np.T.shape}, "
f"actual: {tuple(sublayer.weight.shape)}"
)
standalone_model = StandaloneModel(
self.in_features,
self.hidden_features,
weight_attrs,
)
optimizer_te = paddle.optimizer.SGD(learning_rate=0.1, parameters=pipe_model.parameters())
optimizer_pd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=standalone_model.parameters()
)
pipe_model = fleet.distributed_model(pipe_model)
optimizer_te = fleet.distributed_optimizer(optimizer_te)
def train_one_step(layer, inp, optimizer):
loss = layer(inp)
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss
for i in range(self.iter):
inp = paddle.to_tensor(
np.random.normal(size=[self.batch_size, self.in_features]), dtype=self.global_dtype
)
label = paddle.to_tensor(np.random.randint(self.in_features, size=[self.batch_size, 1]))
loss_te = pipe_model.train_batch([inp, label], optimizer_te)
loss_pd = train_one_step(standalone_model, [inp, label], optimizer_pd)
print(f"Iter: {i}, loss_te: {loss_te.item()}, loss_pd: {loss_pd.item()}")
assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol)
class TestLinearPipelineParallelFP8(TestLinearPipelineParallel):
"""Tests Linear layer with column/row parallelism in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 32
self.micro_batch_size = 16
self.in_features = 32
self.hidden_features = 64
self.global_dtype = "float32"
self.rtol = 5e-2
self.atol = 5e-2
self.iter = 10
self.fp8 = True
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Unittest for Linear layer in tensor parallel"""
import unittest
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.layers.mpu import mp_ops
from utils import assert_allclose, assert_shape, set_random_seed
import transformer_engine.paddle as te
class TestLinearTp(unittest.TestCase):
"""Tests Linear layer with column/row parallelism in BF16"""
def setUp(self):
self.set_attr()
self.init_dist_env()
paddle.set_default_dtype(self.global_dtype)
def init_dist_env(self):
"""Init Paddle Fleet environment"""
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
}
strategy.hybrid_configs["mp_configs"].need_broadcast_data = False
fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index()
self.hcg = fleet.get_hybrid_communicate_group()
self.tp_group = self.hcg.get_model_parallel_group()
self.world_size = self.hcg.get_model_parallel_world_size()
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = "bfloat16"
self.rtol = 1e-3
self.atol = 1e-3
self.fp8 = False
self.sequence_parallel = False
def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False):
inp = paddle.to_tensor(inp, stop_gradient=True)
assert split_input in ["none", "column", "row"]
if split_input == "column":
split_size = inp.shape[1] // self.world_size
input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)]
elif split_input == "row":
split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :]
else:
input_parallel = inp
input_parallel.stop_gradient = False
out = layer(input_parallel)
if gather_output:
total_out = mp_ops._c_concat(out, group=self.tp_group)
else:
total_out = out
loss = total_out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
if split_input != "none":
grad_input = []
paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group)
if split_input == "column":
grad_input = paddle.concat(grad_input, axis=1)
elif split_input == "row":
grad_input = paddle.concat(grad_input, axis=0)
else:
grad_input = input_parallel.grad
return loss, grad_input
def test_column_parallel_layer(self):
"""Tests column parallel linear"""
set_random_seed(1024)
layer_te = te.Linear(
self.in_features,
self.out_features,
parallel_mode="column",
sequence_parallel=self.sequence_parallel,
)
layer_pd = te.Linear(
self.in_features,
self.out_features,
backend="paddle",
)
# Get total weight
total_weight = []
partial_weight = layer_te.weight.clone().detach()
paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group)
total_weight = paddle.concat(total_weight, axis=0)
layer_pd.weight.copy_(total_weight.T, True)
assert_shape(
layer_te.weight, [self.out_features // self.model_parallel_size, self.in_features]
)
assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size])
optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters())
layer_te = fleet.distributed_model(layer_te)
optimizer_te = fleet.distributed_optimizer(optimizer_te)
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype)
with te.fp8_autocast(enabled=self.fp8):
loss_tp, grad_input = self._train_one_step(
layer_te,
inp,
optimizer_te,
split_input="row" if self.sequence_parallel else "none",
gather_output=True,
)
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
def test_row_parallel_layer(self):
"""Tests row parallel linear"""
set_random_seed(1024)
layer_te = te.Linear(
self.in_features,
self.out_features,
parallel_mode="row",
sequence_parallel=self.sequence_parallel,
)
layer_pd = te.Linear(
self.in_features,
self.out_features,
backend="paddle",
)
# Get total weight
total_weight = []
partial_weight = layer_te.weight.clone().detach()
paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group)
total_weight = paddle.concat(total_weight, axis=1)
layer_pd.weight.copy_(total_weight.T, True)
assert_shape(
layer_te.weight, [self.out_features, self.in_features // self.model_parallel_size]
)
assert_shape(layer_te.bias, [self.out_features])
optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters())
layer_te = fleet.distributed_model(layer_te)
optimizer_te = fleet.distributed_optimizer(optimizer_te)
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype)
with te.fp8_autocast(enabled=self.fp8):
loss_tp, grad_input = self._train_one_step(
layer_te,
inp,
optimizer_te,
split_input="column",
gather_output=self.sequence_parallel,
)
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
class TestLinearTpFP8(TestLinearTp):
"""Tests Linear layer with column/row parallelism in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = "bfloat16"
self.rtol = 1e-2
self.atol = 1e-2
self.fp8 = True
self.sequence_parallel = False
class TestLinearSp(TestLinearTp):
"""Tests Linear layer with sequence parallelism"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = "bfloat16"
self.rtol = 1e-3
self.atol = 1e-3
self.fp8 = False
self.sequence_parallel = True
class TestLinearSpFP8(TestLinearTp):
"""Tests Linear layer with sequence parallelism in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = "bfloat16"
self.rtol = 1e-2
self.atol = 1e-2
self.fp8 = True
self.sequence_parallel = True
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Unittest for Transformer layer in tensor parallel"""
import unittest
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.layers.mpu import mp_ops
from utils import assert_allclose, set_random_seed, register_sequence_parallel_allreduce_hooks
import transformer_engine.paddle as te
class TestTransformerTp(unittest.TestCase):
"""Tests Transformer layer with model parallel in BF16"""
def setUp(self):
self.set_attr()
self.init_dist_env()
paddle.set_default_dtype(self.global_dtype)
def init_dist_env(self):
"""Init Paddle Fleet environment"""
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
}
strategy.hybrid_configs["mp_configs"].need_broadcast_data = False
fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index()
self.hcg = fleet.get_hybrid_communicate_group()
self.tp_group = self.hcg.get_model_parallel_group()
self.world_size = self.hcg.get_model_parallel_world_size()
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.ffn_hidden_size = 4096
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = "padding"
self.layer_type = "encoder"
self.global_dtype = "bfloat16"
self.rtol = 5e-2
self.atol = 5e-2
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = False
def _train_one_step(self, layer, inp_list, optimizer, fp8_enabled, sequence_parallel=False):
inp, mask = inp_list
if sequence_parallel:
split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :]
else:
input_parallel = inp
with te.fp8_autocast(enabled=fp8_enabled):
out = layer(input_parallel, mask)
if sequence_parallel:
total_out = mp_ops._c_concat(out, group=self.tp_group)
total_out = paddle.concat(paddle.split(total_out, self.world_size, axis=-1), axis=0)
else:
total_out = out
loss = total_out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss, total_out
def test_parallel_layer(self):
"""Tests parallel Transformer"""
set_random_seed(1024)
common_args = [
self.hidden_size,
self.ffn_hidden_size,
self.num_heads,
]
common_kwargs = {
"layernorm_epsilon": self.eps,
"hidden_dropout": 0.0,
"attention_dropout": 0.0,
"self_attn_mask_type": self.mask_type,
"layer_type": self.layer_type,
}
layer_tp = te.TransformerLayer(
*common_args,
**common_kwargs,
set_parallel_mode=True,
sequence_parallel=self.sequence_parallel,
)
layer_single = te.TransformerLayer(*common_args, **common_kwargs, set_parallel_mode=False)
def _get_total_weight(local_weight, tp_group, axis, interleave=False):
total_weight = []
partial_weight = local_weight.clone().detach()
paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group)
if interleave:
# Due to the interleaved qkv layout, need to concat on num_head
# dimension for column parallel linear in MultiHeadAttention layer
assert axis == 0
assert [
3 * self.hidden_size // self.world_size,
self.hidden_size,
] == partial_weight.shape
local_num_head = self.num_heads // self.world_size
for idx, _ in enumerate(total_weight):
total_weight[idx] = total_weight[idx].reshape(
[3, local_num_head, -1, self.hidden_size]
)
total_weight = paddle.concat(total_weight, axis=1).reshape([-1, self.hidden_size])
else:
total_weight = paddle.concat(total_weight, axis=axis)
return total_weight
def _get_weight(obj, weight_names):
for name in weight_names:
obj = getattr(obj, name)
return obj
def copy_weight(layer_src, layer_dst, partition_mode, weight_names, interleave=False):
weight_src = _get_weight(layer_src, weight_names)
weight_dst = _get_weight(layer_dst, weight_names)
if partition_mode is None:
total_weight = weight_src
elif partition_mode == "column":
total_weight = _get_total_weight(
weight_src, tp_group=self.tp_group, axis=0, interleave=interleave
)
elif partition_mode == "row":
total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1)
else:
raise ValueError(f"Partition Mode {partition_mode} is not supported.")
assert (
weight_dst.shape == total_weight.shape
), f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match."
weight_dst.copy_(total_weight, True)
copy_weight(layer_tp, layer_single, None, ["self_attention", "layernorm_qkv", "ln_weight"])
copy_weight(
layer_tp,
layer_single,
"column",
["self_attention", "layernorm_qkv", "weight"],
interleave=True,
)
copy_weight(layer_tp, layer_single, "row", ["self_attention", "proj", "weight"])
copy_weight(layer_tp, layer_single, None, ["layernorm_mlp", "ln_weight"])
copy_weight(layer_tp, layer_single, "column", ["layernorm_mlp", "fc1_weight"])
copy_weight(layer_tp, layer_single, "row", ["layernorm_mlp", "fc2_weight"])
if self.sequence_parallel:
register_sequence_parallel_allreduce_hooks(layer_tp, accumulation_steps=1)
optimizer_tp = paddle.optimizer.SGD(learning_rate=0.01, parameters=layer_tp.parameters())
optimizer_single = paddle.optimizer.SGD(
learning_rate=0.01, parameters=layer_single.parameters()
)
layer_tp = fleet.distributed_model(layer_tp)
optimizer_tp = fleet.distributed_optimizer(optimizer_tp)
for _ in range(5):
inp = paddle.uniform(
[self.batch_size, self.q_seqlen, self.hidden_size], self.global_dtype
)
mask = paddle.zeros(
shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), dtype="bool"
)
loss_tp, out_tp = self._train_one_step(
layer_tp, [inp, mask], optimizer_tp, self.fp8, self.sequence_parallel
)
loss_single, out_single = self._train_one_step(
layer_single, [inp, mask], optimizer_single, self.fp8
)
assert_allclose(out_tp, out_single, rtol=self.rtol, atol=self.atol)
assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol)
class TestTransformerTpFp8(TestTransformerTp):
"""Tests Transformer layer with tensor parallelism in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.ffn_hidden_size = 4096
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = "padding"
self.layer_type = "encoder"
self.global_dtype = "bfloat16"
self.rtol = 5e-2
self.atol = 0.5
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = False
class TestTransformerSp(TestTransformerTp):
"""Tests Transformer layer with sequence parallel in BF16"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.ffn_hidden_size = 4096
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = "padding"
self.layer_type = "encoder"
self.global_dtype = "bfloat16"
self.rtol = 5e-2
self.atol = 5e-2
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = True
class TestTransformerSpFp8(TestTransformerSp):
"""Tests Transformer layer with sequence parallelism in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.ffn_hidden_size = 4096
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = "padding"
self.layer_type = "encoder"
self.global_dtype = "bfloat16"
self.rtol = 5e-2
self.atol = 0.5
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = True
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Test TransformerLayer encoder recompute"""
import sys
import paddle
import transformer_engine.paddle as te
class Net(paddle.nn.Layer):
"""Network use for recompute testing"""
def __init__(self, layers):
super().__init__()
self.layers = layers
def forward(self, inp, mask, enable_recompute, use_reentrant):
for layer in self.layers:
if enable_recompute:
out = te.recompute(layer, inp, mask, use_reentrant=use_reentrant)
else:
out = layer(inp, mask)
return out
def main():
"""Main function"""
paddle.seed(10)
batch_size = 16
hidden_size = 4096
num_heads = 32
ffn_hidden_size = 16384
q_seqlen = 512
kv_seqlen = 512
num_layers = 4
enable_recompute = int(sys.argv[1])
use_reentrant = int(sys.argv[2])
layers = paddle.nn.LayerList(
[
te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_heads,
layer_type="encoder",
)
for _ in range(num_layers)
]
)
model = Net(layers)
optimizer = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters())
for _ in range(10):
inp = paddle.uniform([batch_size, q_seqlen, hidden_size])
inp.stop_gradient = False
mask = paddle.zeros(shape=(batch_size, 1, q_seqlen, kv_seqlen), dtype="bool")
with te.fp8_autocast(enabled=True):
out = model(inp, mask, enable_recompute, use_reentrant)
loss = out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
print("Loss: ", float(loss))
print("Peak memory: ", paddle.device.cuda.max_memory_allocated(0))
if __name__ == "__main__":
main()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Test basic installation of Paddle extensions"""
def test_import():
"""
Test if Paddle extension can be imported normally
"""
import transformer_engine.paddle # pylint: disable=unused-import
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Test TE Paddle Layer-level APIs"""
import os
from utils import assert_allclose, is_fused_attention_supported
import paddle
import pytest
from transformer_engine.common.recipe import DelayedScaling
import transformer_engine.paddle as te
from transformer_engine.paddle.fp8 import is_fp8_available, fp8_autocast
is_fp8_supported, reason = is_fp8_available()
LINEAR_CASES = [(16, 16, 32), (32, 32, 64)]
NORM_CASES = [(16, 32), (256, 1024)]
@pytest.fixture(autouse=True)
def setup():
"""Setup random seed before each test"""
paddle.seed(10)
yield
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("use_fp8", [True, False])
def test_checkpoint(use_fp8):
"""Test checkpoint save / load"""
bs = 16
in_features = 16
out_features = 32
file_name = "model.pdparams"
input_tensor = paddle.uniform(shape=(bs, in_features), dtype="float32")
model = te.Linear(in_features, out_features)
model_loaded = te.Linear(in_features, out_features)
# Populate amax_history
with fp8_autocast(enabled=False, calibrating=True):
_ = model(input_tensor)
# Save model
paddle.save(model.state_dict(), file_name)
# Get ref output
with fp8_autocast(enabled=use_fp8):
out_ref = model(input_tensor)
# Load model
model_loaded.set_state_dict(paddle.load(file_name))
if os.path.exists(file_name):
os.remove(file_name)
# Get actual output
with fp8_autocast(enabled=use_fp8):
out = model_loaded(input_tensor)
assert_allclose(out, out_ref)
def calc_output_and_grad(layer, x, dy):
"""
Calculate forward and backward pass
"""
inp = paddle.to_tensor(x)
inp.stop_gradient = x.stop_gradient
y = layer(inp)
y.backward(dy)
return y, inp.grad if not inp.stop_gradient else None
@staticmethod
def calc_output_and_grad_ln_out(layer, x, dy, return_ln_out=False):
"""
Calculate forward and backward pass for layernorm
"""
inp = paddle.to_tensor(x)
inp.stop_gradient = x.stop_gradient
outputs = layer(inp)
ln_out = None
if return_ln_out:
y, ln_out = outputs
else:
y = outputs
y.backward(dy)
return y, ln_out, inp.grad if not inp.stop_gradient else None
class TestLinear:
"""
Tests for Linear layer
"""
@staticmethod
@pytest.mark.skipif(
paddle.device.cuda.get_device_capability() < (8, 0),
reason="BF16 Linear requires Ampere+ GPU",
)
@pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
@pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize("no_dgrad", [True, False])
@pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
def test_linear_bf16(
bs, in_features, out_features, has_bias, no_dbias, no_dgrad, no_wgrad, activation_dtype
):
"""
Test BF16 Linear
"""
rtol = 5e-2
atol = 5e-2
input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype)
input_tensor.stop_gradient = no_dgrad
grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)
paddle.set_default_dtype(activation_dtype)
layer_te = te.Linear(in_features, out_features, bias_attr=None if has_bias else False)
layer_pd = te.Linear(
in_features, out_features, bias_attr=None if has_bias else False, backend="paddle"
)
layer_pd.weight.copy_(layer_te.weight.T, True)
if has_bias:
layer_pd.bias.copy_(layer_te.bias, True)
layer_te.weight.stop_gradient = no_wgrad
layer_pd.weight.stop_gradient = no_wgrad
if has_bias:
layer_te.bias.stop_gradient = no_dbias
layer_pd.bias.stop_gradient = no_dbias
out_ref, grad_input_ref = calc_output_and_grad(layer_pd, input_tensor, grad_out)
out, grad_input = calc_output_and_grad(layer_te, input_tensor, grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
if not no_dgrad:
assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
if not no_wgrad:
assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol)
if has_bias and not no_dbias:
assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
@pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize("no_dgrad", [True, False])
@pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize("fp8_wgrad", [True, False])
@pytest.mark.parametrize("do_calibration", [True, False])
@pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
def test_linear_fp8(
bs,
in_features,
out_features,
has_bias,
no_dbias,
no_dgrad,
no_wgrad,
fp8_wgrad,
do_calibration,
activation_dtype,
):
"""
Test FP8 Linear
"""
rtol = 0.1
atol = 0.5
input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype)
input_tensor.stop_gradient = no_dgrad
grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)
recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad))
paddle.set_default_dtype(activation_dtype)
layer_te = te.Linear(
in_features=in_features,
out_features=out_features,
bias_attr=None if has_bias else False,
)
layer_pd = te.Linear(
in_features=in_features,
out_features=out_features,
bias_attr=None if has_bias else False,
backend="paddle",
)
layer_pd.weight.copy_(layer_te.weight.T, True)
if has_bias:
layer_pd.bias.copy_(layer_te.bias, True)
layer_te.weight.stop_gradient = no_wgrad
layer_pd.weight.stop_gradient = no_wgrad
if has_bias:
layer_te.bias.stop_gradient = no_dbias
layer_pd.bias.stop_gradient = no_dbias
with fp8_autocast(
enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe
):
out_ref, grad_input_ref = calc_output_and_grad(layer_pd, input_tensor, grad_out)
out, grad_input = calc_output_and_grad(layer_te, input_tensor, grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
if not no_dgrad:
assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
if not no_wgrad:
assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol)
if has_bias and not no_dbias:
assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol)
if do_calibration:
assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
@pytest.mark.parametrize("activation_dtype", ["bfloat16"])
@pytest.mark.parametrize("num_microbatch", [8])
def test_linear_fp8_microbatch(bs, in_features, out_features, activation_dtype, num_microbatch):
"""
Test FP8 Linear
"""
rtol = 0.1
atol = 0.1
recipe = DelayedScaling()
paddle.set_default_dtype(activation_dtype)
layer_cached = te.Linear(
in_features=in_features,
out_features=out_features,
)
layer_normal = te.Linear(
in_features=in_features,
out_features=out_features,
)
layer_cached.weight.copy_(layer_normal.weight, True)
layer_cached.bias.copy_(layer_normal.bias, True)
for iteration in range(num_microbatch):
input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype)
grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
out = layer_cached(input_tensor, is_first_microbatch=(iteration == 0))
out.backward(grad_out)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
out_ref = layer_normal(input_tensor)
out_ref.backward(grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(
layer_cached.weight.grad, layer_normal.weight.grad, rtol=rtol, atol=atol
)
@pytest.mark.parametrize("bs,hidden_size", NORM_CASES)
@pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize("no_dgrad", [True, False])
@pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
def test_layernorm_bf16(bs, hidden_size, has_bias, no_dbias, no_dgrad, no_wgrad, activation_dtype):
"""
Test BF16 LayerNorm
"""
eps = 1e-3
rtol = 1e-2
atol = 1e-2
x = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
x.stop_gradient = no_dgrad
grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
paddle.set_default_dtype(activation_dtype)
layer_te = te.LayerNorm(hidden_size=hidden_size, eps=eps, bias_attr=None if has_bias else False)
layer_pd = te.LayerNorm(
hidden_size=hidden_size, eps=eps, bias_attr=None if has_bias else False, backend="paddle"
)
layer_pd.weight.copy_(layer_te.weight, True)
if has_bias:
layer_pd.bias.copy_(layer_te.bias, True)
layer_te.weight.stop_gradient = no_wgrad
layer_pd.weight.stop_gradient = no_wgrad
if has_bias:
layer_te.bias.stop_gradient = no_dbias
layer_pd.bias.stop_gradient = no_dbias
out_ref, grad_input_ref = calc_output_and_grad(layer_pd, x, grad_out)
out, grad_input = calc_output_and_grad(layer_te, x, grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
if not no_dgrad:
assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
if not no_wgrad:
assert_allclose(layer_te.weight.grad, layer_pd.weight.grad, rtol=rtol, atol=atol)
if has_bias and not no_dbias:
assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol)
class TestLayerNormLinear:
"""
Tests for LayerNormLinear layer
"""
@staticmethod
@pytest.mark.skipif(
paddle.device.cuda.get_device_capability() < (8, 0),
reason="BF16 Linear requires Ampere+ GPU",
)
@pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
@pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize("no_dgrad", [True, False])
@pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize("return_ln_out", [True, False])
@pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
@pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
def test_layernorm_linear_bf16(
bs,
in_features,
out_features,
has_bias,
no_dbias,
no_dgrad,
no_wgrad,
return_ln_out,
activation_dtype,
normalization,
):
"""
Test BF16 LayerNormLinear Layer
"""
paddle.set_default_dtype(activation_dtype)
rtol = 5e-2
atol = 5e-2
input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype)
input_tensor.stop_gradient = no_dgrad
grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)
eps = 1e-3
has_ln_bias = normalization == "LayerNorm"
layer_te = te.LayerNormLinear(
in_features=in_features,
out_features=out_features,
eps=eps,
normalization=normalization,
bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out,
)
layer_pd = te.LayerNormLinear(
in_features=in_features,
out_features=out_features,
eps=eps,
normalization=normalization,
bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out,
backend="paddle",
)
layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
if has_ln_bias:
layer_pd.ln_bias.copy_(layer_te.ln_bias, True)
layer_pd.weight.copy_(layer_te.weight.T, True)
if has_bias:
layer_pd.bias.copy_(layer_te.bias, True)
layer_te.weight.stop_gradient = no_wgrad
layer_te.ln_weight.stop_gradient = no_wgrad
layer_pd.weight.stop_gradient = no_wgrad
layer_pd.ln_weight.stop_gradient = no_wgrad
if has_ln_bias:
layer_te.ln_bias.stop_gradient = no_dbias
layer_pd.ln_bias.stop_gradient = no_dbias
if has_bias:
layer_te.bias.stop_gradient = no_dbias
layer_pd.bias.stop_gradient = no_dbias
out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out(
layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out
)
out, ln_out, grad_input = calc_output_and_grad_ln_out(
layer_te, input_tensor, grad_out, return_ln_out=return_ln_out
)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
if not no_dgrad:
assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
if not no_wgrad:
assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol)
assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol)
if not no_dbias:
if has_ln_bias:
assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol)
if has_bias:
assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol)
if return_ln_out:
assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
@pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize("no_dgrad", [True, False])
@pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize("fp8_wgrad", [True, False])
@pytest.mark.parametrize("do_calibration", [True, False])
@pytest.mark.parametrize("return_ln_out", [True, False])
@pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
@pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
def test_layernorm_linear_fp8(
bs,
in_features,
out_features,
has_bias,
no_dbias,
no_dgrad,
no_wgrad,
fp8_wgrad,
do_calibration,
return_ln_out,
activation_dtype,
normalization,
):
"""
Test FP8 LayerNormLinear Layer
"""
paddle.set_default_dtype(activation_dtype)
rtol = 0.1
atol = 0.75
input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype)
input_tensor.stop_gradient = no_dgrad
grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)
eps = 1e-3
has_ln_bias = normalization == "LayerNorm"
recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad))
layer_te = te.LayerNormLinear(
in_features=in_features,
out_features=out_features,
eps=eps,
normalization=normalization,
bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out,
)
layer_pd = te.LayerNormLinear(
in_features=in_features,
out_features=out_features,
eps=eps,
normalization=normalization,
bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out,
backend="paddle",
)
layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
if has_ln_bias:
layer_pd.ln_bias.copy_(layer_te.ln_bias, True)
layer_pd.weight.copy_(layer_te.weight.T, True)
if has_bias:
layer_pd.bias.copy_(layer_te.bias, True)
layer_te.weight.stop_gradient = no_wgrad
layer_te.ln_weight.stop_gradient = no_wgrad
layer_pd.weight.stop_gradient = no_wgrad
layer_pd.ln_weight.stop_gradient = no_wgrad
if has_ln_bias:
layer_te.ln_bias.stop_gradient = no_dbias
layer_pd.ln_bias.stop_gradient = no_dbias
if has_bias:
layer_te.bias.stop_gradient = no_dbias
layer_pd.bias.stop_gradient = no_dbias
with fp8_autocast(
enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe
):
out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out(
layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out
)
out, ln_out, grad_input = calc_output_and_grad_ln_out(
layer_te, input_tensor, grad_out, return_ln_out=return_ln_out
)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
if not no_dgrad:
assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
if not no_wgrad:
assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol)
assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol)
if not no_dbias:
if has_ln_bias:
assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol)
if has_bias:
assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol)
if return_ln_out:
assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol)
if do_calibration:
assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
@pytest.mark.parametrize("activation_dtype", ["bfloat16"])
@pytest.mark.parametrize("num_microbatch", [8])
def test_layernorm_linear_fp8_microbatch(
bs, in_features, out_features, activation_dtype, num_microbatch
):
"""
Test FP8 LayerNormLinear Layer
"""
paddle.set_default_dtype(activation_dtype)
eps = 1e-3
rtol = 0.5
atol = 0.5
recipe = DelayedScaling()
layer_cached = te.LayerNormLinear(
in_features=in_features,
out_features=out_features,
eps=eps,
)
layer_normal = te.LayerNormLinear(
in_features=in_features,
out_features=out_features,
eps=eps,
)
layer_cached.ln_weight.copy_(layer_normal.ln_weight, True)
layer_cached.ln_bias.copy_(layer_normal.ln_bias, True)
layer_cached.weight.copy_(layer_normal.weight, True)
layer_cached.bias.copy_(layer_normal.bias, True)
for iteration in range(num_microbatch):
input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype)
grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
out = layer_cached(input_tensor, is_first_microbatch=(iteration == 0))
out.backward(grad_out)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
out_ref = layer_normal(input_tensor)
out_ref.backward(grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(
layer_cached.weight.grad, layer_normal.weight.grad, rtol=rtol, atol=atol
)
assert_allclose(
layer_cached.ln_weight.grad, layer_normal.ln_weight.grad, rtol=rtol, atol=atol
)
class TestLayerNormMLP:
"""
Test LayerNormMLP Layer
"""
@staticmethod
@pytest.mark.skipif(
paddle.device.cuda.get_device_capability() < (8, 0),
reason="BF16 Linear requires Ampere+ GPU",
)
@pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES)
@pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize("no_dgrad", [True, False])
@pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize("return_ln_out", [True, False])
@pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
@pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
@pytest.mark.parametrize("activation", ["gelu", "swiglu"])
def test_layernorm_mlp_bf16(
bs,
hidden_size,
ffn_hidden_size,
has_bias,
no_dbias,
no_dgrad,
no_wgrad,
return_ln_out,
activation_dtype,
normalization,
activation,
):
"""
Tests for TestLayerNormMLP layer
"""
paddle.set_default_dtype(activation_dtype)
rtol = 5e-2
atol = 5e-2
input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
input_tensor.stop_gradient = no_dgrad
grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
eps = 1e-3
has_ln_bias = normalization == "LayerNorm"
layer_te = te.LayerNormMLP(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
eps=eps,
normalization=normalization,
activation=activation,
bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out,
)
layer_pd = te.LayerNormMLP(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
eps=eps,
normalization=normalization,
activation=activation,
bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out,
backend="paddle",
)
layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
if has_ln_bias:
layer_pd.ln_bias.copy_(layer_te.ln_bias, True)
layer_pd.fc1_weight.copy_(layer_te.fc1_weight.T, True)
layer_pd.fc2_weight.copy_(layer_te.fc2_weight.T, True)
if has_bias:
layer_pd.fc1_bias.copy_(layer_te.fc1_bias, True)
layer_pd.fc2_bias.copy_(layer_te.fc2_bias, True)
layer_te.fc1_weight.stop_gradient = no_wgrad
layer_te.fc2_weight.stop_gradient = no_wgrad
layer_te.ln_weight.stop_gradient = no_wgrad
layer_pd.fc1_weight.stop_gradient = no_wgrad
layer_pd.fc2_weight.stop_gradient = no_wgrad
layer_pd.ln_weight.stop_gradient = no_wgrad
if has_ln_bias:
layer_te.ln_bias.stop_gradient = no_dbias
layer_pd.ln_bias.stop_gradient = no_dbias
if has_bias:
layer_te.fc1_bias.stop_gradient = no_dbias
layer_te.fc2_bias.stop_gradient = no_dbias
layer_pd.fc1_bias.stop_gradient = no_dbias
layer_pd.fc2_bias.stop_gradient = no_dbias
out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out(
layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out
)
out, ln_out, grad_input = calc_output_and_grad_ln_out(
layer_te, input_tensor, grad_out, return_ln_out=return_ln_out
)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
if not no_dgrad:
assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
if not no_wgrad:
assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol)
assert_allclose(
layer_te.fc1_weight.grad, layer_pd.fc1_weight.grad.T, rtol=rtol, atol=atol
)
assert_allclose(
layer_te.fc2_weight.grad, layer_pd.fc2_weight.grad.T, rtol=rtol, atol=atol
)
if not no_dbias:
if has_ln_bias:
assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol)
if has_bias:
assert_allclose(
layer_te.fc1_bias.grad, layer_pd.fc1_bias.grad, rtol=rtol, atol=atol
)
assert_allclose(
layer_te.fc2_bias.grad, layer_pd.fc2_bias.grad, rtol=rtol, atol=atol
)
if return_ln_out:
assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES)
@pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize("no_dgrad", [True, False])
@pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize("fp8_wgrad", [True, False])
@pytest.mark.parametrize("do_calibration", [True, False])
@pytest.mark.parametrize("return_ln_out", [True, False])
@pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
@pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
@pytest.mark.parametrize("activation", ["gelu", "swiglu"])
def test_layernorm_mlp_fp8(
bs,
hidden_size,
ffn_hidden_size,
has_bias,
no_dbias,
no_dgrad,
no_wgrad,
fp8_wgrad,
do_calibration,
return_ln_out,
activation_dtype,
normalization,
activation,
):
"""
Test FP8 LayerNormMLP Layer
"""
paddle.set_default_dtype(activation_dtype)
rtol = 0.1
atol = 0.75
input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
input_tensor.stop_gradient = no_dgrad
grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
eps = 1e-3
has_ln_bias = normalization == "LayerNorm"
recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad))
layer_te = te.LayerNormMLP(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
eps=eps,
normalization=normalization,
activation=activation,
bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out,
)
layer_pd = te.LayerNormMLP(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
eps=eps,
normalization=normalization,
activation=activation,
bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out,
backend="paddle",
)
layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
if has_ln_bias:
layer_pd.ln_bias.copy_(layer_te.ln_bias, True)
layer_pd.fc1_weight.copy_(layer_te.fc1_weight.T, True)
layer_pd.fc2_weight.copy_(layer_te.fc2_weight.T, True)
if has_bias:
layer_pd.fc1_bias.copy_(layer_te.fc1_bias, True)
layer_pd.fc2_bias.copy_(layer_te.fc2_bias, True)
layer_te.fc1_weight.stop_gradient = no_wgrad
layer_te.fc2_weight.stop_gradient = no_wgrad
layer_te.ln_weight.stop_gradient = no_wgrad
layer_pd.fc1_weight.stop_gradient = no_wgrad
layer_pd.fc2_weight.stop_gradient = no_wgrad
layer_pd.ln_weight.stop_gradient = no_wgrad
if has_ln_bias:
layer_te.ln_bias.stop_gradient = no_dbias
layer_pd.ln_bias.stop_gradient = no_dbias
if has_bias:
layer_te.fc1_bias.stop_gradient = no_dbias
layer_te.fc2_bias.stop_gradient = no_dbias
layer_pd.fc1_bias.stop_gradient = no_dbias
layer_pd.fc2_bias.stop_gradient = no_dbias
with fp8_autocast(
enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe
):
out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out(
layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out
)
out, ln_out, grad_input = calc_output_and_grad_ln_out(
layer_te, input_tensor, grad_out, return_ln_out=return_ln_out
)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
if not no_dgrad:
assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
if not no_wgrad:
assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol)
assert_allclose(
layer_te.fc1_weight.grad, layer_pd.fc1_weight.grad.T, rtol=rtol, atol=atol
)
assert_allclose(
layer_te.fc2_weight.grad, layer_pd.fc2_weight.grad.T, rtol=rtol, atol=atol
)
if not no_dbias:
if has_ln_bias:
assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol)
if has_bias:
assert_allclose(
layer_te.fc1_bias.grad, layer_pd.fc1_bias.grad, rtol=rtol, atol=atol
)
assert_allclose(
layer_te.fc2_bias.grad, layer_pd.fc2_bias.grad, rtol=rtol, atol=atol
)
if return_ln_out:
assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol)
if do_calibration:
assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES)
@pytest.mark.parametrize("activation_dtype", ["bfloat16"])
@pytest.mark.parametrize("num_microbatch", [8])
def test_layernorm_mlp_fp8_microbatch(
bs, hidden_size, ffn_hidden_size, activation_dtype, num_microbatch
):
"""
Test FP8 LayerNormMLP Layer
"""
paddle.set_default_dtype(activation_dtype)
rtol = 1e-5
atol = 1e-5
eps = 1e-3
recipe = DelayedScaling()
layer_cached = te.LayerNormMLP(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
eps=eps,
)
layer_normal = te.LayerNormMLP(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
eps=eps,
)
layer_normal.ln_weight.copy_(layer_cached.ln_weight, True)
layer_normal.ln_bias.copy_(layer_cached.ln_bias, True)
layer_normal.fc1_weight.copy_(layer_cached.fc1_weight, True)
layer_normal.fc2_weight.copy_(layer_cached.fc2_weight, True)
layer_normal.fc1_bias.copy_(layer_cached.fc1_bias, True)
layer_normal.fc2_bias.copy_(layer_cached.fc2_bias, True)
# Calibration to make sure weight scale is the same
input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe):
_ = layer_cached(input_tensor)
with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe):
_ = layer_normal(input_tensor)
for iteration in range(num_microbatch):
input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
out = layer_cached(input_tensor, is_first_microbatch=(iteration == 0))
out.backward(grad_out)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
out_ref = layer_normal(input_tensor)
out_ref.backward(grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(
layer_cached.ln_weight.grad, layer_normal.ln_weight.grad, rtol=rtol, atol=atol
)
assert_allclose(
layer_cached.fc1_weight.grad, layer_normal.fc1_weight.grad, rtol=rtol, atol=atol
)
assert_allclose(
layer_cached.fc2_weight.grad, layer_normal.fc2_weight.grad, rtol=rtol, atol=atol
)
@pytest.mark.parametrize("bs", [1, 2])
@pytest.mark.parametrize("hidden_size, num_heads", [[1024, 16]])
@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]])
@pytest.mark.parametrize("attn_type", ["self", "cross"])
@pytest.mark.parametrize("mask_type", ["causal", "padding"])
@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"])
@pytest.mark.parametrize("deterministic", [True, False])
def test_dot_product_attention(
bs, hidden_size, num_heads, q_seqlen, kv_seqlen, attn_type, mask_type, math_dtype, deterministic
):
"""
Test DotProductAttention Layer
"""
paddle.set_default_dtype(math_dtype)
rtol = 1e-4
atol = 2e-2
head_size = hidden_size // num_heads
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
num_heads=num_heads,
num_gqa_groups=num_heads,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
head_size=head_size,
dtype=math_dtype,
dropout=0.0,
qkv_layout="bshd_bshd_bshd",
bias_type="no_bias",
mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
attn_q_input = paddle.normal(
mean=0.0, std=0.02, shape=(bs, q_seqlen, num_heads, head_size)
).astype(math_dtype)
attn_k_input = paddle.normal(
mean=0.0, std=0.02, shape=(bs, kv_seqlen, num_heads, head_size)
).astype(math_dtype)
attn_v_input = paddle.normal(
mean=0.0, std=0.02, shape=(bs, kv_seqlen, num_heads, head_size)
).astype(math_dtype)
q_actual_seqlen = paddle.randint(low=20, high=q_seqlen, shape=(bs,), dtype="int32")
kv_actual_seqlen = (
paddle.randint(low=20, high=kv_seqlen, shape=(bs,), dtype="int32")
if attn_type == "cross"
else q_actual_seqlen
)
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool")
grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, num_heads, head_size)).astype(
"float32"
)
for i in range(0, bs):
grad_out[i, q_actual_seqlen[i] :, :, :] = 0
grad_out = grad_out.astype(math_dtype)
for i in range(0, bs):
attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False
head_size = hidden_size // num_heads
if deterministic:
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
layer_te = te.DotProductAttention(
num_heads,
head_size,
attention_dropout=0.0,
attn_mask_type=mask_type,
attention_type=attn_type,
backend="transformer_engine",
)
layer_pd = te.DotProductAttention(
num_heads,
head_size,
attention_dropout=0.0,
attn_mask_type=mask_type,
attention_type=attn_type,
backend="paddle",
)
def calc_attn_output_and_grad(layer, q, k, v, mask, dout):
_q = paddle.to_tensor(q, stop_gradient=False)
_k = paddle.to_tensor(k, stop_gradient=False)
_v = paddle.to_tensor(v, stop_gradient=False)
out = layer(_q, _k, _v, mask)
out.backward(dout)
return out, _q.grad, _k.grad, _v.grad
out, q_grad, k_grad, v_grad = calc_attn_output_and_grad(
layer_te, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out
)
out_ref, q_grad_ref, k_grad_ref, v_grad_ref = calc_attn_output_and_grad(
layer_pd, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out
)
valid_out_ref = paddle.full_like(out_ref, 0)
for i in range(0, bs):
valid_out_ref[i, 0 : q_actual_seqlen[i], :, :] = out_ref[i, 0 : q_actual_seqlen[i], :, :]
valid_q_grad_ref = paddle.full_like(q_grad_ref, 0)
valid_k_grad_ref = paddle.full_like(k_grad_ref, 0)
valid_v_grad_ref = paddle.full_like(v_grad_ref, 0)
for i in range(0, bs):
valid_q_grad_ref[i, 0 : q_actual_seqlen[i], :, :] = q_grad_ref[
i, 0 : q_actual_seqlen[i], :, :
]
valid_k_grad_ref[i, 0 : kv_actual_seqlen[i], :, :] = k_grad_ref[
i, 0 : kv_actual_seqlen[i], :, :
]
valid_v_grad_ref[i, 0 : kv_actual_seqlen[i], :, :] = v_grad_ref[
i, 0 : kv_actual_seqlen[i], :, :
]
assert_allclose(out, valid_out_ref, rtol=rtol, atol=atol)
assert_allclose(q_grad, valid_q_grad_ref, rtol=rtol, atol=atol)
assert_allclose(k_grad, valid_k_grad_ref, rtol=rtol, atol=atol)
assert_allclose(v_grad, valid_v_grad_ref, rtol=rtol, atol=atol)
if deterministic:
out2, q_grad2, k_grad2, v_grad2 = calc_attn_output_and_grad(
layer_te, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out
)
assert_allclose(out, out2, rtol=1e-12, atol=1e-12)
assert_allclose(q_grad, q_grad2, rtol=1e-12, atol=1e-12)
assert_allclose(k_grad, k_grad2, rtol=1e-12, atol=1e-12)
assert_allclose(v_grad, v_grad2, rtol=1e-12, atol=1e-12)
os.environ.pop("NVTE_ALLOW_NONDETERMINISTIC_ALGO", None)
@pytest.mark.parametrize("bs", [1, 2])
@pytest.mark.parametrize("num_gqa_groups", [1, 2, 4])
@pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[256, 4, 1024]])
@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]])
@pytest.mark.parametrize("has_bias, no_dbias", [[False, True], [True, True], [True, False]])
@pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize("mask_type", ["causal", "padding"])
@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"])
@pytest.mark.parametrize("output_layernorm", [True, False])
@pytest.mark.parametrize("return_layernorm_output", [True, False])
@pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
def test_transformer_encoder_layer(
bs,
hidden_size,
num_heads,
num_gqa_groups,
ffn_hidden_size,
has_bias,
no_dbias,
no_wgrad,
q_seqlen,
kv_seqlen,
mask_type,
math_dtype,
output_layernorm,
return_layernorm_output,
normalization,
):
"""
Test Transformer Encoder Layer
"""
paddle.set_default_dtype(math_dtype)
rtol = 5e-2
atol = 5e-2
eps = 1e-3
has_ln_bias = normalization == "LayerNorm"
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
num_heads=num_heads,
num_gqa_groups=num_gqa_groups,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
head_size=hidden_size // num_heads,
dtype=math_dtype,
dropout=0.0,
qkv_layout="bshd_bshd_bshd",
bias_type="no_bias",
mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)
q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen
kv_actual_seqlen = q_actual_seqlen
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool")
grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, hidden_size)).astype(
"float32"
)
for i in range(0, bs):
grad_out[i, q_actual_seqlen[i] :, :] = 0
grad_out = grad_out.astype(math_dtype)
for i in range(0, bs):
attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False
layer_te = te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_heads,
num_gqa_groups=num_gqa_groups,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
weight_attr=None,
bias_attr=None if has_bias else False,
self_attn_mask_type=mask_type,
apply_residual_connection_post_layernorm=return_layernorm_output,
output_layernorm=output_layernorm,
layer_type="encoder",
normalization=normalization,
backend="transformer_engine",
)
layer_pd = te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_heads,
num_gqa_groups=num_gqa_groups,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
weight_attr=None,
bias_attr=None if has_bias else False,
self_attn_mask_type=mask_type,
apply_residual_connection_post_layernorm=return_layernorm_output,
output_layernorm=output_layernorm,
layer_type="encoder",
normalization=normalization,
backend="paddle",
)
# MultiHeadAttention params
if output_layernorm:
layer_pd.self_attention.qkv.weight.copy_(layer_te.self_attention.qkv.weight.T, True)
layer_pd.self_attention.qkv.weight.stop_gradient = no_wgrad
layer_te.self_attention.qkv.weight.stop_gradient = no_wgrad
if has_bias:
layer_pd.self_attention.qkv.bias.copy_(layer_te.self_attention.qkv.bias, True)
layer_pd.self_attention.qkv.bias.stop_gradient = no_dbias
layer_te.self_attention.qkv.bias.stop_gradient = no_dbias
else:
layer_pd.self_attention.layernorm_qkv.ln_weight.copy_(
layer_te.self_attention.layernorm_qkv.ln_weight, True
)
layer_pd.self_attention.layernorm_qkv.weight.copy_(
layer_te.self_attention.layernorm_qkv.weight.T, True
)
layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
if has_ln_bias:
layer_pd.self_attention.layernorm_qkv.ln_bias.copy_(
layer_te.self_attention.layernorm_qkv.ln_bias, True
)
layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
if has_bias:
layer_pd.self_attention.layernorm_qkv.bias.copy_(
layer_te.self_attention.layernorm_qkv.bias, True
)
layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias
layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias
layer_pd.self_attention.proj.weight.copy_(layer_te.self_attention.proj.weight.T, True)
layer_pd.self_attention.proj.weight.stop_gradient = no_wgrad
layer_te.self_attention.proj.weight.stop_gradient = no_wgrad
if has_bias:
layer_pd.self_attention.proj.bias.copy_(layer_te.self_attention.proj.bias, True)
layer_pd.self_attention.proj.bias.stop_gradient = no_dbias
layer_te.self_attention.proj.bias.stop_gradient = no_dbias
# LayerNorm MLP params
layer_pd.layernorm_mlp.ln_weight.copy_(layer_te.layernorm_mlp.ln_weight, True)
layer_pd.layernorm_mlp.fc1_weight.copy_(layer_te.layernorm_mlp.fc1_weight.T, True)
layer_pd.layernorm_mlp.fc2_weight.copy_(layer_te.layernorm_mlp.fc2_weight.T, True)
layer_pd.layernorm_mlp.ln_weight.stop_gradient = no_wgrad
layer_pd.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad
layer_pd.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad
layer_te.layernorm_mlp.ln_weight.stop_gradient = no_wgrad
layer_te.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad
layer_te.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad
if has_ln_bias:
layer_pd.layernorm_mlp.ln_bias.copy_(layer_te.layernorm_mlp.ln_bias, True)
layer_pd.layernorm_mlp.ln_bias.stop_gradient = no_dbias
layer_te.layernorm_mlp.ln_bias.stop_gradient = no_dbias
if has_bias:
layer_pd.layernorm_mlp.fc1_bias.copy_(layer_te.layernorm_mlp.fc1_bias, True)
layer_pd.layernorm_mlp.fc2_bias.copy_(layer_te.layernorm_mlp.fc2_bias, True)
layer_pd.layernorm_mlp.fc1_bias.stop_gradient = no_dbias
layer_pd.layernorm_mlp.fc2_bias.stop_gradient = no_dbias
layer_te.layernorm_mlp.fc1_bias.stop_gradient = no_dbias
layer_te.layernorm_mlp.fc2_bias.stop_gradient = no_dbias
if output_layernorm:
layer_pd.layernorm.weight.copy_(layer_te.layernorm.weight, True)
layer_pd.layernorm.bias.copy_(layer_te.layernorm.bias, True)
layer_pd.layernorm.weight.stop_gradient = no_wgrad
layer_pd.layernorm.bias.stop_gradient = no_dbias
layer_te.layernorm.weight.stop_gradient = no_wgrad
layer_te.layernorm.bias.stop_gradient = no_dbias
def calc_transformer_output_and_grad(layer, encoder_input, mask, dout):
_encoder_input = paddle.to_tensor(encoder_input, stop_gradient=False)
out = layer(_encoder_input, mask)
out.backward(dout)
return out, _encoder_input.grad
out_ref, grad_input_ref = calc_transformer_output_and_grad(
layer_pd, encoder_input, attn_mask, grad_out
)
out, grad_input = calc_transformer_output_and_grad(layer_te, encoder_input, attn_mask, grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
if not no_wgrad:
if output_layernorm:
assert_allclose(
layer_te.self_attention.qkv.weight.grad,
layer_pd.self_attention.qkv.weight.grad.T,
rtol=rtol,
atol=atol,
)
else:
assert_allclose(
layer_te.self_attention.layernorm_qkv.weight.grad,
layer_pd.self_attention.layernorm_qkv.weight.grad.T,
rtol=rtol,
atol=atol,
)
if not no_dbias:
if output_layernorm:
assert_allclose(
layer_te.self_attention.qkv.bias.grad,
layer_pd.self_attention.qkv.bias.grad,
rtol=0.01,
atol=0.5,
)
else:
assert_allclose(
layer_te.self_attention.layernorm_qkv.bias.grad,
layer_pd.self_attention.layernorm_qkv.bias.grad,
rtol=0.01,
atol=0.5,
)
@pytest.mark.parametrize("bs", [1, 2])
@pytest.mark.parametrize("num_gqa_groups", [1, 2, 4])
@pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[256, 4, 1024]])
@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]])
@pytest.mark.parametrize("has_bias, no_dbias", [[False, True], [True, True], [True, False]])
@pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize("mask_type", ["causal", "padding"])
@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"])
@pytest.mark.parametrize("output_layernorm", [True, False])
@pytest.mark.parametrize("return_layernorm_output", [True, False])
@pytest.mark.parametrize("recompute_core_attention", [True, False])
@pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
def test_transformer_decoder_layer(
bs,
hidden_size,
num_heads,
num_gqa_groups,
ffn_hidden_size,
has_bias,
no_dbias,
no_wgrad,
q_seqlen,
kv_seqlen,
mask_type,
math_dtype,
output_layernorm,
return_layernorm_output,
recompute_core_attention,
normalization,
):
"""
Test Transformer Decoder Layer
"""
paddle.set_default_dtype(math_dtype)
rtol = 5e-2
atol = 6e-2
eps = 1e-3
has_ln_bias = normalization == "LayerNorm"
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
num_heads=num_heads,
num_gqa_groups=num_gqa_groups,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
head_size=hidden_size // num_heads,
dtype=math_dtype,
dropout=0.0,
qkv_layout="bshd_bshd_bshd",
bias_type="no_bias",
mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
encoder_input = paddle.normal(mean=0.0, std=0.1, shape=(bs, q_seqlen, hidden_size)).astype(
math_dtype
)
encoder_output = paddle.normal(mean=0.0, std=0.1, shape=(bs, kv_seqlen, hidden_size)).astype(
math_dtype
)
q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen
kv_actual_seqlen = q_actual_seqlen
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool")
grad_out = paddle.normal(mean=0.0, std=0.01, shape=(bs, q_seqlen, hidden_size)).astype(
"float32"
)
# rounding to avoid numerical issues
encoder_input = paddle.round(encoder_input * 1000) / 1000
encoder_output = paddle.round(encoder_output * 1000) / 1000
grad_out = paddle.round(grad_out * 1000) / 1000
for i in range(0, bs):
grad_out[i, q_actual_seqlen[i] :, :] = 0
grad_out = grad_out.astype(math_dtype)
for i in range(0, bs):
attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False
layer_te = te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_heads,
num_gqa_groups=num_gqa_groups,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
weight_attr=None,
bias_attr=None if has_bias else False,
self_attn_mask_type=mask_type,
apply_residual_connection_post_layernorm=return_layernorm_output,
output_layernorm=output_layernorm,
layer_type="decoder",
normalization=normalization,
backend="transformer_engine",
)
layer_pd = te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_heads,
num_gqa_groups=num_gqa_groups,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
weight_attr=None,
bias_attr=None if has_bias else False,
self_attn_mask_type=mask_type,
apply_residual_connection_post_layernorm=return_layernorm_output,
output_layernorm=output_layernorm,
layer_type="decoder",
normalization=normalization,
backend="paddle",
)
# MultiHeadAttention params - self attn
if output_layernorm:
layer_pd.self_attention.qkv.weight.copy_(layer_te.self_attention.qkv.weight.T, True)
layer_pd.self_attention.qkv.weight.stop_gradient = no_wgrad
layer_te.self_attention.qkv.weight.stop_gradient = no_wgrad
if has_bias:
layer_pd.self_attention.qkv.bias.copy_(layer_te.self_attention.qkv.bias, True)
layer_pd.self_attention.qkv.bias.stop_gradient = no_dbias
layer_te.self_attention.qkv.bias.stop_gradient = no_dbias
else:
layer_pd.self_attention.layernorm_qkv.ln_weight.copy_(
layer_te.self_attention.layernorm_qkv.ln_weight, True
)
layer_pd.self_attention.layernorm_qkv.weight.copy_(
layer_te.self_attention.layernorm_qkv.weight.T, True
)
layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
if has_ln_bias:
layer_pd.self_attention.layernorm_qkv.ln_bias.copy_(
layer_te.self_attention.layernorm_qkv.ln_bias, True
)
layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
if has_bias:
layer_pd.self_attention.layernorm_qkv.bias.copy_(
layer_te.self_attention.layernorm_qkv.bias, True
)
layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias
layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias
layer_pd.self_attention.proj.weight.copy_(layer_te.self_attention.proj.weight.T, True)
layer_pd.self_attention.proj.weight.stop_gradient = no_wgrad
layer_te.self_attention.proj.weight.stop_gradient = no_wgrad
if has_bias:
layer_pd.self_attention.proj.bias.copy_(layer_te.self_attention.proj.bias, True)
layer_pd.self_attention.proj.bias.stop_gradient = no_dbias
layer_te.self_attention.proj.bias.stop_gradient = no_dbias
# MultiHeadAttention params - cross attn
layer_pd.inter_attention.layernorm_query.ln_weight.copy_(
layer_te.inter_attention.layernorm_query.ln_weight, True
)
layer_pd.inter_attention.layernorm_query.weight.copy_(
layer_te.inter_attention.layernorm_query.weight.T, True
)
layer_pd.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad
layer_pd.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad
layer_te.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad
layer_te.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad
if has_ln_bias:
layer_pd.inter_attention.layernorm_query.ln_bias.copy_(
layer_te.inter_attention.layernorm_query.ln_bias, True
)
layer_pd.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias
layer_te.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias
if has_bias:
layer_pd.inter_attention.layernorm_query.bias.copy_(
layer_te.inter_attention.layernorm_query.bias, True
)
layer_pd.inter_attention.layernorm_query.bias.stop_gradient = no_dbias
layer_te.inter_attention.layernorm_query.bias.stop_gradient = no_dbias
layer_pd.inter_attention.key_value.weight.copy_(
layer_te.inter_attention.key_value.weight.T, True
)
layer_pd.inter_attention.key_value.weight.stop_gradient = no_wgrad
layer_te.inter_attention.key_value.weight.stop_gradient = no_wgrad
layer_pd.inter_attention.proj.weight.copy_(layer_te.inter_attention.proj.weight.T, True)
layer_pd.inter_attention.proj.weight.stop_gradient = no_wgrad
layer_te.inter_attention.proj.weight.stop_gradient = no_wgrad
if has_bias:
layer_pd.inter_attention.key_value.bias.copy_(layer_te.inter_attention.key_value.bias, True)
layer_pd.inter_attention.key_value.bias.stop_gradient = no_dbias
layer_te.inter_attention.key_value.bias.stop_gradient = no_dbias
layer_pd.inter_attention.proj.bias.copy_(layer_te.inter_attention.proj.bias, True)
layer_pd.inter_attention.proj.bias.stop_gradient = no_dbias
layer_te.inter_attention.proj.bias.stop_gradient = no_dbias
# LayerNorm MLP params
layer_pd.layernorm_mlp.ln_weight.copy_(layer_te.layernorm_mlp.ln_weight, True)
layer_pd.layernorm_mlp.fc1_weight.copy_(layer_te.layernorm_mlp.fc1_weight.T, True)
layer_pd.layernorm_mlp.fc2_weight.copy_(layer_te.layernorm_mlp.fc2_weight.T, True)
layer_pd.layernorm_mlp.ln_weight.stop_gradient = no_wgrad
layer_pd.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad
layer_pd.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad
layer_te.layernorm_mlp.ln_weight.stop_gradient = no_wgrad
layer_te.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad
layer_te.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad
if has_ln_bias:
layer_pd.layernorm_mlp.ln_bias.copy_(layer_te.layernorm_mlp.ln_bias, True)
layer_pd.layernorm_mlp.ln_bias.stop_gradient = no_dbias
layer_te.layernorm_mlp.ln_bias.stop_gradient = no_dbias
if has_bias:
layer_pd.layernorm_mlp.fc1_bias.copy_(layer_te.layernorm_mlp.fc1_bias, True)
layer_pd.layernorm_mlp.fc2_bias.copy_(layer_te.layernorm_mlp.fc2_bias, True)
layer_pd.layernorm_mlp.fc1_bias.stop_gradient = no_dbias
layer_pd.layernorm_mlp.fc2_bias.stop_gradient = no_dbias
layer_te.layernorm_mlp.fc1_bias.stop_gradient = no_dbias
layer_te.layernorm_mlp.fc2_bias.stop_gradient = no_dbias
if output_layernorm:
layer_pd.layernorm.weight.copy_(layer_te.layernorm.weight, True)
layer_pd.layernorm.bias.copy_(layer_te.layernorm.bias, True)
layer_pd.layernorm.weight.stop_gradient = no_wgrad
layer_pd.layernorm.bias.stop_gradient = no_dbias
layer_te.layernorm.weight.stop_gradient = no_wgrad
layer_te.layernorm.bias.stop_gradient = no_dbias
def calc_transformer_output_and_grad(
layer,
encoder_input,
mask,
encoder_output,
enc_dec_attn_mask,
dout,
recompute_core_attention=False,
):
_encoder_input = paddle.to_tensor(encoder_input, stop_gradient=False)
_encoder_output = paddle.to_tensor(encoder_output, stop_gradient=False)
out = layer(
_encoder_input,
mask,
_encoder_output,
enc_dec_attn_mask,
recompute_core_attention=recompute_core_attention,
)
out.backward(dout)
return out, _encoder_input.grad, _encoder_output.grad
out_ref, grad_encoder_input_ref, grad_encoder_output_ref = calc_transformer_output_and_grad(
layer_pd, encoder_input, attn_mask, encoder_output, attn_mask, grad_out
)
out, grad_encoder_input, grad_encoder_output = calc_transformer_output_and_grad(
layer_te,
encoder_input,
attn_mask,
encoder_output,
attn_mask,
grad_out,
recompute_core_attention=recompute_core_attention,
)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(grad_encoder_input, grad_encoder_input_ref, rtol=rtol, atol=atol)
assert_allclose(grad_encoder_output, grad_encoder_output_ref, rtol=rtol, atol=atol)
if not no_wgrad:
if output_layernorm:
assert_allclose(
layer_te.self_attention.qkv.weight.grad,
layer_pd.self_attention.qkv.weight.grad.T,
rtol=rtol,
atol=atol,
)
else:
assert_allclose(
layer_te.self_attention.layernorm_qkv.weight.grad,
layer_pd.self_attention.layernorm_qkv.weight.grad.T,
rtol=rtol,
atol=atol,
)
assert_allclose(
layer_te.inter_attention.layernorm_query.weight.grad,
layer_pd.inter_attention.layernorm_query.weight.grad.T,
rtol=rtol,
atol=atol,
)
if not no_dbias:
if output_layernorm:
assert_allclose(
layer_te.self_attention.qkv.bias.grad,
layer_pd.self_attention.qkv.bias.grad,
rtol=0.5,
atol=0.6,
)
else:
assert_allclose(
layer_te.self_attention.layernorm_qkv.bias.grad,
layer_pd.self_attention.layernorm_qkv.bias.grad,
rtol=0.01,
atol=0.5,
)
assert_allclose(
layer_te.inter_attention.layernorm_query.bias.grad,
layer_pd.inter_attention.layernorm_query.bias.grad,
rtol=rtol,
atol=atol,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("bs", [8])
@pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[1024, 16, 4096]])
@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[128, 128]])
@pytest.mark.parametrize("mask_type", ["causal"])
@pytest.mark.parametrize("math_dtype", ["bfloat16"])
@pytest.mark.parametrize("num_microbatch", [8])
def test_transformer_encoder_layer_microbatch(
bs,
hidden_size,
num_heads,
ffn_hidden_size,
q_seqlen,
kv_seqlen,
mask_type,
math_dtype,
num_microbatch,
):
"""
Test Transformer Encoder Layer with FP8 weight caching
"""
paddle.set_default_dtype(math_dtype)
rtol = 1e-5
atol = 1e-5
eps = 1e-3
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
num_heads=num_heads,
num_gqa_groups=num_heads,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
head_size=hidden_size // num_heads,
dtype=math_dtype,
dropout=0.0,
qkv_layout="bs3hd",
bias_type="no_bias",
mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
layer_cached = te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_heads,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
weight_attr=None,
bias_attr=None,
self_attn_mask_type=mask_type,
layer_type="encoder",
)
layer_normal = te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_heads,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
weight_attr=None,
bias_attr=None,
self_attn_mask_type=mask_type,
layer_type="encoder",
)
layer_normal.self_attention.layernorm_qkv.ln_weight.copy_(
layer_cached.self_attention.layernorm_qkv.ln_weight, True
)
layer_normal.self_attention.layernorm_qkv.ln_bias.copy_(
layer_cached.self_attention.layernorm_qkv.ln_bias, True
)
layer_normal.self_attention.layernorm_qkv.weight.copy_(
layer_cached.self_attention.layernorm_qkv.weight, True
)
layer_normal.self_attention.layernorm_qkv.bias.copy_(
layer_cached.self_attention.layernorm_qkv.bias, True
)
layer_normal.self_attention.proj.weight.copy_(layer_cached.self_attention.proj.weight, True)
layer_normal.self_attention.proj.bias.copy_(layer_cached.self_attention.proj.bias, True)
# LayerNorm MLP params
layer_normal.layernorm_mlp.ln_weight.copy_(layer_cached.layernorm_mlp.ln_weight, True)
layer_normal.layernorm_mlp.ln_bias.copy_(layer_cached.layernorm_mlp.ln_bias, True)
layer_normal.layernorm_mlp.fc1_weight.copy_(layer_cached.layernorm_mlp.fc1_weight, True)
layer_normal.layernorm_mlp.fc2_weight.copy_(layer_cached.layernorm_mlp.fc2_weight, True)
layer_normal.layernorm_mlp.fc1_bias.copy_(layer_cached.layernorm_mlp.fc1_bias, True)
layer_normal.layernorm_mlp.fc2_bias.copy_(layer_cached.layernorm_mlp.fc2_bias, True)
recipe = DelayedScaling()
def generate_input():
encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)
q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen
kv_actual_seqlen = q_actual_seqlen
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool")
grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, hidden_size)).astype(
"float32"
)
for i in range(0, bs):
grad_out[i, q_actual_seqlen[i] :, :] = 0
grad_out = grad_out.astype(math_dtype)
for i in range(0, bs):
attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False
return encoder_input, attn_mask, grad_out
# Calibration to make sure weight scale is the same
encoder_input, mask, _ = generate_input()
with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe):
_ = layer_cached(encoder_input, mask)
with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe):
_ = layer_normal(encoder_input, mask)
for iteration in range(num_microbatch):
encoder_input, mask, grad_out = generate_input()
with fp8_autocast(enabled=True, fp8_recipe=recipe):
out = layer_cached(encoder_input, mask, is_first_microbatch=(iteration == 0))
out.backward(grad_out)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
out_ref = layer_normal(encoder_input, mask)
out_ref.backward(grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(
layer_cached.self_attention.layernorm_qkv.weight.grad,
layer_normal.self_attention.layernorm_qkv.weight.grad,
rtol=rtol,
atol=atol,
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Test TransformerLayer encoder main_grad"""
import numpy as np
import pytest
import paddle
from paddle.distributed.fleet.utils import mix_precision_utils
import transformer_engine.paddle as te
from transformer_engine.paddle.fp8 import is_fp8_available
is_fp8_supported, reason = is_fp8_available()
def create_optimizer(model, use_pure_bf16, use_main_grad):
"""Create optimizer"""
if use_main_grad:
assert use_pure_bf16
model = mix_precision_utils.MixPrecisionLayer(model, dtype="bfloat16")
optimizer = paddle.optimizer.AdamW(
parameters=model.parameters(),
learning_rate=0.0001,
multi_precision=use_pure_bf16,
)
if use_main_grad:
optimizer = mix_precision_utils.MixPrecisionOptimizer(optimizer)
return optimizer
class Net(paddle.nn.Layer):
"""Network use for main_grad testing"""
def __init__(self, fuse_wgrad_accumulation):
super().__init__()
self.layer = te.TransformerLayer(
4096,
16384,
32,
layer_type="encoder",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
)
def forward(self, inp):
out = self.layer(inp)
return out
def train(enable_master_grad, fuse_wgrad_accumulation=False):
"""Train function"""
paddle.seed(10)
accumulate_steps = 4
if fuse_wgrad_accumulation:
assert enable_master_grad, "fuse_wgrad_accumulation requires enable_master_grad"
model = Net(fuse_wgrad_accumulation)
optimizer = create_optimizer(model, use_pure_bf16=True, use_main_grad=enable_master_grad)
loss_list = []
for step_id in range(16):
inp = paddle.uniform([2, 1024, 4096], dtype="float32")
inp.stop_gradient = False
with te.fp8_autocast(enabled=True):
out = model(inp)
loss = out.mean()
loss_list.append(loss)
loss.backward()
# gradient accumulation
if (step_id + 1) % accumulate_steps == 0:
optimizer.step()
optimizer.clear_grad()
return loss_list
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
def test_master_grad():
"""Test main_grad"""
paddle.set_default_dtype("float32")
loss1 = train(enable_master_grad=False)
loss2 = train(enable_master_grad=True)
loss3 = train(enable_master_grad=True, fuse_wgrad_accumulation=True)
np.testing.assert_allclose(loss1, loss2, rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(loss1, loss3, rtol=1e-5, atol=1e-5)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Test TE operators"""
import struct
import numpy as np
import paddle
import paddle.nn.functional as F
import pytest
from utils import (
assert_allclose,
create_fp8_meta,
get_fused_attention_backend,
is_fused_attention_supported,
)
from transformer_engine import transformer_engine_paddle as tex
from transformer_engine.paddle.cpp_extensions import (
cast_to_fp8,
cast_from_fp8,
gemm,
fp8_gemm,
transpose,
cast_transpose,
cast_transpose_bgrad,
te_gelu,
gelu_fp8,
swiglu,
swiglu_fp8,
swiglu_pd,
dswiglu,
dgelu_cast_transpose_bgrad_fp8,
layernorm_fwd_fp8,
layernorm_fwd,
layernorm_bwd,
rmsnorm_fwd_fp8,
rmsnorm_fwd,
rmsnorm_bwd,
fused_attn_fwd_qkvpacked,
fused_attn_bwd_qkvpacked,
fused_attn_fwd_kvpacked,
fused_attn_bwd_kvpacked,
fused_attn_fwd,
fused_attn_bwd,
scaled_softmax_forward,
scaled_softmax_backward,
scaled_masked_softmax_forward,
scaled_masked_softmax_backward,
scaled_upper_triang_masked_softmax_forward,
scaled_upper_triang_masked_softmax_backward,
)
from transformer_engine.paddle.fp8 import is_fp8_available
from transformer_engine.paddle.constants import FP8FwdTensors
from transformer_engine.common.recipe import DelayedScaling
GEMM_CASES = [
(256, 256, 512),
(32, 32, 32),
(16384, 1024, 2816),
(16384, 2816, 1024),
(16384, 1024, 1024),
]
is_fp8_supported, reason = is_fp8_available()
SELF_ATTN_CASES = [(2, 512, 12, 64)]
CROSS_ATTN_CASES = [(2, 128, 512, 12, 64)]
FLASH_ATTN_CASES = [(2, 1024, 16, 64), (2, 2048, 16, 128)]
ATTN_DTYPES = [tex.DType.kFloat16, tex.DType.kBFloat16]
@pytest.fixture(autouse=True)
def setup():
"""Setup random seed before each test"""
np.random.seed(10)
paddle.seed(11)
yield
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("inplace", [True, False])
def test_quantize_dequantize(fp8_dtype, inplace):
"""
Test cast_to_fp8 and cast_from_fp8
"""
a = paddle.rand(shape=(32, 32), dtype="float32")
# Init fp8_meta
fp8_meta = create_fp8_meta()
a_fp8 = paddle.zeros(shape=a.shape, dtype=paddle.uint8) if inplace else None
a_fp8 = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_OUTPUT, otype=fp8_dtype, out=a_fp8)
b = cast_from_fp8(
a_fp8,
fp8_meta,
FP8FwdTensors.GEMM1_OUTPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32,
)
assert_allclose(a, b, rtol=5e-2, atol=5e-2)
def copy_bits_from_float_to_uint16(f):
"""
Copy bits
"""
return struct.unpack("<I", struct.pack("<f", f))[0] >> 16
def convert_float_to_uint16(float_list):
"""
convert float to uint16
"""
new_output = []
for x in np.nditer(float_list):
new_output.append(np.uint16(copy_bits_from_float_to_uint16(x)))
new_output = np.reshape(new_output, float_list.shape).view(np.uint16)
return new_output
class TestTranspose:
"""
Test transpose operators
"""
@staticmethod
def test_transpose_bf16():
"""
Test BF16 transpose
"""
a = paddle.rand(shape=(16, 32), dtype="bfloat16")
a_transposed = transpose(a, otype=tex.DType.kBFloat16)
assert_allclose(a_transposed, a.T)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_transpose_fp8(fp8_dtype):
"""
Test FP8 transpose
"""
min_val = -8
max_val = 8
a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32")
fp8_meta = create_fp8_meta()
a_fp8 = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
a_fp8_transposed = transpose(a_fp8, otype=fp8_dtype)
a_transposed = cast_from_fp8(
a_fp8_transposed,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32,
)
assert_allclose(a_transposed, a.T)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("inplace", [True, False])
def test_cast_transpose(fp8_dtype, inplace):
"""
Test cast_transpose
"""
min_val = -8
max_val = 8
a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32")
fp8_meta = create_fp8_meta()
a_fp8_casted, a_fp8_transposed = None, None
if inplace:
a_fp8_casted = paddle.zeros(shape=a.shape, dtype=paddle.uint8)
a_fp8_transposed = paddle.zeros(shape=a.T.shape, dtype=paddle.uint8)
a_fp8_casted, a_fp8_transposed = cast_transpose(
a,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
otype=fp8_dtype,
cast_out=a_fp8_casted,
transpose_out=a_fp8_transposed,
)
a_transposed = cast_from_fp8(
a_fp8_transposed,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32,
)
a_casted = cast_from_fp8(
a_fp8_casted,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32,
)
assert_allclose(a_casted, a)
assert_allclose(a_transposed, a.T)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_cast_transpose_bgrad(fp8_dtype):
"""
Test cast_transpose_bgrad
"""
min_val = -8
max_val = 8
a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32")
fp8_meta = create_fp8_meta()
bgrad, a_fp8_casted, a_fp8_transposed = cast_transpose_bgrad(
a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype
)
a_transposed = cast_from_fp8(
a_fp8_transposed,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32,
)
a_casted = cast_from_fp8(
a_fp8_casted,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32,
)
assert_allclose(a_casted, a)
assert_allclose(a_transposed, a.T)
assert_allclose(bgrad, a.sum(axis=0))
class TestActivation:
"""
Test activation operators
"""
@staticmethod
def test_gelu_bf16():
"""
Test BF16 GELU Forward
"""
a = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1
gelu_out = te_gelu(a, otype=tex.DType.kBFloat16)
gelu_ref = paddle.nn.GELU()(a)
assert_allclose(gelu_out, gelu_ref, rtol=1e-2)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_gelu_fp8(fp8_dtype):
"""
Test FP8 GELU Forward
"""
a = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1
fp8_meta = create_fp8_meta()
gelu_out_fp8 = gelu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
gelu_out = cast_from_fp8(
gelu_out_fp8,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32,
)
gelu_ref = paddle.nn.GELU()(a)
assert_allclose(gelu_out, gelu_ref, rtol=0.1, atol=0.01)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_gelu_bwd_fp8(fp8_dtype):
"""
Test FP8 GELU Backward
"""
# y = GELU(x), calculate ref
x = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1
x.stop_gradient = False
y = paddle.nn.GELU()(x)
y_grad = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1
paddle.autograd.backward([y], [y_grad], True)
# calculate fp8
fp8_meta = create_fp8_meta()
x_grad_fp8, x_grad_t_fp8, dbias = dgelu_cast_transpose_bgrad_fp8(
y_grad, x, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype
)
x_grad = cast_from_fp8(
x_grad_fp8,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32,
)
x_grad_t = cast_from_fp8(
x_grad_t_fp8,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32,
)
assert_allclose(x_grad, x.grad, rtol=0.1, atol=0.01)
assert_allclose(x_grad_t, x.grad.T, rtol=0.1, atol=0.01)
assert_allclose(dbias, x.grad.sum(axis=0), rtol=0.1, atol=0.01)
@staticmethod
def test_swiglu_bf16():
"""
Test BF16 SwiGLU Forward
"""
a = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1
swiglu_out = swiglu(a, otype=tex.DType.kBFloat16)
swiglu_ref = swiglu_pd(a)
assert_allclose(swiglu_out, swiglu_ref, rtol=1e-2)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_swiglu_fp8(fp8_dtype):
"""
Test FP8 SwiGLU Forward
"""
a = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1
fp8_meta = create_fp8_meta()
swiglu_out_fp8 = swiglu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
swiglu_out = cast_from_fp8(
swiglu_out_fp8,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32,
)
swiglu_ref = swiglu_pd(a)
assert_allclose(swiglu_out, swiglu_ref, rtol=0.1, atol=0.01)
@staticmethod
def test_swiglu_bwd():
"""
Test SwiGLU Backward
"""
# y = SwiGLU(x), calculate ref
x = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1
x.stop_gradient = False
y = swiglu_pd(x)
y_grad = paddle.rand(shape=(16, 16), dtype="bfloat16") * 2 - 1
paddle.autograd.backward([y], [y_grad], True)
# calculate fp8
x_grad = dswiglu(y_grad, x, otype=tex.DType.kBFloat16)
assert_allclose(x_grad, x.grad, rtol=0.1, atol=0.01)
class TestGemm:
"""
Tests for gemm(cuBLASLt) operator
"""
@staticmethod
@pytest.mark.skipif(
paddle.device.cuda.get_device_capability() < (8, 0), reason="BF16 GEMM requires Ampere+ GPU"
)
@pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_bf16(m, n, k):
"""
Test "TN" BF16 GEMM
"""
a = paddle.rand(shape=(m, k), dtype="bfloat16")
b = paddle.rand(shape=(n, k), dtype="bfloat16")
workspace = paddle.zeros(shape=[33_554_432], dtype="uint8")
ref_out = paddle.matmul(a, b.T)
# CublasLt inside tex.te_gemm assumes inputs are column major.
# Mathematically, A@B=C is equivalent to B^T@A^T=C^T, where X^T is the
# transpose of X.
# Here we perform "TN" GEMM in column major, i.e., b@a^T = C^T,
# which is equivalent to a@b^T = C in row major.
actual_out, _, _ = gemm(
b, a, paddle.bfloat16, workspace, False, None, False, False, "TN", None, None, False
)
assert_allclose(actual_out, ref_out, rtol=1.6e-2, atol=1e-5)
@staticmethod
@pytest.mark.skipif(
paddle.device.cuda.get_device_capability() < (8, 0), reason="BF16 GEMM requires Ampere+ GPU"
)
@pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_bf16_inplace(m, n, k):
"""
Test "TN" BF16 GEMM, with accumulate=True
"""
min_val = -16
max_val = 16
a = paddle.rand(shape=(m, k), dtype="bfloat16")
b = paddle.rand(shape=(n, k), dtype="bfloat16")
c = paddle.cast(paddle.randint(min_val, max_val, shape=(m, n)), "bfloat16")
workspace = paddle.zeros(shape=[33_554_432], dtype="uint8")
ref_out = c + paddle.matmul(a, b.T)
actual_out = paddle.clone(c)
_, _, _ = gemm(
b,
a,
paddle.bfloat16,
workspace,
False,
None,
False,
True,
"TN",
actual_out,
None,
False,
)
assert_allclose(actual_out, ref_out, rtol=5e-2, atol=5e-2)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_fp8_randint(m, n, k):
"""
Test "TN" FP8 GEMM
"""
min_val = -4
max_val = 4
fp8_dtype = tex.DType.kFloat8E4M3
out_dtype = paddle.float32
fp8_meta = create_fp8_meta(num_gemms=1)
a = paddle.cast(paddle.randint(min_val, max_val, shape=(m, k)), "float32")
a_casted = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
b = paddle.cast(paddle.randint(min_val, max_val, shape=(n, k)), "float32")
b_casted = cast_to_fp8(b, fp8_meta, FP8FwdTensors.GEMM1_WEIGHT, otype=fp8_dtype)
workspace = paddle.zeros(shape=[33_554_432], dtype="uint8")
ref_out = paddle.matmul(a, b.T)
actual_out, _ = fp8_gemm(
b_casted,
fp8_meta.scale_inv,
FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype,
a_casted,
fp8_meta.scale_inv,
FP8FwdTensors.GEMM1_INPUT,
fp8_dtype,
out_dtype,
workspace,
)
assert_allclose(actual_out, ref_out)
class TestLayerNorm:
"""
Test layernorm operators
"""
@staticmethod
def calc_fwd_ref(x, eps, gamma, beta):
"""
Calculate reference using paddle layer_norm op
"""
y = paddle.nn.functional.layer_norm(
x=x, normalized_shape=x.shape[1:], weight=gamma, bias=beta, epsilon=eps
)
mean = paddle.mean(x, axis=-1)
var = paddle.var(x, axis=-1)
inv_var = paddle.sqrt(1.0 / var)
return y, mean, inv_var
@staticmethod
def calc_bwd_ref(x, eps, gamma, beta, dy):
"""
Calculate reference using paddle layer_norm op
"""
x.stop_gradient = False
gamma.stop_gradient = False
beta.stop_gradient = False
y = paddle.nn.functional.layer_norm(
x=x, normalized_shape=x.shape[1:], weight=gamma, bias=beta, epsilon=eps
)
paddle.autograd.backward([y], [dy], True)
return x.grad, gamma.grad, beta.grad
def test_layernorm_fwd(self):
"""
Test BF16 LayerNorm Forward
"""
N, H = (16, 32)
eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype="bfloat16")
gamma = paddle.uniform(shape=(H,), dtype="bfloat16")
beta = paddle.uniform(shape=(H,), dtype="bfloat16")
y, mu, rsigma = layernorm_fwd(x, gamma, beta, eps, tex.DType.kBFloat16)
y_ref, mu_ref, rsigma_ref = self.calc_fwd_ref(x, eps, gamma, beta)
assert_allclose(y, y_ref, rtol=1e-4, atol=1e-4)
assert_allclose(mu, mu_ref, rtol=1e-3, atol=1e-3)
assert_allclose(rsigma, rsigma_ref, rtol=5e-2, atol=5e-2)
@staticmethod
def test_layernorm_fwd_fp8():
"""
Test FP8 LayerNorm Forward
"""
fp8_dtype = tex.DType.kFloat8E4M3
N, H = (16, 32)
eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype="float32")
gamma = paddle.uniform(shape=(H,), dtype="float32")
beta = paddle.uniform(shape=(H,), dtype="float32")
fp8_tensor = FP8FwdTensors.GEMM1_INPUT
fp8_meta = create_fp8_meta()
y_ref, mu_ref, rsigma_ref = layernorm_fwd(x, gamma, beta, eps, tex.DType.kFloat32)
y_fp8, mu, rsigma = layernorm_fwd_fp8(x, gamma, beta, eps, fp8_meta, fp8_tensor, fp8_dtype)
y = cast_from_fp8(y_fp8, fp8_meta, fp8_tensor, itype=fp8_dtype, otype=tex.DType.kFloat32)
assert_allclose(y, y_ref, rtol=0.1, atol=0.01)
assert_allclose(mu, mu_ref)
assert_allclose(rsigma, rsigma_ref)
def test_layernorm_bwd(self):
"""
Test BF16 LayerNorm Backward
"""
N, H = (16, 32)
eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype="bfloat16")
dy = paddle.uniform(shape=(N, H), dtype="bfloat16")
gamma = paddle.uniform(shape=(H,), dtype="bfloat16")
beta = paddle.uniform(shape=(H,), dtype="bfloat16")
dx_ref, dgamma_ref, dbeta_ref = self.calc_bwd_ref(x, eps, gamma, beta, dy)
_, mu, rsigma = layernorm_fwd(x, gamma, beta, eps, tex.DType.kBFloat16)
dx, dgamma, dbeta = layernorm_bwd(dy, x, mu, rsigma, gamma)
assert_allclose(dx, dx_ref, rtol=1e-5, atol=1e-5)
assert_allclose(dgamma, dgamma_ref, rtol=1e-5, atol=1e-5)
assert_allclose(dbeta, dbeta_ref, rtol=1e-5, atol=1e-5)
class TestRMSNorm:
"""
Test rmsnorm operators
"""
@staticmethod
def calc_fwd_ref(x, eps, gamma):
"""
Calculate rmsnorm reference using paddle op
"""
norm = paddle.rsqrt(paddle.mean(x**2, axis=-1, keepdim=True) + eps)
y = x * norm * gamma
return y
def calc_bwd_ref(self, x, eps, gamma, dy):
"""
Calculate rmsnorm bwd reference using paddle op
"""
x.stop_gradient = False
gamma.stop_gradient = False
y = self.calc_fwd_ref(x, eps, gamma)
paddle.autograd.backward([y], [dy], True)
return x.grad, gamma.grad
def test_rmsnorm_fwd(self):
"""
Test BF16 RMSNorm Forward
"""
N, H = (16, 32)
eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype="bfloat16")
gamma = paddle.uniform(shape=(H,), dtype="bfloat16")
y, _ = rmsnorm_fwd(x, gamma, eps, tex.DType.kBFloat16)
y_ref = self.calc_fwd_ref(x, eps, gamma)
assert_allclose(y, y_ref, rtol=1e-2, atol=1e-2)
@staticmethod
def test_rmsnorm_fwd_fp8():
"""
Test FP8 RMSNorm Forward
"""
fp8_dtype = tex.DType.kFloat8E4M3
N, H = (16, 32)
eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype="float32")
gamma = paddle.uniform(shape=(H,), dtype="float32")
fp8_tensor = FP8FwdTensors.GEMM1_INPUT
fp8_meta = create_fp8_meta()
y_ref, rsigma_ref = rmsnorm_fwd(x, gamma, eps, tex.DType.kFloat32)
y_fp8, rsigma = rmsnorm_fwd_fp8(x, gamma, eps, fp8_meta, fp8_tensor, fp8_dtype)
y = cast_from_fp8(y_fp8, fp8_meta, fp8_tensor, itype=fp8_dtype, otype=tex.DType.kFloat32)
assert_allclose(y, y_ref, rtol=0.1, atol=0.01)
assert_allclose(rsigma, rsigma_ref)
def test_rmsnorm_bwd(self):
"""
Test BF16 RMSNorm Backward
"""
N, H = (16, 32)
eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype="bfloat16")
dy = paddle.uniform(shape=(N, H), dtype="bfloat16")
gamma = paddle.uniform(shape=(H,), dtype="bfloat16")
dx_ref, dgamma_ref = self.calc_bwd_ref(x, eps, gamma, dy)
_, rsigma = rmsnorm_fwd(x, gamma, eps, tex.DType.kBFloat16)
dx, dgamma = rmsnorm_bwd(dy, x, rsigma, gamma)
assert_allclose(dx, dx_ref, rtol=1e-2, atol=1e-2)
assert_allclose(dgamma, dgamma_ref, rtol=1e-2, atol=5e-2)
class TestFusedAttn:
"""
Test fused attention operators
"""
def set_input(self, b, s_q, s_kv, h, d, dtype, attn_mode="self_attn", is_causal_masking=False):
"""
set test input
"""
def _random(shape):
if self.dtype == "bfloat16":
data = np.random.normal(loc=0.0, scale=0.02, size=shape).astype("float32")
return convert_float_to_uint16(data)
return np.random.normal(loc=0.0, scale=0.02, size=shape).astype(self.dtype)
self.batch_size = b
self.q_seqlen = s_q
self.kv_seqlen = s_kv
self.num_heads = h
self.head_size = d
self.dropout_prob = 0.0
self.scaling_factor = 1.0 / np.sqrt(d)
self.q_shape = (b, s_q, h, d)
self.kv_shape = (b, s_kv, h, d)
self.fuse_qkv_shape = (b, s_q, 3, h, d)
self.fuse_kv_shape = (b, s_kv, 2, h, d)
self.bias_shape = (1, h, s_q, s_kv)
self.attn_mode = attn_mode
self.dtype = dtype
self.is_causal_masking = is_causal_masking
self.q = _random(self.q_shape)
if self.attn_mode == "self_attn":
assert self.q_seqlen == self.kv_seqlen, "self attention requires q_seqlen == kv_seqlen"
self.kv = self.q
else:
self.kv = _random(self.kv_shape)
self.q_actual_seqlen = None
if self.is_causal_masking:
self.q_actual_seqlen = np.full(
self.batch_size,
self.q_seqlen,
dtype=np.int32,
)
else:
self.q_actual_seqlen = np.random.randint(
low=20,
high=self.q_seqlen,
size=(self.batch_size,),
dtype=np.int32,
)
self.kv_actual_seqlen = self.q_actual_seqlen
self.q_cu_seqlen = np.cumsum(self.q_actual_seqlen)
self.q_cu_seqlen = np.insert(self.q_cu_seqlen, 0, 0)
self.kv_cu_seqlen = np.cumsum(self.kv_actual_seqlen)
self.kv_cu_seqlen = np.insert(self.kv_cu_seqlen, 0, 0)
self.attn_mask = np.ones(
shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen),
dtype=np.int32,
)
if self.is_causal_masking:
assert attn_mode == "self_attn", "only support causal masking for self attention"
for i in range(0, self.batch_size):
for j in range(self.q_actual_seqlen[i]):
self.attn_mask[i, :, j, : j + 1] = 0
else:
for i in range(0, self.batch_size):
self.attn_mask[i, :, : self.q_actual_seqlen[i], : self.kv_actual_seqlen[i]] = 0
dout = _random((self.batch_size, self.q_seqlen, self.num_heads, self.head_size))
self.dout = paddle.to_tensor(dout, dtype=self.dtype)
def _get_reference_out(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
q_tensor = paddle.to_tensor(self.q, stop_gradient=False)
k_tensor = paddle.to_tensor(self.kv, stop_gradient=False)
v_tensor = paddle.to_tensor(self.kv, stop_gradient=False)
q_out = paddle.transpose(x=q_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d]
k_out = paddle.transpose(x=k_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d]
v_out = paddle.transpose(x=v_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d]
qk_out = paddle.matmul(
x=q_out * self.scaling_factor,
y=k_out,
transpose_x=False,
transpose_y=True,
)
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True).cast("bool")
attn_mask_vals = paddle.full(qk_out.shape, -1e4, qk_out.dtype)
attn_mask_out = paddle.where(attn_mask, attn_mask_vals, qk_out)
attn_mask_out = paddle.cast(attn_mask_out, "float32")
softmax_out = F.softmax(attn_mask_out)
softmax_out = paddle.cast(softmax_out, self.dtype)
if self.dropout_prob:
dropout_out = F.dropout(
softmax_out,
self.dropout_prob,
training=self.training,
mode="upscale_in_train",
)
qkv_out = paddle.matmul(dropout_out, v_out)
else:
qkv_out = paddle.matmul(softmax_out, v_out)
out = paddle.transpose(qkv_out, perm=[0, 2, 1, 3]) # [b, h, s, d] -> [b, s, h, d]
paddle.autograd.backward(
[out],
[self.dout],
retain_graph=True,
)
return out, q_tensor.grad, k_tensor.grad, v_tensor.grad
def _get_fused_attention_out(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
if self.attn_mode == "self_attn":
qkv = np.stack([self.q, self.kv, self.kv], axis=2) # [b, s, 3, h, d]
qkv_tensor = paddle.to_tensor(qkv, stop_gradient=False)
else:
q_tensor = paddle.to_tensor(self.q, stop_gradient=False)
kv = np.stack([self.kv, self.kv], axis=2) # [b, s, 2, h, d]
kv_tensor = paddle.to_tensor(kv, stop_gradient=False)
q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True)
kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True)
qkv_layout = "bs3hd" if self.attn_mode == "self_attn" else "bshd_bs2hd"
fused_attention_backend = get_fused_attention_backend(
num_heads=self.num_heads,
num_gqa_groups=self.num_heads,
q_seqlen=self.q_seqlen,
kv_seqlen=self.kv_seqlen,
head_size=self.head_size,
dtype=self.dtype,
dropout=self.dropout_prob,
qkv_layout=qkv_layout,
bias_type="no_bias",
mask_type="causal" if self.is_causal_masking else "padding",
)
qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16
out, softmax_aux_tensor, q_grad, k_grad, v_grad = None, None, None, None, None
if self.attn_mode == "self_attn":
out, softmax_aux_tensor, rng_state = fused_attn_fwd_qkvpacked(
qkv_tensor,
q_cu_seqlen_tensor,
is_training=True,
max_seqlen=self.q_seqlen,
qkv_dtype=qkv_dtype,
fused_attention_backend=fused_attention_backend,
Bias=None,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False,
attn_mask_type="causal" if self.is_causal_masking else "padding",
)
dqkv, _ = fused_attn_bwd_qkvpacked(
qkv_tensor,
q_cu_seqlen_tensor,
rng_state,
out,
self.dout,
softmax_aux_tensor,
max_seqlen=self.q_seqlen,
qkv_dtype=qkv_dtype,
fused_attention_backend=fused_attention_backend,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False,
attn_mask_type="causal" if self.is_causal_masking else "padding",
)
q_grad = dqkv[:, :, 0, :, :]
k_grad = dqkv[:, :, 1, :, :]
v_grad = dqkv[:, :, 2, :, :]
else: # attn_mode == 'cross_attn'
out, softmax_aux_tensor, rng_state = fused_attn_fwd_kvpacked(
q_tensor,
kv_tensor,
q_cu_seqlen_tensor,
kv_cu_seqlen_tensor,
is_training=True,
max_seqlen_q=self.q_seqlen,
max_seqlen_kv=self.kv_seqlen,
qkv_dtype=qkv_dtype,
fused_attention_backend=fused_attention_backend,
Bias=None,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False,
)
dq, dkv, _ = fused_attn_bwd_kvpacked(
q_tensor,
kv_tensor,
q_cu_seqlen_tensor,
kv_cu_seqlen_tensor,
rng_state,
out,
self.dout,
softmax_aux_tensor,
fused_attention_backend=fused_attention_backend,
max_seqlen_q=self.q_seqlen,
max_seqlen_kv=self.kv_seqlen,
qkv_dtype=qkv_dtype,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False,
)
q_grad = dq
k_grad = dkv[:, :, 0, :, :]
v_grad = dkv[:, :, 1, :, :]
return out, q_grad, k_grad, v_grad
def _get_fused_attention_with_separate_qkv(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
q_tensor = paddle.to_tensor(self.q, stop_gradient=False)
k_tensor = paddle.to_tensor(self.kv, stop_gradient=False)
v_tensor = paddle.to_tensor(self.kv, stop_gradient=False)
q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True)
kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True)
qkv_layout = "bshd_bshd_bshd"
fused_attention_backend = get_fused_attention_backend(
num_heads=self.num_heads,
num_gqa_groups=self.num_heads,
q_seqlen=self.q_seqlen,
kv_seqlen=self.kv_seqlen,
head_size=self.head_size,
dtype=self.dtype,
dropout=self.dropout_prob,
qkv_layout=qkv_layout,
bias_type="no_bias",
mask_type="causal" if self.is_causal_masking else "padding",
)
qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16
out, softmax_aux_tensor, rng_state = fused_attn_fwd(
q_tensor,
k_tensor,
v_tensor,
q_cu_seqlen_tensor,
kv_cu_seqlen_tensor,
is_training=True,
max_seqlen_q=self.q_seqlen,
max_seqlen_kv=self.kv_seqlen,
qkv_dtype=qkv_dtype,
fused_attention_backend=fused_attention_backend,
Bias=None,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False,
qkv_layout=qkv_layout,
attn_mask_type="causal" if self.is_causal_masking else "padding",
)
dq, dk, dv, _ = fused_attn_bwd(
q_tensor,
k_tensor,
v_tensor,
q_cu_seqlen_tensor,
kv_cu_seqlen_tensor,
rng_state,
out,
self.dout,
softmax_aux_tensor,
fused_attention_backend=fused_attention_backend,
max_seqlen_q=self.q_seqlen,
max_seqlen_kv=self.kv_seqlen,
qkv_dtype=qkv_dtype,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False,
qkv_layout=qkv_layout,
attn_mask_type="causal" if self.is_causal_masking else "padding",
)
return out, dq, dk, dv
@pytest.mark.parametrize("b, s, h, d", SELF_ATTN_CASES)
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
@pytest.mark.parametrize("is_causal_masking", [True, False])
def test_self_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking):
"""
test self attention forward + backward
"""
if not is_fused_attention_supported(
num_heads=h,
num_gqa_groups=h,
q_seqlen=s,
kv_seqlen=s,
head_size=d,
dtype=dtype,
dropout=0.0,
qkv_layout="bs3hd",
bias_type="no_bias",
mask_type="causal" if is_causal_masking else "padding",
):
pytest.skip("cuDNN fused attention is not supported")
self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out()
assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2)
assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2)
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
@pytest.mark.parametrize("b, s_q, s_kv, h, d", CROSS_ATTN_CASES)
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
def test_cross_attn_forward_backward(self, b, s_q, s_kv, h, d, dtype):
"""
test cross attention forward + backward
"""
if not is_fused_attention_supported(
num_heads=h,
num_gqa_groups=h,
q_seqlen=s_q,
kv_seqlen=s_kv,
head_size=d,
dtype=dtype,
dropout=0.0,
qkv_layout="bshd_bs2hd",
bias_type="no_bias",
mask_type="padding",
):
pytest.skip("cuDNN fused attention is not supported")
self.set_input(b, s_q, s_kv, h, d, dtype, "cross_attn")
reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out()
assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2)
assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2)
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
@pytest.mark.parametrize("b, s, h, d", FLASH_ATTN_CASES)
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
@pytest.mark.parametrize("is_causal_masking", [True])
def test_flash_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking):
"""
test flash attention forward + backward
"""
if not is_fused_attention_supported(
num_heads=h,
num_gqa_groups=h,
q_seqlen=s,
kv_seqlen=s,
head_size=d,
dtype=dtype,
dropout=0.0,
qkv_layout="bs3hd",
bias_type="no_bias",
mask_type="causal" if is_causal_masking else "padding",
):
pytest.skip("cuDNN fused attention is not supported")
self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out()
assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2)
assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2)
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
@pytest.mark.parametrize("b, s, h, d", FLASH_ATTN_CASES)
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
@pytest.mark.parametrize("is_causal_masking", [False, True])
def test_fused_attn_with_separate_qkv_forward_backward(
self, b, s, h, d, dtype, is_causal_masking
):
"""
test flash attention forward + backward with separate qkv inputs
"""
if not is_fused_attention_supported(
num_heads=h,
num_gqa_groups=h,
q_seqlen=s,
kv_seqlen=s,
head_size=d,
dtype=dtype,
dropout=0.0,
qkv_layout="bshd_bshd_bshd",
bias_type="no_bias",
mask_type="causal" if is_causal_masking else "padding",
):
pytest.skip("cuDNN fused attention is not supported")
self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_with_separate_qkv()
assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2)
assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2)
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
class TestSoftmax:
"""
Test softmax operators
"""
@staticmethod
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
def test_scaled_softmax_fwd_bwd(dtype):
"""test scaled softmax"""
B, H, S = (16, 4, 32)
scale = 0.8
x = paddle.uniform(shape=(B, H, S, S), dtype=dtype)
x.stop_gradient = False
dy = paddle.uniform(shape=(B, H, S, S), dtype=dtype)
y_ref = F.softmax(scale * x)
y = scaled_softmax_forward(x, scale)
paddle.autograd.backward([y_ref], [dy], True)
dx_ref = x.grad
dx = scaled_softmax_backward(dy, y, scale)
assert_allclose(y_ref, y, rtol=1e-4, atol=1e-3)
assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3)
@staticmethod
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
def test_scaled_masked_softmax_fwd_bwd(dtype):
"""test scaled masked softmax"""
B, H, S = (16, 4, 32)
scale = 0.8
x = paddle.uniform(shape=(B, H, S, S), dtype=dtype)
x.stop_gradient = False
dy = paddle.uniform(shape=(B, H, S, S), dtype=dtype)
mask = paddle.reshape(x[0, 0] > 0.3, shape=(1, 1, S, S))
mask_flipped = x[0, 0] <= 0.3
mask_ref = (mask_flipped.astype(dtype) - 1.0) * 1e4
y_ref = F.softmax(scale * x + mask_ref)
y = scaled_masked_softmax_forward(x, mask, scale)
paddle.autograd.backward([y_ref], [dy], True)
dx_ref = x.grad
dx = scaled_masked_softmax_backward(dy, y, scale)
assert_allclose(y_ref, y, rtol=1e-4, atol=1e-3)
assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3)
@staticmethod
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
def test_scaled_upper_triang_masked_softmax_fwd_bwd(dtype):
"""test scaled upper triang masked softmax"""
B, S = (16, 32)
scale = 0.8
x = paddle.uniform(shape=(B, S, S), dtype=dtype)
x.stop_gradient = False
dy = paddle.uniform(shape=(B, S, S), dtype=dtype)
mask = paddle.ones((S, S), dtype="int32")
col_beg, col_end = 1, S
for row in range(0, S):
mask[row, col_beg:col_end] = 0
col_beg += 1
mask_ref = (mask.astype(dtype) - 1.0) * 1e4
y_ref = F.softmax(scale * x + mask_ref)
y = scaled_upper_triang_masked_softmax_forward(x, scale)
paddle.autograd.backward([y_ref], [dy], True)
dx_ref = x.grad
dx = scaled_upper_triang_masked_softmax_backward(dy, y, scale)
assert_allclose(y_ref, y, rtol=1e-4, atol=5e-3)
assert_allclose(dx_ref, dx, rtol=1e-4, atol=5e-3)
@pytest.mark.parametrize("update_weight_scale_inv", [True, False])
def test_amax_and_scale_update(update_weight_scale_inv):
"""Test update_scale"""
num_gemm = 6
history_len = 1024
recipe = DelayedScaling()
fp8_dtype = tex.DType.kFloat8E4M3
fp8_max = recipe.fp8_format.value.max_fwd
non_weight_mask = paddle.to_tensor([True, False] * (num_gemm // 2))
amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype="float32")
rolled_history_ref = paddle.roll(amax_history_tensor, -1, axis=0)
rolled_history_ref[0] = 0.0
amax_tensor = paddle.max(amax_history_tensor, axis=0)
scale_tensor = paddle.ones(shape=[num_gemm], dtype="float32")
def calc_ref(amax, scale, fp8_max, margin=0):
"""Calculate reference scale"""
sf = (fp8_max / amax) / (2**margin)
sf = paddle.where(amax > 0.0, sf, scale)
sf = paddle.where(paddle.isfinite(amax), sf, scale)
return sf
scale_ref = calc_ref(amax_tensor, scale_tensor, fp8_max, 0.0)
if update_weight_scale_inv:
scale_inv_ref = 1.0 / scale_ref
else:
scale_inv_ref = paddle.zeros_like(scale_tensor)
scale_inv_ref = paddle.where(non_weight_mask, 1.0 / scale_ref, scale_inv_ref)
# Placeholder
scale_actual = paddle.zeros_like(scale_tensor)
scale_inv_actual = paddle.zeros_like(scale_tensor)
if update_weight_scale_inv:
non_weight_mask = paddle.empty([0])
tex.amax_and_scale_update_inplace(
_amax_history=amax_history_tensor,
_scale=scale_actual,
_scale_inv=scale_inv_actual,
non_weight_mask=non_weight_mask,
fp8_dtype=int(fp8_dtype),
margin=0.0,
amax_compute="max",
)
assert_allclose(scale_actual, scale_ref, rtol=1e-7, atol=1e-7)
assert_allclose(scale_inv_actual, scale_inv_ref, rtol=1e-7, atol=1e-7)
assert_allclose(amax_history_tensor, rolled_history_ref, rtol=1e-7, atol=1e-7)
def test_update_latest_history():
"""Test update_latest_history"""
num_gemm = 6
history_len = 1024
amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype="float32")
amax = paddle.rand(shape=[num_gemm], dtype="float32")
tex.update_latest_amax_history_inplace(_history=amax_history_tensor, amax=amax)
assert_allclose(amax_history_tensor[0], amax, rtol=1e-7, atol=1e-7)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Test TE Paddle Parallel"""
from pathlib import Path
import unittest
from dist_launcher import TestDistributed
from utils import is_devices_enough
from transformer_engine.paddle.fp8 import is_fp8_available
test_root = Path(__file__).resolve().parent
gpu_has_fp8, reason = is_fp8_available()
class TestParallelLinear(TestDistributed):
"""Test Linear in Parallel mode"""
@unittest.skipIf(not is_devices_enough(2), "TestParallelLinear needs 2 GPUs")
@unittest.skipIf(not gpu_has_fp8, reason)
def test_linear_tp(self):
"""Tests linear with tensor parallel in BF16"""
self.run_2gpu(str(test_root / "parallel_tests" / "linear_tp.py"))
class TestParallelLayerNormLinear(TestDistributed):
"""Test LayerNormLinear in Parallel mode"""
@unittest.skipIf(not is_devices_enough(2), "TestParallelLayerNormLinear needs 2 GPUs")
@unittest.skipIf(not gpu_has_fp8, reason)
def test_layernorm_linear_tp(self):
"""Tests layernorm_linear with tensor parallel in BF16"""
self.run_2gpu(str(test_root / "parallel_tests" / "layernorm_linear_tp.py"))
class TestParallelLayerNormMLP(TestDistributed):
"""Test LayerNormMLP in Parallel mode"""
@unittest.skipIf(not is_devices_enough(2), "TestParallelLayerNormMLP needs 2 GPUs")
@unittest.skipIf(not gpu_has_fp8, reason)
def test_layernorm_mlp_tp(self):
"""Tests layernorm_mlp with tensor parallel in BF16"""
self.run_2gpu(str(test_root / "parallel_tests" / "layernorm_mlp_tp.py"))
class TestAmaxReduction(TestDistributed):
"""Test amax reduction in dp mode"""
@unittest.skipIf(not is_devices_enough(2), "TestAmaxReduction needs 2 GPUs")
@unittest.skipIf(not gpu_has_fp8, reason)
def test_amax_reduction(self):
"""Tests amax reduction"""
self.run_2gpu(str(test_root / "parallel_tests" / "amax_reduction.py"))
class TestPipelineParallel(TestDistributed):
"""Test pipeline parallel"""
@unittest.skipIf(not is_devices_enough(2), "TestPipelineParallel needs 2 GPUs")
@unittest.skipIf(not gpu_has_fp8, reason)
def test_pipeline_parallel(self):
"""Tests pipeline parallel"""
self.run_2gpu(str(test_root / "parallel_tests" / "linear_pp.py"))
class TestGroupSharding(TestDistributed):
"""Test group sharding"""
@unittest.skipIf(not is_devices_enough(2), "TestGroupSharding needs 2 GPUs")
@unittest.skipIf(not gpu_has_fp8, reason)
def test_group_sharding(self):
"""Tests group sharding"""
self.run_2gpu(str(test_root / "parallel_tests" / "group_sharding.py"))
class TestParallelAttention(TestDistributed):
"""Test MultiHeadAttention Layer in Parallel mode"""
@unittest.skipIf(not is_devices_enough(2), "TestParallelAttention needs 2 GPUs")
@unittest.skipIf(not gpu_has_fp8, reason)
def test_attention_tp(self):
"""Tests TransMultiHeadAttentionformer Layer with tensor parallel in BF16"""
self.run_2gpu(str(test_root / "parallel_tests" / "attention_tp.py"))
class TestParallelTransformerLayer(TestDistributed):
"""Test Transformer Layer in Parallel mode"""
@unittest.skipIf(not is_devices_enough(2), "TestParallelTransformerLayer needs 2 GPUs")
@unittest.skipIf(not gpu_has_fp8, reason)
def test_transformer_tp(self):
"""Tests Transformer Layer with tensor parallel in BF16"""
self.run_2gpu(str(test_root / "parallel_tests" / "transformer_tp.py"))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Test TE Paddle Recompute"""
from pathlib import Path
import re
import subprocess
import numpy as np
import pytest
from transformer_engine.paddle.fp8 import is_fp8_available
test_root = Path(__file__).resolve().parent
is_fp8_supported, reason = is_fp8_available()
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("use_reentrant", [False, True])
def test_transformer_encoder_recompute(use_reentrant):
"""
Test TransformerLayer encoder recompute
"""
rtol = 1e-5
atol = 1e-5
def launch_subprocess_and_check_output(enable_recompute):
"""Launch training in subprocess and check output"""
try:
cmd = [
"python",
str(test_root / "recompute_tests" / "recompute_transformer_encoder.py"),
str(int(enable_recompute)),
str(int(use_reentrant)),
]
result = subprocess.check_output(cmd, stderr=subprocess.STDOUT, universal_newlines=True)
print(result)
loss_match = re.search(r"Loss:\s+(-?\d+\.\d+)", result)
memory_match = re.search(r"Peak memory:\s+(\d+)", result)
loss_value = float(loss_match.group(1))
memory_value = int(memory_match.group(1))
return loss_value, memory_value
except subprocess.CalledProcessError as e:
raise ValueError(f"Subprocess failed with error: {e}") from e
loss_recompute, peak_memory_recompute = launch_subprocess_and_check_output(True)
loss_ref, peak_memory_ref = launch_subprocess_and_check_output(False)
assert peak_memory_recompute < peak_memory_ref
np.testing.assert_allclose(loss_recompute, loss_ref, rtol=rtol, atol=atol)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utils for testing"""
import random
from typing import Union
import numpy as np
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
import transformer_engine # pylint: disable=unused-import
from transformer_engine.paddle.constants import (
TE_DType,
AttnBiasType,
AttnMaskType,
FusedAttnBackend,
)
from transformer_engine.paddle.fp8 import FP8TensorMeta
from transformer_engine import (
transformer_engine_paddle as tex,
) # pylint: disable=wrong-import-order
def create_fp8_meta(num_gemms=1, amax_history_len=10):
"""
Create and initialize FP8TensorMeta
"""
fp8_meta = FP8TensorMeta(is_forward=True)
fp8_meta.prepare(num_gemms, amax_history_len)
return fp8_meta
def assert_allclose(
actual, desired, rtol=1e-05, atol=1e-08, equal_nan=True, err_msg="", verbose=True
):
"""Compare two input paddle tensors"""
if isinstance(actual, paddle.Tensor):
actual = paddle.cast(actual, "float32")
if isinstance(desired, paddle.Tensor):
desired = paddle.cast(desired, "float32")
if len(actual.shape) == 0:
actual = actual.item()
desired = desired.item()
else:
actual = actual.numpy()
desired = desired.numpy()
np.testing.assert_allclose(actual, desired, rtol, atol, equal_nan, err_msg, verbose)
def assert_shape(inp, expected_shape):
"""Assert the shape of input tensor equals to expected shape"""
assert (
inp.shape == expected_shape
), f"Expected tensor shape: {expected_shape} != actual tensor shape: {inp.shape}"
def is_devices_enough(required):
"""If the number of device is enough"""
return paddle.device.cuda.device_count() >= required
def set_random_seed(seed):
"""Set random seed for reproducability."""
fleet.meta_parallel.model_parallel_random_seed(seed)
hcg = fleet.get_hybrid_communicate_group()
if paddle.distributed.get_world_size() > 1:
# obtain rank message of hybrid parallel
mp_rank = hcg.get_model_parallel_rank()
mp_size = hcg.get_model_parallel_world_size()
pp_rank = hcg.get_stage_id()
pp_size = hcg.get_pipe_parallel_world_size()
dp_rank = hcg.get_data_parallel_rank()
dp_size = hcg.get_data_parallel_world_size()
sharding_rank = hcg.get_sharding_parallel_rank()
else:
mp_rank, mp_size = 0, 1
pp_rank, pp_size = 0, 1
dp_rank, dp_size = 0, 1
sharding_rank, _ = 0, 1
random.seed(seed + 100 * pp_rank)
np.random.seed(seed + 100 * pp_rank)
seed_offset = seed + 1024 + paddle.distributed.get_world_size()
global_seed = (
seed_offset
+ pp_rank * (mp_size)
+ dp_rank * (mp_size * pp_size)
+ sharding_rank * (mp_size * pp_size * dp_size)
)
seed_offset += paddle.distributed.get_world_size()
local_seed = (
seed_offset
+ mp_rank
+ pp_rank * (mp_size)
+ dp_rank * (mp_size * pp_size)
+ sharding_rank * (mp_size * pp_size * dp_size)
)
tracker = get_rng_state_tracker()
# tracker.reset()
if "global_seed" not in tracker.states_:
tracker.add("global_seed", global_seed)
if "local_seed" not in tracker.states_:
tracker.add("local_seed", local_seed)
paddle.seed(global_seed)
def get_fused_attention_backend(
num_heads: int,
num_gqa_groups: int,
q_seqlen: int,
kv_seqlen: int,
head_size: int,
dtype: Union[paddle.dtype, str],
dropout: float,
qkv_layout: str = "bs3hd",
bias_type: str = "no_bias",
mask_type: str = "causal",
) -> tex.NVTE_Fused_Attn_Backend:
"""Get cuDNN fused attention backend for attention config"""
if isinstance(dtype, str):
dtype = dict(
float32=paddle.float32,
bfloat16=paddle.bfloat16,
float16=paddle.float16,
)[dtype]
return tex.get_fused_attn_backend(
TE_DType[dtype],
TE_DType[dtype],
tex.get_nvte_qkv_layout(qkv_layout),
AttnBiasType[bias_type],
AttnMaskType[mask_type],
dropout,
num_heads,
num_gqa_groups,
q_seqlen,
kv_seqlen,
head_size,
)
def is_fused_attention_supported(
num_heads: int,
num_gqa_groups: int,
q_seqlen: int,
kv_seqlen: int,
head_size: int,
dtype: Union[paddle.dtype, str],
dropout: float,
qkv_layout: str = "bs3hd",
bias_type: str = "no_bias",
mask_type: str = "causal",
) -> bool:
"""Check if cuDNN fused attention is supported for attention config"""
backend = get_fused_attention_backend(
num_heads=num_heads,
num_gqa_groups=num_gqa_groups,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
head_size=head_size,
dtype=dtype,
dropout=dropout,
qkv_layout=qkv_layout,
bias_type=bias_type,
mask_type=mask_type,
)
return backend != FusedAttnBackend["No_Backend"]
def register_sequence_parallel_allreduce_hooks(model, accumulation_steps) -> None:
"""Register allreduce hooks for sequence parallel tensors"""
def is_sequence_parallel_parameter(parameter):
"""If input tensor is marked as sequence parallel tensor"""
out = getattr(parameter, "sequence_parallel", False)
return out
def create_allreduce_gradient_hook(param, accumulation_steps):
"""Create allreduce gradient hook"""
hcg = fleet.get_hybrid_communicate_group()
pg = hcg.get_model_parallel_group().process_group
step = [0]
@paddle.autograd.no_grad()
def __impl__():
step[0] += 1
if (step[0] % accumulation_steps) == 0:
if hasattr(param, "main_grad"):
pg.allreduce(param.main_grad).wait()
else:
pg.allreduce(param.grad).wait()
return __impl__
if accumulation_steps <= 0 or not paddle.distributed.is_initialized():
return
hcg = fleet.get_hybrid_communicate_group()
mp_group = hcg.get_model_parallel_group()
if mp_group.nranks <= 1:
return
params = []
for p in model.parameters():
if is_sequence_parallel_parameter(p):
params.append(p)
for p in params:
hook = create_allreduce_gradient_hook(p, accumulation_steps)
p._register_backward_hook(hook)
build
onnxruntime
libcustom_ort_ops.so
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
cmake_minimum_required(VERSION 3.21)
project(custom_ort_ops LANGUAGES CXX)
# Dependencies
find_package(CUDAToolkit REQUIRED)
set(ONNX_INCLUDE_DIR ${CMAKE_SOURCE_DIR}/onnxruntime/include)
if(NOT EXISTS "${ONNX_INCLUDE_DIR}")
message(FATAL_ERROR
"Could not find ONNX Runtime headers. "
"Please clone https://github.com/microsoft/onnxruntime "
"into TransformerEngine/tests/pytorch/onnx.")
endif()
include_directories(${ONNX_INCLUDE_DIR})
# Configure library
add_library(custom_ort_ops SHARED custom_op_library.cc)
target_link_libraries(custom_ort_ops PUBLIC CUDA::cudart)
target_include_directories(custom_ort_ops PUBLIC
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(custom_ort_ops PRIVATE
${ONNX_INCLUDE_DIR}/onnxruntime
${ONNX_INCLUDE_DIR}/onnxruntime/core/session)
# Install library
install(TARGETS custom_ort_ops DESTINATION .)
# Custom ONNX Runtime operators for Transformer Engine tests
This directory contains code that builds custom ONNX operators for use
in Transformer Engine tests. It includes basic, non-performant
implementations of the FP8 quantization and dequantization operators
that are used when exporting Transformer Engine models to ONNX.
For more information, see [the ONNX Runtime reference for custom
operators](https://onnxruntime.ai/docs/reference/operators/add-custom-op.html).
Much of the code has been adapted from [an ONNX Runtime
test](https://github.com/microsoft/onnxruntime/blob/de93f40240459953a6e3bbb86b6ad83eaeab681f/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc).
## Usage
* Build the custom operators:
```bash
$ bash TransformerEngine/tests/pytorch/custom_ort_ops/build.sh
```
* Run the ONNX export tests with pytest:
```bash
$ python -m pytest TransformerEngine/tests/pytorch/test_onnx_export.py
```
\ No newline at end of file
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -ex
: ${CUSTOM_ORT_OPS_PATH=$(dirname $(realpath $0))}
cd ${CUSTOM_ORT_OPS_PATH}
# Download ONNX Runtime source
git clone --depth=1 -b rel-1.19.2 --single-branch https://github.com/microsoft/onnxruntime.git || true
# Configure and build with CMake
mkdir -p build
cmake -S . -B build -DCMAKE_INSTALL_PREFIX=.
cmake --build build --verbose
cmake --install build --verbose
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "custom_op_library.h"
#define ORT_API_MANUAL_INIT
#include "onnxruntime_c_api.h"
#include "onnxruntime_cxx_api.h"
#undef ORT_API_MANUAL_INIT
#include <exception>
#include <memory>
#include <mutex>
#include <utility>
#include <vector>
#include "core/common/common.h"
#include "core/session/onnxruntime_lite_custom_op.h"
#include <cuda_fp8.h>
namespace {
template <typename IType, typename OType, typename CType>
void Quantize(OrtKernelContext* context,
const Ort::Custom::Tensor<IType>& input,
const Ort::Custom::Tensor<CType>& scale_inv,
Ort::Custom::Tensor<unsigned char>& output) {
auto raw_input = input.Data();
auto raw_scale_inv = scale_inv.Data();
auto raw_output = reinterpret_cast<OType*>(output.Allocate(input.Shape()));
const auto rs = static_cast<CType>(raw_scale_inv[0]);
const size_t N = input.NumberOfElement();
for (size_t i = 0; i < N; ++i) {
const auto x = static_cast<CType>(raw_input[i]);
raw_output[i] = static_cast<OType>(x / rs);
}
}
template <typename IType, typename OType, typename CType>
void Dequantize(OrtKernelContext* context,
const Ort::Custom::Tensor<unsigned char>& input,
const Ort::Custom::Tensor<CType>& scale_inv,
Ort::Custom::Tensor<OType>& output) {
auto raw_input = reinterpret_cast<const IType*>(input.Data());
auto raw_scale_inv = scale_inv.Data();
auto raw_output = output.Allocate(input.Shape());
const auto rs = static_cast<CType>(raw_scale_inv[0]);
const size_t N = input.NumberOfElement();
for (size_t i = 0; i < N; ++i) {
const auto x = rs * static_cast<CType>(raw_input[i]);
raw_output[i] = static_cast<OType>(x);
}
}
static void AddOrtCustomOpDomainToContainer(Ort::CustomOpDomain&& domain) {
static std::vector<Ort::CustomOpDomain> ort_custom_op_domain_container;
static std::mutex ort_custom_op_domain_mutex;
std::lock_guard<std::mutex> lock(ort_custom_op_domain_mutex);
ort_custom_op_domain_container.push_back(std::move(domain));
}
} // namespace
OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) {
Ort::Global<void>::api_ = api->GetApi(ORT_API_VERSION);
// Namespace for custom ops
static const char* c_OpDomain = "trt";
// Construct custom ops
static const std::unique_ptr<Ort::Custom::OrtLiteCustomOp> c_Quantize{
Ort::Custom::CreateLiteCustomOp("TRT_FP8QuantizeLinear",
"CPUExecutionProvider",
Quantize<float, __nv_fp8_e4m3, float>)
};
static const std::unique_ptr<Ort::Custom::OrtLiteCustomOp> c_Dequantize{
Ort::Custom::CreateLiteCustomOp("TRT_FP8DequantizeLinear",
"CPUExecutionProvider",
Dequantize<__nv_fp8_e4m3, float, float>)
};
// Register custom ops
OrtStatus* result = nullptr;
ORT_TRY {
Ort::CustomOpDomain domain{c_OpDomain};
domain.Add(c_Quantize.get());
domain.Add(c_Dequantize.get());
Ort::UnownedSessionOptions session_options(options);
session_options.Add(domain);
AddOrtCustomOpDomainToContainer(std::move(domain));
}
ORT_CATCH(const std::exception& e) {
ORT_HANDLE_EXCEPTION([&]() {
Ort::Status status{e};
result = status.release();
});
}
return result;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment