"vscode:/vscode.git/clone" did not exist on "b56b6ca0d650c653c80ec113e27d6a8e640a4b2f"
Unverified Commit 9416519d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Apply formatting (#929)



* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d99142a0
...@@ -21,15 +21,15 @@ from paddle.distributed.utils.launch_utils import ( ...@@ -21,15 +21,15 @@ from paddle.distributed.utils.launch_utils import (
watch_local_trainers, watch_local_trainers,
) )
__all__ = ['TestDistributed'] __all__ = ["TestDistributed"]
def get_cluster_from_args(selected_gpus): def get_cluster_from_args(selected_gpus):
"""Get node information from selected GPUs""" """Get node information from selected GPUs"""
cluster_node_ips = '127.0.0.1' cluster_node_ips = "127.0.0.1"
node_ip = '127.0.0.1' node_ip = "127.0.0.1"
node_ips = [x.strip() for x in cluster_node_ips.split(',')] node_ips = [x.strip() for x in cluster_node_ips.split(",")]
node_ips.index(node_ip) node_ips.index(node_ip)
...@@ -47,7 +47,7 @@ def get_cluster_from_args(selected_gpus): ...@@ -47,7 +47,7 @@ def get_cluster_from_args(selected_gpus):
def get_gpus(selected_gpus): def get_gpus(selected_gpus):
"""Get selected GPU string""" """Get selected GPU string"""
selected_gpus = [x.strip() for x in selected_gpus.split(',')] selected_gpus = [x.strip() for x in selected_gpus.split(",")]
return selected_gpus return selected_gpus
...@@ -86,7 +86,7 @@ def start_local_trainers( ...@@ -86,7 +86,7 @@ def start_local_trainers(
print(f"trainer proc env:{current_env}") print(f"trainer proc env:{current_env}")
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': if os.getenv("WITH_COVERAGE", "OFF") == "ON":
cmd = "python -m coverage run --branch -p " + training_script cmd = "python -m coverage run --branch -p " + training_script
else: else:
cmd = "python -u " + training_script cmd = "python -u " + training_script
...@@ -95,7 +95,9 @@ def start_local_trainers( ...@@ -95,7 +95,9 @@ def start_local_trainers(
fn = None fn = None
proc = subprocess.Popen(cmd.split(" ") + training_script_args, env=current_env) # pylint: disable=consider-using-with proc = subprocess.Popen(
cmd.split(" ") + training_script_args, env=current_env
) # pylint: disable=consider-using-with
tp = TrainerProc() tp = TrainerProc()
tp.proc = proc tp.proc = proc
...@@ -117,10 +119,10 @@ class TestDistributed(unittest.TestCase): ...@@ -117,10 +119,10 @@ class TestDistributed(unittest.TestCase):
allocator_strategy="auto_growth", allocator_strategy="auto_growth",
): ):
"""Run target file in subprocesses""" """Run target file in subprocesses"""
if (not core.is_compiled_with_cuda() or core.get_cuda_device_count() == 0): if not core.is_compiled_with_cuda() or core.get_cuda_device_count() == 0:
return return
selected_gpus = get_gpus('0,1') selected_gpus = get_gpus("0,1")
cluster = None cluster = None
pod = None pod = None
......
...@@ -27,7 +27,7 @@ class TestAmaxReduction(unittest.TestCase): ...@@ -27,7 +27,7 @@ class TestAmaxReduction(unittest.TestCase):
def setUp(self): def setUp(self):
self.data_parallel_size = 2 self.data_parallel_size = 2
self.init_dist_env() self.init_dist_env()
self.global_dtype = 'bfloat16' self.global_dtype = "bfloat16"
paddle.set_default_dtype(self.global_dtype) paddle.set_default_dtype(self.global_dtype)
def init_dist_env(self): def init_dist_env(self):
...@@ -83,5 +83,5 @@ class TestAmaxReduction(unittest.TestCase): ...@@ -83,5 +83,5 @@ class TestAmaxReduction(unittest.TestCase):
assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].scale_inv) assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].scale_inv)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -44,8 +44,8 @@ class TestAttentionTp(unittest.TestCase): ...@@ -44,8 +44,8 @@ class TestAttentionTp(unittest.TestCase):
self.num_heads = 16 self.num_heads = 16
self.q_seqlen = 128 self.q_seqlen = 128
self.kv_seqlen = 128 self.kv_seqlen = 128
self.mask_type = 'padding' self.mask_type = "padding"
self.global_dtype = 'bfloat16' self.global_dtype = "bfloat16"
self.rtol = 5e-3 self.rtol = 5e-3
self.atol = 5e-3 self.atol = 5e-3
self.eps = 1e-3 self.eps = 1e-3
...@@ -56,7 +56,7 @@ class TestAttentionTp(unittest.TestCase): ...@@ -56,7 +56,7 @@ class TestAttentionTp(unittest.TestCase):
inp, mask = inp_list inp, mask = inp_list
if sequence_parallel: if sequence_parallel:
split_size = inp.shape[0] // self.world_size split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :] input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :]
else: else:
input_parallel = inp input_parallel = inp
with te.fp8_autocast(enabled=fp8_enabled): with te.fp8_autocast(enabled=fp8_enabled):
...@@ -80,18 +80,20 @@ class TestAttentionTp(unittest.TestCase): ...@@ -80,18 +80,20 @@ class TestAttentionTp(unittest.TestCase):
self.num_heads, self.num_heads,
) )
common_kwargs = { common_kwargs = {
'layernorm_epsilon': self.eps, "layernorm_epsilon": self.eps,
'attention_dropout': 0.0, "attention_dropout": 0.0,
'attn_mask_type': self.mask_type, "attn_mask_type": self.mask_type,
'attention_type': 'self', "attention_type": "self",
"tp_group": self.tp_group, "tp_group": self.tp_group,
"input_layernorm": True, "input_layernorm": True,
} }
layer_tp = te.MultiHeadAttention(*common_args, layer_tp = te.MultiHeadAttention(
**common_kwargs, *common_args,
set_parallel_mode=True, **common_kwargs,
sequence_parallel=self.sequence_parallel) set_parallel_mode=True,
sequence_parallel=self.sequence_parallel,
)
layer_single = te.MultiHeadAttention(*common_args, **common_kwargs, set_parallel_mode=False) layer_single = te.MultiHeadAttention(*common_args, **common_kwargs, set_parallel_mode=False)
def _get_total_weight(local_weight, tp_group, axis, interleave=False): def _get_total_weight(local_weight, tp_group, axis, interleave=False):
...@@ -102,12 +104,15 @@ class TestAttentionTp(unittest.TestCase): ...@@ -102,12 +104,15 @@ class TestAttentionTp(unittest.TestCase):
# Due to the interleaved qkv layout, need to concat on num_head # Due to the interleaved qkv layout, need to concat on num_head
# dimension for column parallel linear in MultiHeadAttention layer # dimension for column parallel linear in MultiHeadAttention layer
assert axis == 0 assert axis == 0
assert [3 * self.hidden_size // self.world_size, assert [
self.hidden_size] == partial_weight.shape 3 * self.hidden_size // self.world_size,
self.hidden_size,
] == partial_weight.shape
local_num_head = self.num_heads // self.world_size local_num_head = self.num_heads // self.world_size
for idx, _ in enumerate(total_weight): for idx, _ in enumerate(total_weight):
total_weight[idx] = total_weight[idx].reshape( total_weight[idx] = total_weight[idx].reshape(
[3, local_num_head, -1, self.hidden_size]) [3, local_num_head, -1, self.hidden_size]
)
total_weight = paddle.concat(total_weight, axis=1).reshape([-1, self.hidden_size]) total_weight = paddle.concat(total_weight, axis=1).reshape([-1, self.hidden_size])
else: else:
total_weight = paddle.concat(total_weight, axis=axis) total_weight = paddle.concat(total_weight, axis=axis)
...@@ -123,42 +128,47 @@ class TestAttentionTp(unittest.TestCase): ...@@ -123,42 +128,47 @@ class TestAttentionTp(unittest.TestCase):
weight_dst = _get_weight(layer_dst, weight_names) weight_dst = _get_weight(layer_dst, weight_names)
if partition_mode is None: if partition_mode is None:
total_weight = weight_src total_weight = weight_src
elif partition_mode == 'column': elif partition_mode == "column":
total_weight = _get_total_weight(weight_src, total_weight = _get_total_weight(
tp_group=self.tp_group, weight_src, tp_group=self.tp_group, axis=0, interleave=interleave
axis=0, )
interleave=interleave) elif partition_mode == "row":
elif partition_mode == 'row':
total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1) total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1)
else: else:
raise ValueError(f"Partition Mode {partition_mode} is not supported.") raise ValueError(f"Partition Mode {partition_mode} is not supported.")
assert weight_dst.shape == total_weight.shape, \ assert (
f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match." weight_dst.shape == total_weight.shape
), f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match."
weight_dst.copy_(total_weight, True) weight_dst.copy_(total_weight, True)
copy_weight(layer_tp, layer_single, None, ['layernorm_qkv', 'ln_weight']) copy_weight(layer_tp, layer_single, None, ["layernorm_qkv", "ln_weight"])
copy_weight(layer_tp, layer_single, 'column', ['layernorm_qkv', 'weight'], interleave=True) copy_weight(layer_tp, layer_single, "column", ["layernorm_qkv", "weight"], interleave=True)
copy_weight(layer_tp, layer_single, 'row', ['proj', 'weight']) copy_weight(layer_tp, layer_single, "row", ["proj", "weight"])
if self.sequence_parallel: if self.sequence_parallel:
register_sequence_parallel_allreduce_hooks(layer_tp, accumulation_steps=1) register_sequence_parallel_allreduce_hooks(layer_tp, accumulation_steps=1)
optimizer_tp = paddle.optimizer.SGD(learning_rate=0.01, parameters=layer_tp.parameters()) optimizer_tp = paddle.optimizer.SGD(learning_rate=0.01, parameters=layer_tp.parameters())
optimizer_single = paddle.optimizer.SGD(learning_rate=0.01, optimizer_single = paddle.optimizer.SGD(
parameters=layer_single.parameters()) learning_rate=0.01, parameters=layer_single.parameters()
)
layer_tp = fleet.distributed_model(layer_tp) layer_tp = fleet.distributed_model(layer_tp)
optimizer_tp = fleet.distributed_optimizer(optimizer_tp) optimizer_tp = fleet.distributed_optimizer(optimizer_tp)
for _ in range(5): for _ in range(5):
inp = paddle.uniform([self.batch_size, self.q_seqlen, self.hidden_size], inp = paddle.uniform(
self.global_dtype) [self.batch_size, self.q_seqlen, self.hidden_size], self.global_dtype
mask = paddle.zeros(shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), )
dtype='bool') mask = paddle.zeros(
loss_tp, out_tp = self._train_one_step(layer_tp, [inp, mask], optimizer_tp, self.fp8, shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), dtype="bool"
self.sequence_parallel) )
loss_single, out_single = self._train_one_step(layer_single, [inp, mask], loss_tp, out_tp = self._train_one_step(
optimizer_single, self.fp8) layer_tp, [inp, mask], optimizer_tp, self.fp8, self.sequence_parallel
)
loss_single, out_single = self._train_one_step(
layer_single, [inp, mask], optimizer_single, self.fp8
)
assert_allclose(out_tp, out_single, rtol=self.rtol, atol=self.atol) assert_allclose(out_tp, out_single, rtol=self.rtol, atol=self.atol)
assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol) assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol)
...@@ -173,8 +183,8 @@ class TestAttentionTpFp8(TestAttentionTp): ...@@ -173,8 +183,8 @@ class TestAttentionTpFp8(TestAttentionTp):
self.num_heads = 16 self.num_heads = 16
self.q_seqlen = 128 self.q_seqlen = 128
self.kv_seqlen = 128 self.kv_seqlen = 128
self.mask_type = 'padding' self.mask_type = "padding"
self.global_dtype = 'bfloat16' self.global_dtype = "bfloat16"
self.rtol = 5e-2 self.rtol = 5e-2
self.atol = 5e-2 self.atol = 5e-2
self.eps = 1e-3 self.eps = 1e-3
...@@ -192,8 +202,8 @@ class TestAttentionSp(TestAttentionTp): ...@@ -192,8 +202,8 @@ class TestAttentionSp(TestAttentionTp):
self.num_heads = 16 self.num_heads = 16
self.q_seqlen = 128 self.q_seqlen = 128
self.kv_seqlen = 128 self.kv_seqlen = 128
self.mask_type = 'padding' self.mask_type = "padding"
self.global_dtype = 'bfloat16' self.global_dtype = "bfloat16"
self.rtol = 5e-3 self.rtol = 5e-3
self.atol = 5e-3 self.atol = 5e-3
self.eps = 1e-3 self.eps = 1e-3
...@@ -211,8 +221,8 @@ class TestAttentionSpFp8(TestAttentionTp): ...@@ -211,8 +221,8 @@ class TestAttentionSpFp8(TestAttentionTp):
self.num_heads = 16 self.num_heads = 16
self.q_seqlen = 128 self.q_seqlen = 128
self.kv_seqlen = 128 self.kv_seqlen = 128
self.mask_type = 'padding' self.mask_type = "padding"
self.global_dtype = 'bfloat16' self.global_dtype = "bfloat16"
self.rtol = 5e-2 self.rtol = 5e-2
self.atol = 1e-1 self.atol = 1e-1
self.eps = 1e-3 self.eps = 1e-3
...@@ -220,5 +230,5 @@ class TestAttentionSpFp8(TestAttentionTp): ...@@ -220,5 +230,5 @@ class TestAttentionSpFp8(TestAttentionTp):
self.sequence_parallel = True self.sequence_parallel = True
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -8,7 +8,8 @@ import unittest ...@@ -8,7 +8,8 @@ import unittest
import paddle import paddle
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import ( from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import (
DygraphShardingOptimizer,) DygraphShardingOptimizer,
)
from utils import assert_allclose, set_random_seed from utils import assert_allclose, set_random_seed
import transformer_engine.paddle as te import transformer_engine.paddle as te
...@@ -25,7 +26,7 @@ class TestGroupSharding(unittest.TestCase): ...@@ -25,7 +26,7 @@ class TestGroupSharding(unittest.TestCase):
def set_attr(self): def set_attr(self):
"""Set test configs""" """Set test configs"""
self.sharding_degree = 2 self.sharding_degree = 2
self.global_dtype = 'float32' self.global_dtype = "float32"
self.rtol = 1e-5 self.rtol = 1e-5
self.atol = 1e-5 self.atol = 1e-5
self.batch_size = 16 self.batch_size = 16
...@@ -57,11 +58,12 @@ class TestGroupSharding(unittest.TestCase): ...@@ -57,11 +58,12 @@ class TestGroupSharding(unittest.TestCase):
optimizer = paddle.optimizer.AdamW(learning_rate=0.01, parameters=model.parameters()) optimizer = paddle.optimizer.AdamW(learning_rate=0.01, parameters=model.parameters())
group = fleet.get_hybrid_communicate_group().get_sharding_parallel_group() group = fleet.get_hybrid_communicate_group().get_sharding_parallel_group()
class ShardingLevel: # pylint: disable=too-few-public-methods, class ShardingLevel: # pylint: disable=too-few-public-methods,
"""Paddle sharding options""" """Paddle sharding options"""
kStage1 = 'os'
kStage2 = 'os_g' kStage1 = "os"
kStage3 = 'p_g_os' kStage2 = "os_g"
kStage3 = "p_g_os"
level = ShardingLevel.kStage3 if stage == 3 else ShardingLevel.kStage2 level = ShardingLevel.kStage3 if stage == 3 else ShardingLevel.kStage2
model, optimizer, _ = paddle.distributed.sharding.group_sharded_parallel( model, optimizer, _ = paddle.distributed.sharding.group_sharded_parallel(
...@@ -104,8 +106,9 @@ class TestGroupSharding(unittest.TestCase): ...@@ -104,8 +106,9 @@ class TestGroupSharding(unittest.TestCase):
loss_pd = train_one_step(model_pd, inp, optimizer_pd) loss_pd = train_one_step(model_pd, inp, optimizer_pd)
assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol)
assert len(optimizer_te.state_dict()) == 4, \ assert (
"Expect each rank to hold 4 optimizer state entries." len(optimizer_te.state_dict()) == 4
), "Expect each rank to hold 4 optimizer state entries."
def test_group_sharding_stage2(self): def test_group_sharding_stage2(self):
"""Tests group sharding training""" """Tests group sharding training"""
...@@ -141,8 +144,9 @@ class TestGroupSharding(unittest.TestCase): ...@@ -141,8 +144,9 @@ class TestGroupSharding(unittest.TestCase):
loss_pd = train_one_step(model_pd, inp, optimizer_pd) loss_pd = train_one_step(model_pd, inp, optimizer_pd)
assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol)
assert len(optimizer_te.state_dict()) == 4, \ assert (
"Expect each rank to hold 4 optimizer state entries." len(optimizer_te.state_dict()) == 4
), "Expect each rank to hold 4 optimizer state entries."
def test_group_sharding_stage3(self): def test_group_sharding_stage3(self):
"""Tests group sharding training""" """Tests group sharding training"""
...@@ -174,11 +178,11 @@ class TestGroupSharding(unittest.TestCase): ...@@ -174,11 +178,11 @@ class TestGroupSharding(unittest.TestCase):
assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol)
for name, value in optimizer_te.state_dict().items(): for name, value in optimizer_te.state_dict().items():
if name.endswith('w_0_moment1_0'): if name.endswith("w_0_moment1_0"):
assert value.numel() == \ assert (
self.in_channels * self.out_channels // self.sharding_degree, \ value.numel() == self.in_channels * self.out_channels // self.sharding_degree
"Expect optimizer state to be sharded across trainers." ), "Expect optimizer state to be sharded across trainers."
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -42,22 +42,22 @@ class TestLayerNormLinearTp(unittest.TestCase): ...@@ -42,22 +42,22 @@ class TestLayerNormLinearTp(unittest.TestCase):
self.batch_size = 16 self.batch_size = 16
self.in_features = 32 self.in_features = 32
self.out_features = 64 self.out_features = 64
self.global_dtype = 'bfloat16' self.global_dtype = "bfloat16"
self.rtol = 1e-3 self.rtol = 1e-3
self.atol = 1e-3 self.atol = 1e-3
self.eps = 1e-3 self.eps = 1e-3
self.fp8 = False self.fp8 = False
self.sequence_parallel = False self.sequence_parallel = False
def _train_one_step(self, layer, inp, optimizer, split_input='none', gather_output=False): def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False):
inp = paddle.to_tensor(inp, stop_gradient=True) inp = paddle.to_tensor(inp, stop_gradient=True)
assert split_input in ['none', 'column', 'row'] assert split_input in ["none", "column", "row"]
if split_input == 'column': if split_input == "column":
split_size = inp.shape[1] // self.world_size split_size = inp.shape[1] // self.world_size
input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)] input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)]
elif split_input == 'row': elif split_input == "row":
split_size = inp.shape[0] // self.world_size split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :] input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :]
else: else:
input_parallel = inp input_parallel = inp
input_parallel.stop_gradient = False input_parallel.stop_gradient = False
...@@ -70,12 +70,12 @@ class TestLayerNormLinearTp(unittest.TestCase): ...@@ -70,12 +70,12 @@ class TestLayerNormLinearTp(unittest.TestCase):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
optimizer.clear_grad() optimizer.clear_grad()
if split_input != 'none': if split_input != "none":
grad_input = [] grad_input = []
paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group) paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group)
if split_input == 'column': if split_input == "column":
grad_input = paddle.concat(grad_input, axis=1) grad_input = paddle.concat(grad_input, axis=1)
elif split_input == 'row': elif split_input == "row":
grad_input = paddle.concat(grad_input, axis=0) grad_input = paddle.concat(grad_input, axis=0)
else: else:
grad_input = input_parallel.grad grad_input = input_parallel.grad
...@@ -88,14 +88,14 @@ class TestLayerNormLinearTp(unittest.TestCase): ...@@ -88,14 +88,14 @@ class TestLayerNormLinearTp(unittest.TestCase):
self.in_features, self.in_features,
self.out_features, self.out_features,
eps=self.eps, eps=self.eps,
parallel_mode='column', parallel_mode="column",
sequence_parallel=self.sequence_parallel, sequence_parallel=self.sequence_parallel,
) )
layer_pd = te.LayerNormLinear( layer_pd = te.LayerNormLinear(
self.in_features, self.in_features,
self.out_features, self.out_features,
eps=self.eps, eps=self.eps,
backend='paddle', backend="paddle",
) )
# Get total weight # Get total weight
total_weight = [] total_weight = []
...@@ -104,8 +104,9 @@ class TestLayerNormLinearTp(unittest.TestCase): ...@@ -104,8 +104,9 @@ class TestLayerNormLinearTp(unittest.TestCase):
total_weight = paddle.concat(total_weight, axis=0) total_weight = paddle.concat(total_weight, axis=0)
layer_pd.weight.copy_(total_weight.T, True) layer_pd.weight.copy_(total_weight.T, True)
assert_shape(layer_te.weight, assert_shape(
[self.out_features // self.model_parallel_size, self.in_features]) layer_te.weight, [self.out_features // self.model_parallel_size, self.in_features]
)
assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size]) assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size])
optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
...@@ -121,8 +122,9 @@ class TestLayerNormLinearTp(unittest.TestCase): ...@@ -121,8 +122,9 @@ class TestLayerNormLinearTp(unittest.TestCase):
layer_te, layer_te,
inp, inp,
optimizer_te, optimizer_te,
split_input='row' if self.sequence_parallel else 'none', split_input="row" if self.sequence_parallel else "none",
gather_output=True) gather_output=True,
)
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd) loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
...@@ -136,7 +138,7 @@ class TestLayerNormLinearTpFp8(TestLayerNormLinearTp): ...@@ -136,7 +138,7 @@ class TestLayerNormLinearTpFp8(TestLayerNormLinearTp):
self.batch_size = 16 self.batch_size = 16
self.in_features = 32 self.in_features = 32
self.out_features = 64 self.out_features = 64
self.global_dtype = 'bfloat16' self.global_dtype = "bfloat16"
self.rtol = 1e-2 self.rtol = 1e-2
self.atol = 1e-2 self.atol = 1e-2
self.eps = 1e-3 self.eps = 1e-3
...@@ -152,7 +154,7 @@ class TestLayerNormLinearSp(TestLayerNormLinearTp): ...@@ -152,7 +154,7 @@ class TestLayerNormLinearSp(TestLayerNormLinearTp):
self.batch_size = 16 self.batch_size = 16
self.in_features = 32 self.in_features = 32
self.out_features = 64 self.out_features = 64
self.global_dtype = 'bfloat16' self.global_dtype = "bfloat16"
self.rtol = 1e-3 self.rtol = 1e-3
self.atol = 1e-3 self.atol = 1e-3
self.eps = 1e-3 self.eps = 1e-3
...@@ -168,7 +170,7 @@ class TestLayerNormLinearSpFp8(TestLayerNormLinearTp): ...@@ -168,7 +170,7 @@ class TestLayerNormLinearSpFp8(TestLayerNormLinearTp):
self.batch_size = 16 self.batch_size = 16
self.in_features = 32 self.in_features = 32
self.out_features = 64 self.out_features = 64
self.global_dtype = 'bfloat16' self.global_dtype = "bfloat16"
self.rtol = 1e-2 self.rtol = 1e-2
self.atol = 1e-2 self.atol = 1e-2
self.eps = 1e-3 self.eps = 1e-3
...@@ -176,5 +178,5 @@ class TestLayerNormLinearSpFp8(TestLayerNormLinearTp): ...@@ -176,5 +178,5 @@ class TestLayerNormLinearSpFp8(TestLayerNormLinearTp):
self.sequence_parallel = True self.sequence_parallel = True
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -42,22 +42,22 @@ class TestLayerNormMLPTp(unittest.TestCase): ...@@ -42,22 +42,22 @@ class TestLayerNormMLPTp(unittest.TestCase):
self.batch_size = 16 self.batch_size = 16
self.hidden_size = 32 self.hidden_size = 32
self.ffn_hidden_size = 64 self.ffn_hidden_size = 64
self.global_dtype = 'bfloat16' self.global_dtype = "bfloat16"
self.rtol = 1e-3 self.rtol = 1e-3
self.atol = 1e-3 self.atol = 1e-3
self.eps = 1e-3 self.eps = 1e-3
self.fp8 = False self.fp8 = False
self.sequence_parallel = False self.sequence_parallel = False
def _train_one_step(self, layer, inp, optimizer, split_input='none', gather_output=False): def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False):
inp = paddle.to_tensor(inp, stop_gradient=True) inp = paddle.to_tensor(inp, stop_gradient=True)
assert split_input in ['none', 'column', 'row'] assert split_input in ["none", "column", "row"]
if split_input == 'column': if split_input == "column":
split_size = inp.shape[1] // self.world_size split_size = inp.shape[1] // self.world_size
input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)] input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)]
elif split_input == 'row': elif split_input == "row":
split_size = inp.shape[0] // self.world_size split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :] input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :]
else: else:
input_parallel = inp input_parallel = inp
input_parallel.stop_gradient = False input_parallel.stop_gradient = False
...@@ -71,12 +71,12 @@ class TestLayerNormMLPTp(unittest.TestCase): ...@@ -71,12 +71,12 @@ class TestLayerNormMLPTp(unittest.TestCase):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
optimizer.clear_grad() optimizer.clear_grad()
if split_input != 'none': if split_input != "none":
grad_input = [] grad_input = []
paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group) paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group)
if split_input == 'column': if split_input == "column":
grad_input = paddle.concat(grad_input, axis=1) grad_input = paddle.concat(grad_input, axis=1)
elif split_input == 'row': elif split_input == "row":
grad_input = paddle.concat(grad_input, axis=0) grad_input = paddle.concat(grad_input, axis=0)
else: else:
grad_input = input_parallel.grad grad_input = input_parallel.grad
...@@ -97,7 +97,7 @@ class TestLayerNormMLPTp(unittest.TestCase): ...@@ -97,7 +97,7 @@ class TestLayerNormMLPTp(unittest.TestCase):
ffn_hidden_size=self.ffn_hidden_size, ffn_hidden_size=self.ffn_hidden_size,
eps=self.eps, eps=self.eps,
set_parallel_mode=False, set_parallel_mode=False,
backend='paddle', backend="paddle",
) )
def _get_total_weight(local_weight, tp_group, axis): def _get_total_weight(local_weight, tp_group, axis):
...@@ -113,11 +113,15 @@ class TestLayerNormMLPTp(unittest.TestCase): ...@@ -113,11 +113,15 @@ class TestLayerNormMLPTp(unittest.TestCase):
layer_pd.fc1_weight.copy_(total_fc1_weight.T, True) layer_pd.fc1_weight.copy_(total_fc1_weight.T, True)
layer_pd.fc2_weight.copy_(total_fc2_weight.T, True) layer_pd.fc2_weight.copy_(total_fc2_weight.T, True)
assert_shape(layer_te.fc1_weight, assert_shape(
[self.ffn_hidden_size // self.model_parallel_size, self.hidden_size]) layer_te.fc1_weight,
[self.ffn_hidden_size // self.model_parallel_size, self.hidden_size],
)
assert_shape(layer_te.fc1_bias, [self.ffn_hidden_size // self.model_parallel_size]) assert_shape(layer_te.fc1_bias, [self.ffn_hidden_size // self.model_parallel_size])
assert_shape(layer_te.fc2_weight, assert_shape(
[self.hidden_size, self.ffn_hidden_size // self.model_parallel_size]) layer_te.fc2_weight,
[self.hidden_size, self.ffn_hidden_size // self.model_parallel_size],
)
assert_shape(layer_te.fc2_bias, [self.hidden_size]) assert_shape(layer_te.fc2_bias, [self.hidden_size])
optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
...@@ -133,8 +137,9 @@ class TestLayerNormMLPTp(unittest.TestCase): ...@@ -133,8 +137,9 @@ class TestLayerNormMLPTp(unittest.TestCase):
layer_te, layer_te,
inp, inp,
optimizer_te, optimizer_te,
split_input='row' if self.sequence_parallel else 'none', split_input="row" if self.sequence_parallel else "none",
gather_output=self.sequence_parallel) gather_output=self.sequence_parallel,
)
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd) loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
...@@ -148,7 +153,7 @@ class TestLayerNormMLPTpFp8(TestLayerNormMLPTp): ...@@ -148,7 +153,7 @@ class TestLayerNormMLPTpFp8(TestLayerNormMLPTp):
self.batch_size = 16 self.batch_size = 16
self.hidden_size = 32 self.hidden_size = 32
self.ffn_hidden_size = 64 self.ffn_hidden_size = 64
self.global_dtype = 'bfloat16' self.global_dtype = "bfloat16"
self.rtol = 1e-2 self.rtol = 1e-2
self.atol = 1e-2 self.atol = 1e-2
self.eps = 1e-3 self.eps = 1e-3
...@@ -164,7 +169,7 @@ class TestLayerNormMLPSp(TestLayerNormMLPTp): ...@@ -164,7 +169,7 @@ class TestLayerNormMLPSp(TestLayerNormMLPTp):
self.batch_size = 16 self.batch_size = 16
self.hidden_size = 32 self.hidden_size = 32
self.ffn_hidden_size = 64 self.ffn_hidden_size = 64
self.global_dtype = 'bfloat16' self.global_dtype = "bfloat16"
self.rtol = 1e-3 self.rtol = 1e-3
self.atol = 1e-3 self.atol = 1e-3
self.eps = 1e-3 self.eps = 1e-3
...@@ -180,7 +185,7 @@ class TestLayerNormMLPSpFp8(TestLayerNormMLPTp): ...@@ -180,7 +185,7 @@ class TestLayerNormMLPSpFp8(TestLayerNormMLPTp):
self.batch_size = 16 self.batch_size = 16
self.hidden_size = 32 self.hidden_size = 32
self.ffn_hidden_size = 64 self.ffn_hidden_size = 64
self.global_dtype = 'bfloat16' self.global_dtype = "bfloat16"
self.rtol = 1e-2 self.rtol = 1e-2
self.atol = 1e-2 self.atol = 1e-2
self.eps = 1e-3 self.eps = 1e-3
...@@ -188,5 +193,5 @@ class TestLayerNormMLPSpFp8(TestLayerNormMLPTp): ...@@ -188,5 +193,5 @@ class TestLayerNormMLPSpFp8(TestLayerNormMLPTp):
self.sequence_parallel = True self.sequence_parallel = True
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -23,14 +23,14 @@ class TELinear(te.Linear): ...@@ -23,14 +23,14 @@ class TELinear(te.Linear):
"""To pass is_first_microbatch""" """To pass is_first_microbatch"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
assert 'accumulate_steps' in kwargs assert "accumulate_steps" in kwargs
self.accumulate_steps = kwargs['accumulate_steps'] self.accumulate_steps = kwargs["accumulate_steps"]
del kwargs['accumulate_steps'] del kwargs["accumulate_steps"]
self._micro_batch_id = 0 self._micro_batch_id = 0
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
kwargs['is_first_microbatch'] = (self._micro_batch_id % self.accumulate_steps) == 0 kwargs["is_first_microbatch"] = (self._micro_batch_id % self.accumulate_steps) == 0
if paddle.is_grad_enabled() and self.training: if paddle.is_grad_enabled() and self.training:
self._micro_batch_id += 1 self._micro_batch_id += 1
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
...@@ -39,14 +39,16 @@ class TELinear(te.Linear): ...@@ -39,14 +39,16 @@ class TELinear(te.Linear):
class TEPipelineModel(PipelineLayer): class TEPipelineModel(PipelineLayer):
"""Model for pipeline parallel test""" """Model for pipeline parallel test"""
def __init__(self, def __init__(
in_features, self,
hidden_features, in_features,
weight_attrs, hidden_features,
use_te=True, weight_attrs,
use_fp8=False, use_te=True,
accumulate_steps=1, use_fp8=False,
**kwargs): accumulate_steps=1,
**kwargs,
):
self.in_features = in_features self.in_features = in_features
self.hidden_features = hidden_features self.hidden_features = hidden_features
self.fp8 = use_fp8 self.fp8 = use_fp8
...@@ -56,19 +58,23 @@ class TEPipelineModel(PipelineLayer): ...@@ -56,19 +58,23 @@ class TEPipelineModel(PipelineLayer):
Linear = TELinear if use_te else paddle.nn.Linear Linear = TELinear if use_te else paddle.nn.Linear
extra_kwargs = {} extra_kwargs = {}
if use_te: if use_te:
extra_kwargs['accumulate_steps'] = accumulate_steps extra_kwargs["accumulate_steps"] = accumulate_steps
model_desc = [ model_desc = [
LayerDesc(Linear, LayerDesc(
self.in_features, Linear,
self.hidden_features, self.in_features,
weight_attr=weight_attrs[0], self.hidden_features,
**extra_kwargs), weight_attr=weight_attrs[0],
LayerDesc(Linear, **extra_kwargs,
self.hidden_features, ),
self.in_features, LayerDesc(
weight_attr=weight_attrs[1], Linear,
**extra_kwargs), self.hidden_features,
self.in_features,
weight_attr=weight_attrs[1],
**extra_kwargs,
),
] ]
super().__init__(layers=model_desc, loss_fn=paddle.nn.CrossEntropyLoss(), **kwargs) super().__init__(layers=model_desc, loss_fn=paddle.nn.CrossEntropyLoss(), **kwargs)
...@@ -129,7 +135,7 @@ class TestLinearPipelineParallel(unittest.TestCase): ...@@ -129,7 +135,7 @@ class TestLinearPipelineParallel(unittest.TestCase):
self.micro_batch_size = 16 self.micro_batch_size = 16
self.in_features = 32 self.in_features = 32
self.hidden_features = 64 self.hidden_features = 64
self.global_dtype = 'float32' self.global_dtype = "float32"
self.rtol = 1e-5 self.rtol = 1e-5
self.atol = 1e-5 self.atol = 1e-5
self.iter = 10 self.iter = 10
...@@ -164,16 +170,18 @@ class TestLinearPipelineParallel(unittest.TestCase): ...@@ -164,16 +170,18 @@ class TestLinearPipelineParallel(unittest.TestCase):
# Check if model is split across ranks as expected # Check if model is split across ranks as expected
for name, sublayer in pipe_model.named_sublayers(): for name, sublayer in pipe_model.named_sublayers():
if name in ('_loss_fn', 'shared_layers'): if name in ("_loss_fn", "shared_layers"):
continue continue
if self.rank == 0: if self.rank == 0:
assert tuple(sublayer.weight.shape) == weight1_np.T.shape, \ assert tuple(sublayer.weight.shape) == weight1_np.T.shape, (
f"Shape does not match, expect: {weight1_np.T.shape}, " \ f"Shape does not match, expect: {weight1_np.T.shape}, "
f"actual: {tuple(sublayer.weight.shape)}" f"actual: {tuple(sublayer.weight.shape)}"
)
elif self.rank == 1: elif self.rank == 1:
assert tuple(sublayer.weight.shape) == weight2_np.T.shape, \ assert tuple(sublayer.weight.shape) == weight2_np.T.shape, (
f"Shape does not match, expect: {weight2_np.T.shape}, " \ f"Shape does not match, expect: {weight2_np.T.shape}, "
f"actual: {tuple(sublayer.weight.shape)}" f"actual: {tuple(sublayer.weight.shape)}"
)
standalone_model = StandaloneModel( standalone_model = StandaloneModel(
self.in_features, self.in_features,
...@@ -182,8 +190,9 @@ class TestLinearPipelineParallel(unittest.TestCase): ...@@ -182,8 +190,9 @@ class TestLinearPipelineParallel(unittest.TestCase):
) )
optimizer_te = paddle.optimizer.SGD(learning_rate=0.1, parameters=pipe_model.parameters()) optimizer_te = paddle.optimizer.SGD(learning_rate=0.1, parameters=pipe_model.parameters())
optimizer_pd = paddle.optimizer.SGD(learning_rate=0.1, optimizer_pd = paddle.optimizer.SGD(
parameters=standalone_model.parameters()) learning_rate=0.1, parameters=standalone_model.parameters()
)
pipe_model = fleet.distributed_model(pipe_model) pipe_model = fleet.distributed_model(pipe_model)
optimizer_te = fleet.distributed_optimizer(optimizer_te) optimizer_te = fleet.distributed_optimizer(optimizer_te)
...@@ -196,8 +205,9 @@ class TestLinearPipelineParallel(unittest.TestCase): ...@@ -196,8 +205,9 @@ class TestLinearPipelineParallel(unittest.TestCase):
return loss return loss
for i in range(self.iter): for i in range(self.iter):
inp = paddle.to_tensor(np.random.normal(size=[self.batch_size, self.in_features]), inp = paddle.to_tensor(
dtype=self.global_dtype) np.random.normal(size=[self.batch_size, self.in_features]), dtype=self.global_dtype
)
label = paddle.to_tensor(np.random.randint(self.in_features, size=[self.batch_size, 1])) label = paddle.to_tensor(np.random.randint(self.in_features, size=[self.batch_size, 1]))
loss_te = pipe_model.train_batch([inp, label], optimizer_te) loss_te = pipe_model.train_batch([inp, label], optimizer_te)
loss_pd = train_one_step(standalone_model, [inp, label], optimizer_pd) loss_pd = train_one_step(standalone_model, [inp, label], optimizer_pd)
...@@ -214,12 +224,12 @@ class TestLinearPipelineParallelFP8(TestLinearPipelineParallel): ...@@ -214,12 +224,12 @@ class TestLinearPipelineParallelFP8(TestLinearPipelineParallel):
self.micro_batch_size = 16 self.micro_batch_size = 16
self.in_features = 32 self.in_features = 32
self.hidden_features = 64 self.hidden_features = 64
self.global_dtype = 'float32' self.global_dtype = "float32"
self.rtol = 5e-2 self.rtol = 5e-2
self.atol = 5e-2 self.atol = 5e-2
self.iter = 10 self.iter = 10
self.fp8 = True self.fp8 = True
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -42,21 +42,21 @@ class TestLinearTp(unittest.TestCase): ...@@ -42,21 +42,21 @@ class TestLinearTp(unittest.TestCase):
self.batch_size = 16 self.batch_size = 16
self.in_features = 32 self.in_features = 32
self.out_features = 64 self.out_features = 64
self.global_dtype = 'bfloat16' self.global_dtype = "bfloat16"
self.rtol = 1e-3 self.rtol = 1e-3
self.atol = 1e-3 self.atol = 1e-3
self.fp8 = False self.fp8 = False
self.sequence_parallel = False self.sequence_parallel = False
def _train_one_step(self, layer, inp, optimizer, split_input='none', gather_output=False): def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False):
inp = paddle.to_tensor(inp, stop_gradient=True) inp = paddle.to_tensor(inp, stop_gradient=True)
assert split_input in ['none', 'column', 'row'] assert split_input in ["none", "column", "row"]
if split_input == 'column': if split_input == "column":
split_size = inp.shape[1] // self.world_size split_size = inp.shape[1] // self.world_size
input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)] input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)]
elif split_input == 'row': elif split_input == "row":
split_size = inp.shape[0] // self.world_size split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :] input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :]
else: else:
input_parallel = inp input_parallel = inp
input_parallel.stop_gradient = False input_parallel.stop_gradient = False
...@@ -69,12 +69,12 @@ class TestLinearTp(unittest.TestCase): ...@@ -69,12 +69,12 @@ class TestLinearTp(unittest.TestCase):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
optimizer.clear_grad() optimizer.clear_grad()
if split_input != 'none': if split_input != "none":
grad_input = [] grad_input = []
paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group) paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group)
if split_input == 'column': if split_input == "column":
grad_input = paddle.concat(grad_input, axis=1) grad_input = paddle.concat(grad_input, axis=1)
elif split_input == 'row': elif split_input == "row":
grad_input = paddle.concat(grad_input, axis=0) grad_input = paddle.concat(grad_input, axis=0)
else: else:
grad_input = input_parallel.grad grad_input = input_parallel.grad
...@@ -86,13 +86,13 @@ class TestLinearTp(unittest.TestCase): ...@@ -86,13 +86,13 @@ class TestLinearTp(unittest.TestCase):
layer_te = te.Linear( layer_te = te.Linear(
self.in_features, self.in_features,
self.out_features, self.out_features,
parallel_mode='column', parallel_mode="column",
sequence_parallel=self.sequence_parallel, sequence_parallel=self.sequence_parallel,
) )
layer_pd = te.Linear( layer_pd = te.Linear(
self.in_features, self.in_features,
self.out_features, self.out_features,
backend='paddle', backend="paddle",
) )
# Get total weight # Get total weight
total_weight = [] total_weight = []
...@@ -101,8 +101,9 @@ class TestLinearTp(unittest.TestCase): ...@@ -101,8 +101,9 @@ class TestLinearTp(unittest.TestCase):
total_weight = paddle.concat(total_weight, axis=0) total_weight = paddle.concat(total_weight, axis=0)
layer_pd.weight.copy_(total_weight.T, True) layer_pd.weight.copy_(total_weight.T, True)
assert_shape(layer_te.weight, assert_shape(
[self.out_features // self.model_parallel_size, self.in_features]) layer_te.weight, [self.out_features // self.model_parallel_size, self.in_features]
)
assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size]) assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size])
optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
...@@ -118,8 +119,9 @@ class TestLinearTp(unittest.TestCase): ...@@ -118,8 +119,9 @@ class TestLinearTp(unittest.TestCase):
layer_te, layer_te,
inp, inp,
optimizer_te, optimizer_te,
split_input='row' if self.sequence_parallel else 'none', split_input="row" if self.sequence_parallel else "none",
gather_output=True) gather_output=True,
)
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd) loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
...@@ -130,13 +132,13 @@ class TestLinearTp(unittest.TestCase): ...@@ -130,13 +132,13 @@ class TestLinearTp(unittest.TestCase):
layer_te = te.Linear( layer_te = te.Linear(
self.in_features, self.in_features,
self.out_features, self.out_features,
parallel_mode='row', parallel_mode="row",
sequence_parallel=self.sequence_parallel, sequence_parallel=self.sequence_parallel,
) )
layer_pd = te.Linear( layer_pd = te.Linear(
self.in_features, self.in_features,
self.out_features, self.out_features,
backend='paddle', backend="paddle",
) )
# Get total weight # Get total weight
total_weight = [] total_weight = []
...@@ -145,8 +147,9 @@ class TestLinearTp(unittest.TestCase): ...@@ -145,8 +147,9 @@ class TestLinearTp(unittest.TestCase):
total_weight = paddle.concat(total_weight, axis=1) total_weight = paddle.concat(total_weight, axis=1)
layer_pd.weight.copy_(total_weight.T, True) layer_pd.weight.copy_(total_weight.T, True)
assert_shape(layer_te.weight, assert_shape(
[self.out_features, self.in_features // self.model_parallel_size]) layer_te.weight, [self.out_features, self.in_features // self.model_parallel_size]
)
assert_shape(layer_te.bias, [self.out_features]) assert_shape(layer_te.bias, [self.out_features])
optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
...@@ -158,11 +161,13 @@ class TestLinearTp(unittest.TestCase): ...@@ -158,11 +161,13 @@ class TestLinearTp(unittest.TestCase):
for _ in range(5): for _ in range(5):
inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype) inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype)
with te.fp8_autocast(enabled=self.fp8): with te.fp8_autocast(enabled=self.fp8):
loss_tp, grad_input = self._train_one_step(layer_te, loss_tp, grad_input = self._train_one_step(
inp, layer_te,
optimizer_te, inp,
split_input='column', optimizer_te,
gather_output=self.sequence_parallel) split_input="column",
gather_output=self.sequence_parallel,
)
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd) loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
...@@ -176,7 +181,7 @@ class TestLinearTpFP8(TestLinearTp): ...@@ -176,7 +181,7 @@ class TestLinearTpFP8(TestLinearTp):
self.batch_size = 16 self.batch_size = 16
self.in_features = 32 self.in_features = 32
self.out_features = 64 self.out_features = 64
self.global_dtype = 'bfloat16' self.global_dtype = "bfloat16"
self.rtol = 1e-2 self.rtol = 1e-2
self.atol = 1e-2 self.atol = 1e-2
self.fp8 = True self.fp8 = True
...@@ -191,7 +196,7 @@ class TestLinearSp(TestLinearTp): ...@@ -191,7 +196,7 @@ class TestLinearSp(TestLinearTp):
self.batch_size = 16 self.batch_size = 16
self.in_features = 32 self.in_features = 32
self.out_features = 64 self.out_features = 64
self.global_dtype = 'bfloat16' self.global_dtype = "bfloat16"
self.rtol = 1e-3 self.rtol = 1e-3
self.atol = 1e-3 self.atol = 1e-3
self.fp8 = False self.fp8 = False
...@@ -206,12 +211,12 @@ class TestLinearSpFP8(TestLinearTp): ...@@ -206,12 +211,12 @@ class TestLinearSpFP8(TestLinearTp):
self.batch_size = 16 self.batch_size = 16
self.in_features = 32 self.in_features = 32
self.out_features = 64 self.out_features = 64
self.global_dtype = 'bfloat16' self.global_dtype = "bfloat16"
self.rtol = 1e-2 self.rtol = 1e-2
self.atol = 1e-2 self.atol = 1e-2
self.fp8 = True self.fp8 = True
self.sequence_parallel = True self.sequence_parallel = True
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -45,9 +45,9 @@ class TestTransformerTp(unittest.TestCase): ...@@ -45,9 +45,9 @@ class TestTransformerTp(unittest.TestCase):
self.ffn_hidden_size = 4096 self.ffn_hidden_size = 4096
self.q_seqlen = 128 self.q_seqlen = 128
self.kv_seqlen = 128 self.kv_seqlen = 128
self.mask_type = 'padding' self.mask_type = "padding"
self.layer_type = 'encoder' self.layer_type = "encoder"
self.global_dtype = 'bfloat16' self.global_dtype = "bfloat16"
self.rtol = 5e-2 self.rtol = 5e-2
self.atol = 5e-2 self.atol = 5e-2
self.eps = 1e-3 self.eps = 1e-3
...@@ -58,7 +58,7 @@ class TestTransformerTp(unittest.TestCase): ...@@ -58,7 +58,7 @@ class TestTransformerTp(unittest.TestCase):
inp, mask = inp_list inp, mask = inp_list
if sequence_parallel: if sequence_parallel:
split_size = inp.shape[0] // self.world_size split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :] input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :]
else: else:
input_parallel = inp input_parallel = inp
with te.fp8_autocast(enabled=fp8_enabled): with te.fp8_autocast(enabled=fp8_enabled):
...@@ -83,16 +83,18 @@ class TestTransformerTp(unittest.TestCase): ...@@ -83,16 +83,18 @@ class TestTransformerTp(unittest.TestCase):
self.num_heads, self.num_heads,
] ]
common_kwargs = { common_kwargs = {
'layernorm_epsilon': self.eps, "layernorm_epsilon": self.eps,
'hidden_dropout': 0.0, "hidden_dropout": 0.0,
'attention_dropout': 0.0, "attention_dropout": 0.0,
'self_attn_mask_type': self.mask_type, "self_attn_mask_type": self.mask_type,
'layer_type': self.layer_type, "layer_type": self.layer_type,
} }
layer_tp = te.TransformerLayer(*common_args, layer_tp = te.TransformerLayer(
**common_kwargs, *common_args,
set_parallel_mode=True, **common_kwargs,
sequence_parallel=self.sequence_parallel) set_parallel_mode=True,
sequence_parallel=self.sequence_parallel,
)
layer_single = te.TransformerLayer(*common_args, **common_kwargs, set_parallel_mode=False) layer_single = te.TransformerLayer(*common_args, **common_kwargs, set_parallel_mode=False)
def _get_total_weight(local_weight, tp_group, axis, interleave=False): def _get_total_weight(local_weight, tp_group, axis, interleave=False):
...@@ -103,12 +105,15 @@ class TestTransformerTp(unittest.TestCase): ...@@ -103,12 +105,15 @@ class TestTransformerTp(unittest.TestCase):
# Due to the interleaved qkv layout, need to concat on num_head # Due to the interleaved qkv layout, need to concat on num_head
# dimension for column parallel linear in MultiHeadAttention layer # dimension for column parallel linear in MultiHeadAttention layer
assert axis == 0 assert axis == 0
assert [3 * self.hidden_size // self.world_size, assert [
self.hidden_size] == partial_weight.shape 3 * self.hidden_size // self.world_size,
self.hidden_size,
] == partial_weight.shape
local_num_head = self.num_heads // self.world_size local_num_head = self.num_heads // self.world_size
for idx, _ in enumerate(total_weight): for idx, _ in enumerate(total_weight):
total_weight[idx] = total_weight[idx].reshape( total_weight[idx] = total_weight[idx].reshape(
[3, local_num_head, -1, self.hidden_size]) [3, local_num_head, -1, self.hidden_size]
)
total_weight = paddle.concat(total_weight, axis=1).reshape([-1, self.hidden_size]) total_weight = paddle.concat(total_weight, axis=1).reshape([-1, self.hidden_size])
else: else:
total_weight = paddle.concat(total_weight, axis=axis) total_weight = paddle.concat(total_weight, axis=axis)
...@@ -124,48 +129,56 @@ class TestTransformerTp(unittest.TestCase): ...@@ -124,48 +129,56 @@ class TestTransformerTp(unittest.TestCase):
weight_dst = _get_weight(layer_dst, weight_names) weight_dst = _get_weight(layer_dst, weight_names)
if partition_mode is None: if partition_mode is None:
total_weight = weight_src total_weight = weight_src
elif partition_mode == 'column': elif partition_mode == "column":
total_weight = _get_total_weight(weight_src, total_weight = _get_total_weight(
tp_group=self.tp_group, weight_src, tp_group=self.tp_group, axis=0, interleave=interleave
axis=0, )
interleave=interleave) elif partition_mode == "row":
elif partition_mode == 'row':
total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1) total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1)
else: else:
raise ValueError(f"Partition Mode {partition_mode} is not supported.") raise ValueError(f"Partition Mode {partition_mode} is not supported.")
assert weight_dst.shape == total_weight.shape, \ assert (
f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match." weight_dst.shape == total_weight.shape
), f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match."
weight_dst.copy_(total_weight, True) weight_dst.copy_(total_weight, True)
copy_weight(layer_tp, layer_single, None, ['self_attention', 'layernorm_qkv', 'ln_weight']) copy_weight(layer_tp, layer_single, None, ["self_attention", "layernorm_qkv", "ln_weight"])
copy_weight(layer_tp, copy_weight(
layer_single, layer_tp,
'column', ['self_attention', 'layernorm_qkv', 'weight'], layer_single,
interleave=True) "column",
copy_weight(layer_tp, layer_single, 'row', ['self_attention', 'proj', 'weight']) ["self_attention", "layernorm_qkv", "weight"],
copy_weight(layer_tp, layer_single, None, ['layernorm_mlp', 'ln_weight']) interleave=True,
copy_weight(layer_tp, layer_single, 'column', ['layernorm_mlp', 'fc1_weight']) )
copy_weight(layer_tp, layer_single, 'row', ['layernorm_mlp', 'fc2_weight']) copy_weight(layer_tp, layer_single, "row", ["self_attention", "proj", "weight"])
copy_weight(layer_tp, layer_single, None, ["layernorm_mlp", "ln_weight"])
copy_weight(layer_tp, layer_single, "column", ["layernorm_mlp", "fc1_weight"])
copy_weight(layer_tp, layer_single, "row", ["layernorm_mlp", "fc2_weight"])
if self.sequence_parallel: if self.sequence_parallel:
register_sequence_parallel_allreduce_hooks(layer_tp, accumulation_steps=1) register_sequence_parallel_allreduce_hooks(layer_tp, accumulation_steps=1)
optimizer_tp = paddle.optimizer.SGD(learning_rate=0.01, parameters=layer_tp.parameters()) optimizer_tp = paddle.optimizer.SGD(learning_rate=0.01, parameters=layer_tp.parameters())
optimizer_single = paddle.optimizer.SGD(learning_rate=0.01, optimizer_single = paddle.optimizer.SGD(
parameters=layer_single.parameters()) learning_rate=0.01, parameters=layer_single.parameters()
)
layer_tp = fleet.distributed_model(layer_tp) layer_tp = fleet.distributed_model(layer_tp)
optimizer_tp = fleet.distributed_optimizer(optimizer_tp) optimizer_tp = fleet.distributed_optimizer(optimizer_tp)
for _ in range(5): for _ in range(5):
inp = paddle.uniform([self.batch_size, self.q_seqlen, self.hidden_size], inp = paddle.uniform(
self.global_dtype) [self.batch_size, self.q_seqlen, self.hidden_size], self.global_dtype
mask = paddle.zeros(shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), )
dtype='bool') mask = paddle.zeros(
loss_tp, out_tp = self._train_one_step(layer_tp, [inp, mask], optimizer_tp, self.fp8, shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), dtype="bool"
self.sequence_parallel) )
loss_single, out_single = self._train_one_step(layer_single, [inp, mask], loss_tp, out_tp = self._train_one_step(
optimizer_single, self.fp8) layer_tp, [inp, mask], optimizer_tp, self.fp8, self.sequence_parallel
)
loss_single, out_single = self._train_one_step(
layer_single, [inp, mask], optimizer_single, self.fp8
)
assert_allclose(out_tp, out_single, rtol=self.rtol, atol=self.atol) assert_allclose(out_tp, out_single, rtol=self.rtol, atol=self.atol)
assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol) assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol)
...@@ -181,9 +194,9 @@ class TestTransformerTpFp8(TestTransformerTp): ...@@ -181,9 +194,9 @@ class TestTransformerTpFp8(TestTransformerTp):
self.ffn_hidden_size = 4096 self.ffn_hidden_size = 4096
self.q_seqlen = 128 self.q_seqlen = 128
self.kv_seqlen = 128 self.kv_seqlen = 128
self.mask_type = 'padding' self.mask_type = "padding"
self.layer_type = 'encoder' self.layer_type = "encoder"
self.global_dtype = 'bfloat16' self.global_dtype = "bfloat16"
self.rtol = 5e-2 self.rtol = 5e-2
self.atol = 0.5 self.atol = 0.5
self.eps = 1e-3 self.eps = 1e-3
...@@ -202,9 +215,9 @@ class TestTransformerSp(TestTransformerTp): ...@@ -202,9 +215,9 @@ class TestTransformerSp(TestTransformerTp):
self.ffn_hidden_size = 4096 self.ffn_hidden_size = 4096
self.q_seqlen = 128 self.q_seqlen = 128
self.kv_seqlen = 128 self.kv_seqlen = 128
self.mask_type = 'padding' self.mask_type = "padding"
self.layer_type = 'encoder' self.layer_type = "encoder"
self.global_dtype = 'bfloat16' self.global_dtype = "bfloat16"
self.rtol = 5e-2 self.rtol = 5e-2
self.atol = 5e-2 self.atol = 5e-2
self.eps = 1e-3 self.eps = 1e-3
...@@ -223,9 +236,9 @@ class TestTransformerSpFp8(TestTransformerSp): ...@@ -223,9 +236,9 @@ class TestTransformerSpFp8(TestTransformerSp):
self.ffn_hidden_size = 4096 self.ffn_hidden_size = 4096
self.q_seqlen = 128 self.q_seqlen = 128
self.kv_seqlen = 128 self.kv_seqlen = 128
self.mask_type = 'padding' self.mask_type = "padding"
self.layer_type = 'encoder' self.layer_type = "encoder"
self.global_dtype = 'bfloat16' self.global_dtype = "bfloat16"
self.rtol = 5e-2 self.rtol = 5e-2
self.atol = 0.5 self.atol = 0.5
self.eps = 1e-3 self.eps = 1e-3
...@@ -233,5 +246,5 @@ class TestTransformerSpFp8(TestTransformerSp): ...@@ -233,5 +246,5 @@ class TestTransformerSpFp8(TestTransformerSp):
self.sequence_parallel = True self.sequence_parallel = True
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -37,14 +37,17 @@ def main(): ...@@ -37,14 +37,17 @@ def main():
enable_recompute = int(sys.argv[1]) enable_recompute = int(sys.argv[1])
use_reentrant = int(sys.argv[2]) use_reentrant = int(sys.argv[2])
layers = paddle.nn.LayerList([ layers = paddle.nn.LayerList(
te.TransformerLayer( [
hidden_size, te.TransformerLayer(
ffn_hidden_size, hidden_size,
num_heads, ffn_hidden_size,
layer_type='encoder', num_heads,
) for _ in range(num_layers) layer_type="encoder",
]) )
for _ in range(num_layers)
]
)
model = Net(layers) model = Net(layers)
optimizer = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters()) optimizer = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters())
...@@ -52,7 +55,7 @@ def main(): ...@@ -52,7 +55,7 @@ def main():
for _ in range(10): for _ in range(10):
inp = paddle.uniform([batch_size, q_seqlen, hidden_size]) inp = paddle.uniform([batch_size, q_seqlen, hidden_size])
inp.stop_gradient = False inp.stop_gradient = False
mask = paddle.zeros(shape=(batch_size, 1, q_seqlen, kv_seqlen), dtype='bool') mask = paddle.zeros(shape=(batch_size, 1, q_seqlen, kv_seqlen), dtype="bool")
with te.fp8_autocast(enabled=True): with te.fp8_autocast(enabled=True):
out = model(inp, mask, enable_recompute, use_reentrant) out = model(inp, mask, enable_recompute, use_reentrant)
loss = out.mean() loss = out.mean()
......
...@@ -8,4 +8,4 @@ def test_import(): ...@@ -8,4 +8,4 @@ def test_import():
""" """
Test if Paddle extension can be imported normally Test if Paddle extension can be imported normally
""" """
import transformer_engine.paddle # pylint: disable=unused-import import transformer_engine.paddle # pylint: disable=unused-import
...@@ -26,14 +26,14 @@ def setup(): ...@@ -26,14 +26,14 @@ def setup():
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('use_fp8', [True, False]) @pytest.mark.parametrize("use_fp8", [True, False])
def test_checkpoint(use_fp8): def test_checkpoint(use_fp8):
"""Test checkpoint save / load""" """Test checkpoint save / load"""
bs = 16 bs = 16
in_features = 16 in_features = 16
out_features = 32 out_features = 32
file_name = "model.pdparams" file_name = "model.pdparams"
input_tensor = paddle.uniform(shape=(bs, in_features), dtype='float32') input_tensor = paddle.uniform(shape=(bs, in_features), dtype="float32")
model = te.Linear(in_features, out_features) model = te.Linear(in_features, out_features)
model_loaded = te.Linear(in_features, out_features) model_loaded = te.Linear(in_features, out_features)
# Populate amax_history # Populate amax_history
...@@ -91,15 +91,18 @@ class TestLinear: ...@@ -91,15 +91,18 @@ class TestLinear:
""" """
@staticmethod @staticmethod
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0), @pytest.mark.skipif(
reason="BF16 Linear requires Ampere+ GPU") paddle.device.cuda.get_device_capability() < (8, 0),
@pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES) reason="BF16 Linear requires Ampere+ GPU",
@pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]]) )
@pytest.mark.parametrize('no_dgrad', [True, False]) @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
@pytest.mark.parametrize('no_wgrad', [True, False]) @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32']) @pytest.mark.parametrize("no_dgrad", [True, False])
def test_linear_bf16(bs, in_features, out_features, has_bias, no_dbias, no_dgrad, no_wgrad, @pytest.mark.parametrize("no_wgrad", [True, False])
activation_dtype): @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
def test_linear_bf16(
bs, in_features, out_features, has_bias, no_dbias, no_dgrad, no_wgrad, activation_dtype
):
""" """
Test BF16 Linear Test BF16 Linear
""" """
...@@ -112,10 +115,9 @@ class TestLinear: ...@@ -112,10 +115,9 @@ class TestLinear:
paddle.set_default_dtype(activation_dtype) paddle.set_default_dtype(activation_dtype)
layer_te = te.Linear(in_features, out_features, bias_attr=None if has_bias else False) layer_te = te.Linear(in_features, out_features, bias_attr=None if has_bias else False)
layer_pd = te.Linear(in_features, layer_pd = te.Linear(
out_features, in_features, out_features, bias_attr=None if has_bias else False, backend="paddle"
bias_attr=None if has_bias else False, )
backend='paddle')
layer_pd.weight.copy_(layer_te.weight.T, True) layer_pd.weight.copy_(layer_te.weight.T, True)
if has_bias: if has_bias:
layer_pd.bias.copy_(layer_te.bias, True) layer_pd.bias.copy_(layer_te.bias, True)
...@@ -139,15 +141,25 @@ class TestLinear: ...@@ -139,15 +141,25 @@ class TestLinear:
@staticmethod @staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES) @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
@pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]]) @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize('no_dgrad', [True, False]) @pytest.mark.parametrize("no_dgrad", [True, False])
@pytest.mark.parametrize('no_wgrad', [True, False]) @pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize('fp8_wgrad', [True, False]) @pytest.mark.parametrize("fp8_wgrad", [True, False])
@pytest.mark.parametrize('do_calibration', [True, False]) @pytest.mark.parametrize("do_calibration", [True, False])
@pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32']) @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
def test_linear_fp8(bs, in_features, out_features, has_bias, no_dbias, no_dgrad, no_wgrad, def test_linear_fp8(
fp8_wgrad, do_calibration, activation_dtype): bs,
in_features,
out_features,
has_bias,
no_dbias,
no_dgrad,
no_wgrad,
fp8_wgrad,
do_calibration,
activation_dtype,
):
""" """
Test FP8 Linear Test FP8 Linear
""" """
...@@ -170,7 +182,7 @@ class TestLinear: ...@@ -170,7 +182,7 @@ class TestLinear:
in_features=in_features, in_features=in_features,
out_features=out_features, out_features=out_features,
bias_attr=None if has_bias else False, bias_attr=None if has_bias else False,
backend='paddle', backend="paddle",
) )
layer_pd.weight.copy_(layer_te.weight.T, True) layer_pd.weight.copy_(layer_te.weight.T, True)
if has_bias: if has_bias:
...@@ -182,8 +194,9 @@ class TestLinear: ...@@ -182,8 +194,9 @@ class TestLinear:
layer_te.bias.stop_gradient = no_dbias layer_te.bias.stop_gradient = no_dbias
layer_pd.bias.stop_gradient = no_dbias layer_pd.bias.stop_gradient = no_dbias
with fp8_autocast(enabled=not do_calibration, calibrating=do_calibration, with fp8_autocast(
fp8_recipe=recipe): enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe
):
out_ref, grad_input_ref = calc_output_and_grad(layer_pd, input_tensor, grad_out) out_ref, grad_input_ref = calc_output_and_grad(layer_pd, input_tensor, grad_out)
out, grad_input = calc_output_and_grad(layer_te, input_tensor, grad_out) out, grad_input = calc_output_and_grad(layer_te, input_tensor, grad_out)
...@@ -199,9 +212,9 @@ class TestLinear: ...@@ -199,9 +212,9 @@ class TestLinear:
@staticmethod @staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES) @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
@pytest.mark.parametrize('activation_dtype', ['bfloat16']) @pytest.mark.parametrize("activation_dtype", ["bfloat16"])
@pytest.mark.parametrize('num_microbatch', [8]) @pytest.mark.parametrize("num_microbatch", [8])
def test_linear_fp8_microbatch(bs, in_features, out_features, activation_dtype, num_microbatch): def test_linear_fp8_microbatch(bs, in_features, out_features, activation_dtype, num_microbatch):
""" """
Test FP8 Linear Test FP8 Linear
...@@ -236,17 +249,16 @@ class TestLinear: ...@@ -236,17 +249,16 @@ class TestLinear:
out_ref.backward(grad_out) out_ref.backward(grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol) assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(layer_cached.weight.grad, assert_allclose(
layer_normal.weight.grad, layer_cached.weight.grad, layer_normal.weight.grad, rtol=rtol, atol=atol
rtol=rtol, )
atol=atol)
@pytest.mark.parametrize('bs,hidden_size', NORM_CASES) @pytest.mark.parametrize("bs,hidden_size", NORM_CASES)
@pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]]) @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize('no_dgrad', [True, False]) @pytest.mark.parametrize("no_dgrad", [True, False])
@pytest.mark.parametrize('no_wgrad', [True, False]) @pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32']) @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
def test_layernorm_bf16(bs, hidden_size, has_bias, no_dbias, no_dgrad, no_wgrad, activation_dtype): def test_layernorm_bf16(bs, hidden_size, has_bias, no_dbias, no_dgrad, no_wgrad, activation_dtype):
""" """
Test BF16 LayerNorm Test BF16 LayerNorm
...@@ -261,10 +273,9 @@ def test_layernorm_bf16(bs, hidden_size, has_bias, no_dbias, no_dgrad, no_wgrad, ...@@ -261,10 +273,9 @@ def test_layernorm_bf16(bs, hidden_size, has_bias, no_dbias, no_dgrad, no_wgrad,
paddle.set_default_dtype(activation_dtype) paddle.set_default_dtype(activation_dtype)
layer_te = te.LayerNorm(hidden_size=hidden_size, eps=eps, bias_attr=None if has_bias else False) layer_te = te.LayerNorm(hidden_size=hidden_size, eps=eps, bias_attr=None if has_bias else False)
layer_pd = te.LayerNorm(hidden_size=hidden_size, layer_pd = te.LayerNorm(
eps=eps, hidden_size=hidden_size, eps=eps, bias_attr=None if has_bias else False, backend="paddle"
bias_attr=None if has_bias else False, )
backend='paddle')
layer_pd.weight.copy_(layer_te.weight, True) layer_pd.weight.copy_(layer_te.weight, True)
if has_bias: if has_bias:
layer_pd.bias.copy_(layer_te.bias, True) layer_pd.bias.copy_(layer_te.bias, True)
...@@ -293,17 +304,29 @@ class TestLayerNormLinear: ...@@ -293,17 +304,29 @@ class TestLayerNormLinear:
""" """
@staticmethod @staticmethod
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0), @pytest.mark.skipif(
reason="BF16 Linear requires Ampere+ GPU") paddle.device.cuda.get_device_capability() < (8, 0),
@pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES) reason="BF16 Linear requires Ampere+ GPU",
@pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]]) )
@pytest.mark.parametrize('no_dgrad', [True, False]) @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
@pytest.mark.parametrize('no_wgrad', [True, False]) @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize('return_ln_out', [True, False]) @pytest.mark.parametrize("no_dgrad", [True, False])
@pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32']) @pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm']) @pytest.mark.parametrize("return_ln_out", [True, False])
def test_layernorm_linear_bf16(bs, in_features, out_features, has_bias, no_dbias, no_dgrad, @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
no_wgrad, return_ln_out, activation_dtype, normalization): @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
def test_layernorm_linear_bf16(
bs,
in_features,
out_features,
has_bias,
no_dbias,
no_dgrad,
no_wgrad,
return_ln_out,
activation_dtype,
normalization,
):
""" """
Test BF16 LayerNormLinear Layer Test BF16 LayerNormLinear Layer
""" """
...@@ -315,7 +338,7 @@ class TestLayerNormLinear: ...@@ -315,7 +338,7 @@ class TestLayerNormLinear:
input_tensor.stop_gradient = no_dgrad input_tensor.stop_gradient = no_dgrad
grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype) grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)
eps = 1e-3 eps = 1e-3
has_ln_bias = normalization == 'LayerNorm' has_ln_bias = normalization == "LayerNorm"
layer_te = te.LayerNormLinear( layer_te = te.LayerNormLinear(
in_features=in_features, in_features=in_features,
...@@ -333,7 +356,7 @@ class TestLayerNormLinear: ...@@ -333,7 +356,7 @@ class TestLayerNormLinear:
normalization=normalization, normalization=normalization,
bias_attr=None if has_bias else False, bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out, return_layernorm_output=return_ln_out,
backend='paddle', backend="paddle",
) )
layer_pd.ln_weight.copy_(layer_te.ln_weight, True) layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
...@@ -355,11 +378,11 @@ class TestLayerNormLinear: ...@@ -355,11 +378,11 @@ class TestLayerNormLinear:
layer_pd.bias.stop_gradient = no_dbias layer_pd.bias.stop_gradient = no_dbias
out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out( out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out(
layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out) layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out
out, ln_out, grad_input = calc_output_and_grad_ln_out(layer_te, )
input_tensor, out, ln_out, grad_input = calc_output_and_grad_ln_out(
grad_out, layer_te, input_tensor, grad_out, return_ln_out=return_ln_out
return_ln_out=return_ln_out) )
assert_allclose(out, out_ref, rtol=rtol, atol=atol) assert_allclose(out, out_ref, rtol=rtol, atol=atol)
if not no_dgrad: if not no_dgrad:
...@@ -377,18 +400,29 @@ class TestLayerNormLinear: ...@@ -377,18 +400,29 @@ class TestLayerNormLinear:
@staticmethod @staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES) @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
@pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]]) @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize('no_dgrad', [True, False]) @pytest.mark.parametrize("no_dgrad", [True, False])
@pytest.mark.parametrize('no_wgrad', [True, False]) @pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize('fp8_wgrad', [True, False]) @pytest.mark.parametrize("fp8_wgrad", [True, False])
@pytest.mark.parametrize('do_calibration', [True, False]) @pytest.mark.parametrize("do_calibration", [True, False])
@pytest.mark.parametrize('return_ln_out', [True, False]) @pytest.mark.parametrize("return_ln_out", [True, False])
@pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32']) @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
@pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm']) @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
def test_layernorm_linear_fp8(bs, in_features, out_features, has_bias, no_dbias, no_dgrad, def test_layernorm_linear_fp8(
no_wgrad, fp8_wgrad, do_calibration, return_ln_out, bs,
activation_dtype, normalization): in_features,
out_features,
has_bias,
no_dbias,
no_dgrad,
no_wgrad,
fp8_wgrad,
do_calibration,
return_ln_out,
activation_dtype,
normalization,
):
""" """
Test FP8 LayerNormLinear Layer Test FP8 LayerNormLinear Layer
""" """
...@@ -400,7 +434,7 @@ class TestLayerNormLinear: ...@@ -400,7 +434,7 @@ class TestLayerNormLinear:
input_tensor.stop_gradient = no_dgrad input_tensor.stop_gradient = no_dgrad
grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype) grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)
eps = 1e-3 eps = 1e-3
has_ln_bias = normalization == 'LayerNorm' has_ln_bias = normalization == "LayerNorm"
recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad)) recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad))
...@@ -420,7 +454,7 @@ class TestLayerNormLinear: ...@@ -420,7 +454,7 @@ class TestLayerNormLinear:
normalization=normalization, normalization=normalization,
bias_attr=None if has_bias else False, bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out, return_layernorm_output=return_ln_out,
backend='paddle', backend="paddle",
) )
layer_pd.ln_weight.copy_(layer_te.ln_weight, True) layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
...@@ -441,14 +475,15 @@ class TestLayerNormLinear: ...@@ -441,14 +475,15 @@ class TestLayerNormLinear:
layer_te.bias.stop_gradient = no_dbias layer_te.bias.stop_gradient = no_dbias
layer_pd.bias.stop_gradient = no_dbias layer_pd.bias.stop_gradient = no_dbias
with fp8_autocast(enabled=not do_calibration, calibrating=do_calibration, with fp8_autocast(
fp8_recipe=recipe): enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe
):
out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out( out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out(
layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out) layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out
out, ln_out, grad_input = calc_output_and_grad_ln_out(layer_te, )
input_tensor, out, ln_out, grad_input = calc_output_and_grad_ln_out(
grad_out, layer_te, input_tensor, grad_out, return_ln_out=return_ln_out
return_ln_out=return_ln_out) )
assert_allclose(out, out_ref, rtol=rtol, atol=atol) assert_allclose(out, out_ref, rtol=rtol, atol=atol)
if not no_dgrad: if not no_dgrad:
...@@ -468,11 +503,12 @@ class TestLayerNormLinear: ...@@ -468,11 +503,12 @@ class TestLayerNormLinear:
@staticmethod @staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES) @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
@pytest.mark.parametrize('activation_dtype', ['bfloat16']) @pytest.mark.parametrize("activation_dtype", ["bfloat16"])
@pytest.mark.parametrize('num_microbatch', [8]) @pytest.mark.parametrize("num_microbatch", [8])
def test_layernorm_linear_fp8_microbatch(bs, in_features, out_features, activation_dtype, def test_layernorm_linear_fp8_microbatch(
num_microbatch): bs, in_features, out_features, activation_dtype, num_microbatch
):
""" """
Test FP8 LayerNormLinear Layer Test FP8 LayerNormLinear Layer
""" """
...@@ -513,14 +549,12 @@ class TestLayerNormLinear: ...@@ -513,14 +549,12 @@ class TestLayerNormLinear:
out_ref.backward(grad_out) out_ref.backward(grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol) assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(layer_cached.weight.grad, assert_allclose(
layer_normal.weight.grad, layer_cached.weight.grad, layer_normal.weight.grad, rtol=rtol, atol=atol
rtol=rtol, )
atol=atol) assert_allclose(
assert_allclose(layer_cached.ln_weight.grad, layer_cached.ln_weight.grad, layer_normal.ln_weight.grad, rtol=rtol, atol=atol
layer_normal.ln_weight.grad, )
rtol=rtol,
atol=atol)
class TestLayerNormMLP: class TestLayerNormMLP:
...@@ -529,19 +563,31 @@ class TestLayerNormMLP: ...@@ -529,19 +563,31 @@ class TestLayerNormMLP:
""" """
@staticmethod @staticmethod
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0), @pytest.mark.skipif(
reason="BF16 Linear requires Ampere+ GPU") paddle.device.cuda.get_device_capability() < (8, 0),
@pytest.mark.parametrize('bs,hidden_size,ffn_hidden_size', LINEAR_CASES) reason="BF16 Linear requires Ampere+ GPU",
@pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]]) )
@pytest.mark.parametrize('no_dgrad', [True, False]) @pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES)
@pytest.mark.parametrize('no_wgrad', [True, False]) @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize('return_ln_out', [True, False]) @pytest.mark.parametrize("no_dgrad", [True, False])
@pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32']) @pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm']) @pytest.mark.parametrize("return_ln_out", [True, False])
@pytest.mark.parametrize('activation', ['gelu', 'swiglu']) @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
def test_layernorm_mlp_bf16(bs, hidden_size, ffn_hidden_size, has_bias, no_dbias, no_dgrad, @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
no_wgrad, return_ln_out, activation_dtype, normalization, @pytest.mark.parametrize("activation", ["gelu", "swiglu"])
activation): def test_layernorm_mlp_bf16(
bs,
hidden_size,
ffn_hidden_size,
has_bias,
no_dbias,
no_dgrad,
no_wgrad,
return_ln_out,
activation_dtype,
normalization,
activation,
):
""" """
Tests for TestLayerNormMLP layer Tests for TestLayerNormMLP layer
""" """
...@@ -553,7 +599,7 @@ class TestLayerNormMLP: ...@@ -553,7 +599,7 @@ class TestLayerNormMLP:
input_tensor.stop_gradient = no_dgrad input_tensor.stop_gradient = no_dgrad
grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
eps = 1e-3 eps = 1e-3
has_ln_bias = normalization == 'LayerNorm' has_ln_bias = normalization == "LayerNorm"
layer_te = te.LayerNormMLP( layer_te = te.LayerNormMLP(
hidden_size=hidden_size, hidden_size=hidden_size,
...@@ -572,7 +618,7 @@ class TestLayerNormMLP: ...@@ -572,7 +618,7 @@ class TestLayerNormMLP:
activation=activation, activation=activation,
bias_attr=None if has_bias else False, bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out, return_layernorm_output=return_ln_out,
backend='paddle', backend="paddle",
) )
layer_pd.ln_weight.copy_(layer_te.ln_weight, True) layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
if has_ln_bias: if has_ln_bias:
...@@ -599,55 +645,63 @@ class TestLayerNormMLP: ...@@ -599,55 +645,63 @@ class TestLayerNormMLP:
layer_pd.fc2_bias.stop_gradient = no_dbias layer_pd.fc2_bias.stop_gradient = no_dbias
out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out( out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out(
layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out) layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out
out, ln_out, grad_input = calc_output_and_grad_ln_out(layer_te, )
input_tensor, out, ln_out, grad_input = calc_output_and_grad_ln_out(
grad_out, layer_te, input_tensor, grad_out, return_ln_out=return_ln_out
return_ln_out=return_ln_out) )
assert_allclose(out, out_ref, rtol=rtol, atol=atol) assert_allclose(out, out_ref, rtol=rtol, atol=atol)
if not no_dgrad: if not no_dgrad:
assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
if not no_wgrad: if not no_wgrad:
assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol) assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol)
assert_allclose(layer_te.fc1_weight.grad, assert_allclose(
layer_pd.fc1_weight.grad.T, layer_te.fc1_weight.grad, layer_pd.fc1_weight.grad.T, rtol=rtol, atol=atol
rtol=rtol, )
atol=atol) assert_allclose(
assert_allclose(layer_te.fc2_weight.grad, layer_te.fc2_weight.grad, layer_pd.fc2_weight.grad.T, rtol=rtol, atol=atol
layer_pd.fc2_weight.grad.T, )
rtol=rtol,
atol=atol)
if not no_dbias: if not no_dbias:
if has_ln_bias: if has_ln_bias:
assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol) assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol)
if has_bias: if has_bias:
assert_allclose(layer_te.fc1_bias.grad, assert_allclose(
layer_pd.fc1_bias.grad, layer_te.fc1_bias.grad, layer_pd.fc1_bias.grad, rtol=rtol, atol=atol
rtol=rtol, )
atol=atol) assert_allclose(
assert_allclose(layer_te.fc2_bias.grad, layer_te.fc2_bias.grad, layer_pd.fc2_bias.grad, rtol=rtol, atol=atol
layer_pd.fc2_bias.grad, )
rtol=rtol,
atol=atol)
if return_ln_out: if return_ln_out:
assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol) assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol)
@staticmethod @staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('bs,hidden_size,ffn_hidden_size', LINEAR_CASES) @pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES)
@pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]]) @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize('no_dgrad', [True, False]) @pytest.mark.parametrize("no_dgrad", [True, False])
@pytest.mark.parametrize('no_wgrad', [True, False]) @pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize('fp8_wgrad', [True, False]) @pytest.mark.parametrize("fp8_wgrad", [True, False])
@pytest.mark.parametrize('do_calibration', [True, False]) @pytest.mark.parametrize("do_calibration", [True, False])
@pytest.mark.parametrize('return_ln_out', [True, False]) @pytest.mark.parametrize("return_ln_out", [True, False])
@pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32']) @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
@pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm']) @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
@pytest.mark.parametrize('activation', ['gelu', 'swiglu']) @pytest.mark.parametrize("activation", ["gelu", "swiglu"])
def test_layernorm_mlp_fp8(bs, hidden_size, ffn_hidden_size, has_bias, no_dbias, no_dgrad, def test_layernorm_mlp_fp8(
no_wgrad, fp8_wgrad, do_calibration, return_ln_out, activation_dtype, bs,
normalization, activation): hidden_size,
ffn_hidden_size,
has_bias,
no_dbias,
no_dgrad,
no_wgrad,
fp8_wgrad,
do_calibration,
return_ln_out,
activation_dtype,
normalization,
activation,
):
""" """
Test FP8 LayerNormMLP Layer Test FP8 LayerNormMLP Layer
""" """
...@@ -659,7 +713,7 @@ class TestLayerNormMLP: ...@@ -659,7 +713,7 @@ class TestLayerNormMLP:
input_tensor.stop_gradient = no_dgrad input_tensor.stop_gradient = no_dgrad
grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
eps = 1e-3 eps = 1e-3
has_ln_bias = normalization == 'LayerNorm' has_ln_bias = normalization == "LayerNorm"
recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad)) recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad))
...@@ -681,7 +735,7 @@ class TestLayerNormMLP: ...@@ -681,7 +735,7 @@ class TestLayerNormMLP:
activation=activation, activation=activation,
bias_attr=None if has_bias else False, bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out, return_layernorm_output=return_ln_out,
backend='paddle', backend="paddle",
) )
layer_pd.ln_weight.copy_(layer_te.ln_weight, True) layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
if has_ln_bias: if has_ln_bias:
...@@ -707,40 +761,37 @@ class TestLayerNormMLP: ...@@ -707,40 +761,37 @@ class TestLayerNormMLP:
layer_pd.fc1_bias.stop_gradient = no_dbias layer_pd.fc1_bias.stop_gradient = no_dbias
layer_pd.fc2_bias.stop_gradient = no_dbias layer_pd.fc2_bias.stop_gradient = no_dbias
with fp8_autocast(enabled=not do_calibration, calibrating=do_calibration, with fp8_autocast(
fp8_recipe=recipe): enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe
):
out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out( out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out(
layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out) layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out
out, ln_out, grad_input = calc_output_and_grad_ln_out(layer_te, )
input_tensor, out, ln_out, grad_input = calc_output_and_grad_ln_out(
grad_out, layer_te, input_tensor, grad_out, return_ln_out=return_ln_out
return_ln_out=return_ln_out) )
assert_allclose(out, out_ref, rtol=rtol, atol=atol) assert_allclose(out, out_ref, rtol=rtol, atol=atol)
if not no_dgrad: if not no_dgrad:
assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
if not no_wgrad: if not no_wgrad:
assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol) assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol)
assert_allclose(layer_te.fc1_weight.grad, assert_allclose(
layer_pd.fc1_weight.grad.T, layer_te.fc1_weight.grad, layer_pd.fc1_weight.grad.T, rtol=rtol, atol=atol
rtol=rtol, )
atol=atol) assert_allclose(
assert_allclose(layer_te.fc2_weight.grad, layer_te.fc2_weight.grad, layer_pd.fc2_weight.grad.T, rtol=rtol, atol=atol
layer_pd.fc2_weight.grad.T, )
rtol=rtol,
atol=atol)
if not no_dbias: if not no_dbias:
if has_ln_bias: if has_ln_bias:
assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol) assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol)
if has_bias: if has_bias:
assert_allclose(layer_te.fc1_bias.grad, assert_allclose(
layer_pd.fc1_bias.grad, layer_te.fc1_bias.grad, layer_pd.fc1_bias.grad, rtol=rtol, atol=atol
rtol=rtol, )
atol=atol) assert_allclose(
assert_allclose(layer_te.fc2_bias.grad, layer_te.fc2_bias.grad, layer_pd.fc2_bias.grad, rtol=rtol, atol=atol
layer_pd.fc2_bias.grad, )
rtol=rtol,
atol=atol)
if return_ln_out: if return_ln_out:
assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol) assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol)
...@@ -749,11 +800,12 @@ class TestLayerNormMLP: ...@@ -749,11 +800,12 @@ class TestLayerNormMLP:
@staticmethod @staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('bs,hidden_size,ffn_hidden_size', LINEAR_CASES) @pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES)
@pytest.mark.parametrize('activation_dtype', ['bfloat16']) @pytest.mark.parametrize("activation_dtype", ["bfloat16"])
@pytest.mark.parametrize('num_microbatch', [8]) @pytest.mark.parametrize("num_microbatch", [8])
def test_layernorm_mlp_fp8_microbatch(bs, hidden_size, ffn_hidden_size, activation_dtype, def test_layernorm_mlp_fp8_microbatch(
num_microbatch): bs, hidden_size, ffn_hidden_size, activation_dtype, num_microbatch
):
""" """
Test FP8 LayerNormMLP Layer Test FP8 LayerNormMLP Layer
""" """
...@@ -803,28 +855,26 @@ class TestLayerNormMLP: ...@@ -803,28 +855,26 @@ class TestLayerNormMLP:
out_ref.backward(grad_out) out_ref.backward(grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol) assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(layer_cached.ln_weight.grad, assert_allclose(
layer_normal.ln_weight.grad, layer_cached.ln_weight.grad, layer_normal.ln_weight.grad, rtol=rtol, atol=atol
rtol=rtol, )
atol=atol) assert_allclose(
assert_allclose(layer_cached.fc1_weight.grad, layer_cached.fc1_weight.grad, layer_normal.fc1_weight.grad, rtol=rtol, atol=atol
layer_normal.fc1_weight.grad, )
rtol=rtol, assert_allclose(
atol=atol) layer_cached.fc2_weight.grad, layer_normal.fc2_weight.grad, rtol=rtol, atol=atol
assert_allclose(layer_cached.fc2_weight.grad, )
layer_normal.fc2_weight.grad,
rtol=rtol,
atol=atol) @pytest.mark.parametrize("bs", [1, 2])
@pytest.mark.parametrize("hidden_size, num_heads", [[1024, 16]])
@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]])
@pytest.mark.parametrize('bs', [1, 2]) @pytest.mark.parametrize("attn_type", ["self", "cross"])
@pytest.mark.parametrize('hidden_size, num_heads', [[1024, 16]]) @pytest.mark.parametrize("mask_type", ["causal", "padding"])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[1024, 1024]]) @pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"])
@pytest.mark.parametrize('attn_type', ['self', 'cross']) def test_dot_product_attention(
@pytest.mark.parametrize('mask_type', ['causal', 'padding']) bs, hidden_size, num_heads, q_seqlen, kv_seqlen, attn_type, mask_type, math_dtype
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16']) ):
def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, attn_type,
mask_type, math_dtype):
""" """
Test DotProductAttention Layer Test DotProductAttention Layer
""" """
...@@ -835,53 +885,64 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, ...@@ -835,53 +885,64 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
# Skip if cuDNN fused attention is not supported # Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported( if not is_fused_attention_supported(
num_heads=num_heads, num_heads=num_heads,
num_gqa_groups=num_heads, num_gqa_groups=num_heads,
q_seqlen=q_seqlen, q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen, kv_seqlen=kv_seqlen,
head_size=head_size, head_size=head_size,
dtype=math_dtype, dtype=math_dtype,
dropout=0.0, dropout=0.0,
qkv_layout="bshd_bshd_bshd", qkv_layout="bshd_bshd_bshd",
bias_type="no_bias", bias_type="no_bias",
mask_type=mask_type, mask_type=mask_type,
): ):
pytest.skip("cuDNN fused attention is not supported") pytest.skip("cuDNN fused attention is not supported")
attn_q_input = paddle.normal(mean=0.0, std=0.02, attn_q_input = paddle.normal(
shape=(bs, q_seqlen, num_heads, head_size)).astype(math_dtype) mean=0.0, std=0.02, shape=(bs, q_seqlen, num_heads, head_size)
attn_k_input = paddle.normal(mean=0.0, std=0.02, ).astype(math_dtype)
shape=(bs, kv_seqlen, num_heads, head_size)).astype(math_dtype) attn_k_input = paddle.normal(
attn_v_input = paddle.normal(mean=0.0, std=0.02, mean=0.0, std=0.02, shape=(bs, kv_seqlen, num_heads, head_size)
shape=(bs, kv_seqlen, num_heads, head_size)).astype(math_dtype) ).astype(math_dtype)
attn_v_input = paddle.normal(
q_actual_seqlen = paddle.randint(low=20, high=q_seqlen, shape=(bs,), dtype='int32') mean=0.0, std=0.02, shape=(bs, kv_seqlen, num_heads, head_size)
kv_actual_seqlen = paddle.randint(low=20, high=kv_seqlen, shape=(bs,), ).astype(math_dtype)
dtype='int32') if attn_type == 'cross' else q_actual_seqlen
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool') q_actual_seqlen = paddle.randint(low=20, high=q_seqlen, shape=(bs,), dtype="int32")
kv_actual_seqlen = (
grad_out = paddle.normal(mean=0.0, std=0.02, paddle.randint(low=20, high=kv_seqlen, shape=(bs,), dtype="int32")
shape=(bs, q_seqlen, num_heads, head_size)).astype('float32') if attn_type == "cross"
else q_actual_seqlen
)
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool")
grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, num_heads, head_size)).astype(
"float32"
)
for i in range(0, bs): for i in range(0, bs):
grad_out[i, q_actual_seqlen[i]:, :, :] = 0 grad_out[i, q_actual_seqlen[i] :, :, :] = 0
grad_out = grad_out.astype(math_dtype) grad_out = grad_out.astype(math_dtype)
for i in range(0, bs): for i in range(0, bs):
attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False
head_size = hidden_size // num_heads head_size = hidden_size // num_heads
layer_te = te.DotProductAttention(num_heads, layer_te = te.DotProductAttention(
head_size, num_heads,
attention_dropout=0.0, head_size,
attn_mask_type=mask_type, attention_dropout=0.0,
attention_type=attn_type, attn_mask_type=mask_type,
backend='transformer_engine') attention_type=attn_type,
layer_pd = te.DotProductAttention(num_heads, backend="transformer_engine",
head_size, )
attention_dropout=0.0, layer_pd = te.DotProductAttention(
attn_mask_type=mask_type, num_heads,
attention_type=attn_type, head_size,
backend='paddle') attention_dropout=0.0,
attn_mask_type=mask_type,
attention_type=attn_type,
backend="paddle",
)
def calc_attn_output_and_grad(layer, q, k, v, mask, dout): def calc_attn_output_and_grad(layer, q, k, v, mask, dout):
_q = paddle.to_tensor(q, stop_gradient=False) _q = paddle.to_tensor(q, stop_gradient=False)
...@@ -892,23 +953,29 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, ...@@ -892,23 +953,29 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
out.backward(dout) out.backward(dout)
return out, _q.grad, _k.grad, _v.grad return out, _q.grad, _k.grad, _v.grad
out, q_grad, k_grad, v_grad = calc_attn_output_and_grad(layer_te, attn_q_input, attn_k_input, out, q_grad, k_grad, v_grad = calc_attn_output_and_grad(
attn_v_input, attn_mask, grad_out) layer_te, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out
)
out_ref, q_grad_ref, k_grad_ref, v_grad_ref = calc_attn_output_and_grad( out_ref, q_grad_ref, k_grad_ref, v_grad_ref = calc_attn_output_and_grad(
layer_pd, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out) layer_pd, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out
)
valid_out_ref = paddle.full_like(out_ref, 0) valid_out_ref = paddle.full_like(out_ref, 0)
for i in range(0, bs): for i in range(0, bs):
valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :] valid_out_ref[i, 0 : q_actual_seqlen[i], :, :] = out_ref[i, 0 : q_actual_seqlen[i], :, :]
valid_q_grad_ref = paddle.full_like(q_grad_ref, 0) valid_q_grad_ref = paddle.full_like(q_grad_ref, 0)
valid_k_grad_ref = paddle.full_like(k_grad_ref, 0) valid_k_grad_ref = paddle.full_like(k_grad_ref, 0)
valid_v_grad_ref = paddle.full_like(v_grad_ref, 0) valid_v_grad_ref = paddle.full_like(v_grad_ref, 0)
for i in range(0, bs): for i in range(0, bs):
valid_q_grad_ref[i, 0:q_actual_seqlen[i], :, :] = q_grad_ref[i, 0:q_actual_seqlen[i], :, :] valid_q_grad_ref[i, 0 : q_actual_seqlen[i], :, :] = q_grad_ref[
valid_k_grad_ref[i, 0:kv_actual_seqlen[i], :, :] = k_grad_ref[i, i, 0 : q_actual_seqlen[i], :, :
0:kv_actual_seqlen[i], :, :] ]
valid_v_grad_ref[i, 0:kv_actual_seqlen[i], :, :] = v_grad_ref[i, valid_k_grad_ref[i, 0 : kv_actual_seqlen[i], :, :] = k_grad_ref[
0:kv_actual_seqlen[i], :, :] i, 0 : kv_actual_seqlen[i], :, :
]
valid_v_grad_ref[i, 0 : kv_actual_seqlen[i], :, :] = v_grad_ref[
i, 0 : kv_actual_seqlen[i], :, :
]
assert_allclose(out, valid_out_ref, rtol=rtol, atol=atol) assert_allclose(out, valid_out_ref, rtol=rtol, atol=atol)
assert_allclose(q_grad, valid_q_grad_ref, rtol=rtol, atol=atol) assert_allclose(q_grad, valid_q_grad_ref, rtol=rtol, atol=atol)
...@@ -916,21 +983,34 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, ...@@ -916,21 +983,34 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
assert_allclose(v_grad, valid_v_grad_ref, rtol=rtol, atol=atol) assert_allclose(v_grad, valid_v_grad_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize('bs', [1, 2]) @pytest.mark.parametrize("bs", [1, 2])
@pytest.mark.parametrize('num_gqa_groups', [1, 2, 4]) @pytest.mark.parametrize("num_gqa_groups", [1, 2, 4])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[256, 4, 1024]]) @pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[256, 4, 1024]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[1024, 1024]]) @pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]])
@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]]) @pytest.mark.parametrize("has_bias, no_dbias", [[False, True], [True, True], [True, False]])
@pytest.mark.parametrize('no_wgrad', [True, False]) @pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize('mask_type', ['causal', 'padding']) @pytest.mark.parametrize("mask_type", ["causal", "padding"])
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16']) @pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"])
@pytest.mark.parametrize('output_layernorm', [True, False]) @pytest.mark.parametrize("output_layernorm", [True, False])
@pytest.mark.parametrize('return_layernorm_output', [True, False]) @pytest.mark.parametrize("return_layernorm_output", [True, False])
@pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm']) @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size, def test_transformer_encoder_layer(
has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type, bs,
math_dtype, output_layernorm, return_layernorm_output, hidden_size,
normalization): num_heads,
num_gqa_groups,
ffn_hidden_size,
has_bias,
no_dbias,
no_wgrad,
q_seqlen,
kv_seqlen,
mask_type,
math_dtype,
output_layernorm,
return_layernorm_output,
normalization,
):
""" """
Test Transformer Encoder Layer Test Transformer Encoder Layer
""" """
...@@ -938,68 +1018,73 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f ...@@ -938,68 +1018,73 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
rtol = 5e-2 rtol = 5e-2
atol = 5e-2 atol = 5e-2
eps = 1e-3 eps = 1e-3
has_ln_bias = normalization == 'LayerNorm' has_ln_bias = normalization == "LayerNorm"
# Skip if cuDNN fused attention is not supported # Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported( if not is_fused_attention_supported(
num_heads=num_heads, num_heads=num_heads,
num_gqa_groups=num_gqa_groups, num_gqa_groups=num_gqa_groups,
q_seqlen=q_seqlen, q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen, kv_seqlen=kv_seqlen,
head_size=hidden_size // num_heads, head_size=hidden_size // num_heads,
dtype=math_dtype, dtype=math_dtype,
dropout=0.0, dropout=0.0,
qkv_layout="bshd_bshd_bshd", qkv_layout="bshd_bshd_bshd",
bias_type="no_bias", bias_type="no_bias",
mask_type=mask_type, mask_type=mask_type,
): ):
pytest.skip("cuDNN fused attention is not supported") pytest.skip("cuDNN fused attention is not supported")
encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype) encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)
q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen
kv_actual_seqlen = q_actual_seqlen kv_actual_seqlen = q_actual_seqlen
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool') attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool")
grad_out = paddle.normal(mean=0.0, std=0.02, grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, hidden_size)).astype(
shape=(bs, q_seqlen, hidden_size)).astype('float32') "float32"
)
for i in range(0, bs): for i in range(0, bs):
grad_out[i, q_actual_seqlen[i]:, :] = 0 grad_out[i, q_actual_seqlen[i] :, :] = 0
grad_out = grad_out.astype(math_dtype) grad_out = grad_out.astype(math_dtype)
for i in range(0, bs): for i in range(0, bs):
attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False
layer_te = te.TransformerLayer(hidden_size, layer_te = te.TransformerLayer(
ffn_hidden_size, hidden_size,
num_heads, ffn_hidden_size,
num_gqa_groups=num_gqa_groups, num_heads,
layernorm_epsilon=eps, num_gqa_groups=num_gqa_groups,
hidden_dropout=0.0, layernorm_epsilon=eps,
attention_dropout=0.0, hidden_dropout=0.0,
weight_attr=None, attention_dropout=0.0,
bias_attr=None if has_bias else False, weight_attr=None,
self_attn_mask_type=mask_type, bias_attr=None if has_bias else False,
apply_residual_connection_post_layernorm=return_layernorm_output, self_attn_mask_type=mask_type,
output_layernorm=output_layernorm, apply_residual_connection_post_layernorm=return_layernorm_output,
layer_type='encoder', output_layernorm=output_layernorm,
normalization=normalization, layer_type="encoder",
backend='transformer_engine') normalization=normalization,
layer_pd = te.TransformerLayer(hidden_size, backend="transformer_engine",
ffn_hidden_size, )
num_heads, layer_pd = te.TransformerLayer(
num_gqa_groups=num_gqa_groups, hidden_size,
layernorm_epsilon=eps, ffn_hidden_size,
hidden_dropout=0.0, num_heads,
attention_dropout=0.0, num_gqa_groups=num_gqa_groups,
weight_attr=None, layernorm_epsilon=eps,
bias_attr=None if has_bias else False, hidden_dropout=0.0,
self_attn_mask_type=mask_type, attention_dropout=0.0,
apply_residual_connection_post_layernorm=return_layernorm_output, weight_attr=None,
output_layernorm=output_layernorm, bias_attr=None if has_bias else False,
layer_type='encoder', self_attn_mask_type=mask_type,
normalization=normalization, apply_residual_connection_post_layernorm=return_layernorm_output,
backend='paddle') output_layernorm=output_layernorm,
layer_type="encoder",
normalization=normalization,
backend="paddle",
)
# MultiHeadAttention params # MultiHeadAttention params
if output_layernorm: if output_layernorm:
...@@ -1012,21 +1097,25 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f ...@@ -1012,21 +1097,25 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
layer_te.self_attention.qkv.bias.stop_gradient = no_dbias layer_te.self_attention.qkv.bias.stop_gradient = no_dbias
else: else:
layer_pd.self_attention.layernorm_qkv.ln_weight.copy_( layer_pd.self_attention.layernorm_qkv.ln_weight.copy_(
layer_te.self_attention.layernorm_qkv.ln_weight, True) layer_te.self_attention.layernorm_qkv.ln_weight, True
)
layer_pd.self_attention.layernorm_qkv.weight.copy_( layer_pd.self_attention.layernorm_qkv.weight.copy_(
layer_te.self_attention.layernorm_qkv.weight.T, True) layer_te.self_attention.layernorm_qkv.weight.T, True
)
layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
if has_ln_bias: if has_ln_bias:
layer_pd.self_attention.layernorm_qkv.ln_bias.copy_( layer_pd.self_attention.layernorm_qkv.ln_bias.copy_(
layer_te.self_attention.layernorm_qkv.ln_bias, True) layer_te.self_attention.layernorm_qkv.ln_bias, True
)
layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
if has_bias: if has_bias:
layer_pd.self_attention.layernorm_qkv.bias.copy_( layer_pd.self_attention.layernorm_qkv.bias.copy_(
layer_te.self_attention.layernorm_qkv.bias, True) layer_te.self_attention.layernorm_qkv.bias, True
)
layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias
layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias
...@@ -1074,52 +1163,75 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f ...@@ -1074,52 +1163,75 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
out.backward(dout) out.backward(dout)
return out, _encoder_input.grad return out, _encoder_input.grad
out_ref, grad_input_ref = calc_transformer_output_and_grad(layer_pd, encoder_input, attn_mask, out_ref, grad_input_ref = calc_transformer_output_and_grad(
grad_out) layer_pd, encoder_input, attn_mask, grad_out
)
out, grad_input = calc_transformer_output_and_grad(layer_te, encoder_input, attn_mask, grad_out) out, grad_input = calc_transformer_output_and_grad(layer_te, encoder_input, attn_mask, grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol) assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
if not no_wgrad: if not no_wgrad:
if output_layernorm: if output_layernorm:
assert_allclose(layer_te.self_attention.qkv.weight.grad, assert_allclose(
layer_pd.self_attention.qkv.weight.grad.T, layer_te.self_attention.qkv.weight.grad,
rtol=rtol, layer_pd.self_attention.qkv.weight.grad.T,
atol=atol) rtol=rtol,
atol=atol,
)
else: else:
assert_allclose(layer_te.self_attention.layernorm_qkv.weight.grad, assert_allclose(
layer_pd.self_attention.layernorm_qkv.weight.grad.T, layer_te.self_attention.layernorm_qkv.weight.grad,
rtol=rtol, layer_pd.self_attention.layernorm_qkv.weight.grad.T,
atol=atol) rtol=rtol,
atol=atol,
)
if not no_dbias: if not no_dbias:
if output_layernorm: if output_layernorm:
assert_allclose(layer_te.self_attention.qkv.bias.grad, assert_allclose(
layer_pd.self_attention.qkv.bias.grad, layer_te.self_attention.qkv.bias.grad,
rtol=0.01, layer_pd.self_attention.qkv.bias.grad,
atol=0.5) rtol=0.01,
atol=0.5,
)
else: else:
assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad, assert_allclose(
layer_pd.self_attention.layernorm_qkv.bias.grad, layer_te.self_attention.layernorm_qkv.bias.grad,
rtol=0.01, layer_pd.self_attention.layernorm_qkv.bias.grad,
atol=0.5) rtol=0.01,
atol=0.5,
)
@pytest.mark.parametrize('bs', [1, 2])
@pytest.mark.parametrize('num_gqa_groups', [1, 2, 4])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[256, 4, 1024]]) @pytest.mark.parametrize("bs", [1, 2])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[1024, 1024]]) @pytest.mark.parametrize("num_gqa_groups", [1, 2, 4])
@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]]) @pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[256, 4, 1024]])
@pytest.mark.parametrize('no_wgrad', [True, False]) @pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]])
@pytest.mark.parametrize('mask_type', ['causal', 'padding']) @pytest.mark.parametrize("has_bias, no_dbias", [[False, True], [True, True], [True, False]])
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16']) @pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize('output_layernorm', [True, False]) @pytest.mark.parametrize("mask_type", ["causal", "padding"])
@pytest.mark.parametrize('return_layernorm_output', [True, False]) @pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"])
@pytest.mark.parametrize('recompute_core_attention', [True, False]) @pytest.mark.parametrize("output_layernorm", [True, False])
@pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm']) @pytest.mark.parametrize("return_layernorm_output", [True, False])
def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size, @pytest.mark.parametrize("recompute_core_attention", [True, False])
has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type, @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
math_dtype, output_layernorm, return_layernorm_output, def test_transformer_decoder_layer(
recompute_core_attention, normalization): bs,
hidden_size,
num_heads,
num_gqa_groups,
ffn_hidden_size,
has_bias,
no_dbias,
no_wgrad,
q_seqlen,
kv_seqlen,
mask_type,
math_dtype,
output_layernorm,
return_layernorm_output,
recompute_core_attention,
normalization,
):
""" """
Test Transformer Decoder Layer Test Transformer Decoder Layer
""" """
...@@ -1127,34 +1239,37 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f ...@@ -1127,34 +1239,37 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
rtol = 5e-2 rtol = 5e-2
atol = 6e-2 atol = 6e-2
eps = 1e-3 eps = 1e-3
has_ln_bias = normalization == 'LayerNorm' has_ln_bias = normalization == "LayerNorm"
# Skip if cuDNN fused attention is not supported # Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported( if not is_fused_attention_supported(
num_heads=num_heads, num_heads=num_heads,
num_gqa_groups=num_gqa_groups, num_gqa_groups=num_gqa_groups,
q_seqlen=q_seqlen, q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen, kv_seqlen=kv_seqlen,
head_size=hidden_size // num_heads, head_size=hidden_size // num_heads,
dtype=math_dtype, dtype=math_dtype,
dropout=0.0, dropout=0.0,
qkv_layout="bshd_bshd_bshd", qkv_layout="bshd_bshd_bshd",
bias_type="no_bias", bias_type="no_bias",
mask_type=mask_type, mask_type=mask_type,
): ):
pytest.skip("cuDNN fused attention is not supported") pytest.skip("cuDNN fused attention is not supported")
encoder_input = paddle.normal(mean=0.0, std=0.1, encoder_input = paddle.normal(mean=0.0, std=0.1, shape=(bs, q_seqlen, hidden_size)).astype(
shape=(bs, q_seqlen, hidden_size)).astype(math_dtype) math_dtype
encoder_output = paddle.normal(mean=0.0, std=0.1, )
shape=(bs, kv_seqlen, hidden_size)).astype(math_dtype) encoder_output = paddle.normal(mean=0.0, std=0.1, shape=(bs, kv_seqlen, hidden_size)).astype(
math_dtype
)
q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen
kv_actual_seqlen = q_actual_seqlen kv_actual_seqlen = q_actual_seqlen
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool') attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool")
grad_out = paddle.normal(mean=0.0, std=0.01, grad_out = paddle.normal(mean=0.0, std=0.01, shape=(bs, q_seqlen, hidden_size)).astype(
shape=(bs, q_seqlen, hidden_size)).astype('float32') "float32"
)
# rounding to avoid numerical issues # rounding to avoid numerical issues
encoder_input = paddle.round(encoder_input * 1000) / 1000 encoder_input = paddle.round(encoder_input * 1000) / 1000
...@@ -1162,42 +1277,46 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f ...@@ -1162,42 +1277,46 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
grad_out = paddle.round(grad_out * 1000) / 1000 grad_out = paddle.round(grad_out * 1000) / 1000
for i in range(0, bs): for i in range(0, bs):
grad_out[i, q_actual_seqlen[i]:, :] = 0 grad_out[i, q_actual_seqlen[i] :, :] = 0
grad_out = grad_out.astype(math_dtype) grad_out = grad_out.astype(math_dtype)
for i in range(0, bs): for i in range(0, bs):
attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False
layer_te = te.TransformerLayer(hidden_size, layer_te = te.TransformerLayer(
ffn_hidden_size, hidden_size,
num_heads, ffn_hidden_size,
num_gqa_groups=num_gqa_groups, num_heads,
layernorm_epsilon=eps, num_gqa_groups=num_gqa_groups,
hidden_dropout=0.0, layernorm_epsilon=eps,
attention_dropout=0.0, hidden_dropout=0.0,
weight_attr=None, attention_dropout=0.0,
bias_attr=None if has_bias else False, weight_attr=None,
self_attn_mask_type=mask_type, bias_attr=None if has_bias else False,
apply_residual_connection_post_layernorm=return_layernorm_output, self_attn_mask_type=mask_type,
output_layernorm=output_layernorm, apply_residual_connection_post_layernorm=return_layernorm_output,
layer_type='decoder', output_layernorm=output_layernorm,
normalization=normalization, layer_type="decoder",
backend='transformer_engine') normalization=normalization,
layer_pd = te.TransformerLayer(hidden_size, backend="transformer_engine",
ffn_hidden_size, )
num_heads, layer_pd = te.TransformerLayer(
num_gqa_groups=num_gqa_groups, hidden_size,
layernorm_epsilon=eps, ffn_hidden_size,
hidden_dropout=0.0, num_heads,
attention_dropout=0.0, num_gqa_groups=num_gqa_groups,
weight_attr=None, layernorm_epsilon=eps,
bias_attr=None if has_bias else False, hidden_dropout=0.0,
self_attn_mask_type=mask_type, attention_dropout=0.0,
apply_residual_connection_post_layernorm=return_layernorm_output, weight_attr=None,
output_layernorm=output_layernorm, bias_attr=None if has_bias else False,
layer_type='decoder', self_attn_mask_type=mask_type,
normalization=normalization, apply_residual_connection_post_layernorm=return_layernorm_output,
backend='paddle') output_layernorm=output_layernorm,
layer_type="decoder",
normalization=normalization,
backend="paddle",
)
# MultiHeadAttention params - self attn # MultiHeadAttention params - self attn
if output_layernorm: if output_layernorm:
...@@ -1210,21 +1329,25 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f ...@@ -1210,21 +1329,25 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
layer_te.self_attention.qkv.bias.stop_gradient = no_dbias layer_te.self_attention.qkv.bias.stop_gradient = no_dbias
else: else:
layer_pd.self_attention.layernorm_qkv.ln_weight.copy_( layer_pd.self_attention.layernorm_qkv.ln_weight.copy_(
layer_te.self_attention.layernorm_qkv.ln_weight, True) layer_te.self_attention.layernorm_qkv.ln_weight, True
)
layer_pd.self_attention.layernorm_qkv.weight.copy_( layer_pd.self_attention.layernorm_qkv.weight.copy_(
layer_te.self_attention.layernorm_qkv.weight.T, True) layer_te.self_attention.layernorm_qkv.weight.T, True
)
layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
if has_ln_bias: if has_ln_bias:
layer_pd.self_attention.layernorm_qkv.ln_bias.copy_( layer_pd.self_attention.layernorm_qkv.ln_bias.copy_(
layer_te.self_attention.layernorm_qkv.ln_bias, True) layer_te.self_attention.layernorm_qkv.ln_bias, True
)
layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
if has_bias: if has_bias:
layer_pd.self_attention.layernorm_qkv.bias.copy_( layer_pd.self_attention.layernorm_qkv.bias.copy_(
layer_te.self_attention.layernorm_qkv.bias, True) layer_te.self_attention.layernorm_qkv.bias, True
)
layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias
layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias
...@@ -1238,26 +1361,31 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f ...@@ -1238,26 +1361,31 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
# MultiHeadAttention params - cross attn # MultiHeadAttention params - cross attn
layer_pd.inter_attention.layernorm_query.ln_weight.copy_( layer_pd.inter_attention.layernorm_query.ln_weight.copy_(
layer_te.inter_attention.layernorm_query.ln_weight, True) layer_te.inter_attention.layernorm_query.ln_weight, True
)
layer_pd.inter_attention.layernorm_query.weight.copy_( layer_pd.inter_attention.layernorm_query.weight.copy_(
layer_te.inter_attention.layernorm_query.weight.T, True) layer_te.inter_attention.layernorm_query.weight.T, True
)
layer_pd.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad layer_pd.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad
layer_pd.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad layer_pd.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad
layer_te.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad layer_te.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad
layer_te.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad layer_te.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad
if has_ln_bias: if has_ln_bias:
layer_pd.inter_attention.layernorm_query.ln_bias.copy_( layer_pd.inter_attention.layernorm_query.ln_bias.copy_(
layer_te.inter_attention.layernorm_query.ln_bias, True) layer_te.inter_attention.layernorm_query.ln_bias, True
)
layer_pd.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias layer_pd.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias
layer_te.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias layer_te.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias
if has_bias: if has_bias:
layer_pd.inter_attention.layernorm_query.bias.copy_( layer_pd.inter_attention.layernorm_query.bias.copy_(
layer_te.inter_attention.layernorm_query.bias, True) layer_te.inter_attention.layernorm_query.bias, True
)
layer_pd.inter_attention.layernorm_query.bias.stop_gradient = no_dbias layer_pd.inter_attention.layernorm_query.bias.stop_gradient = no_dbias
layer_te.inter_attention.layernorm_query.bias.stop_gradient = no_dbias layer_te.inter_attention.layernorm_query.bias.stop_gradient = no_dbias
layer_pd.inter_attention.key_value.weight.copy_(layer_te.inter_attention.key_value.weight.T, layer_pd.inter_attention.key_value.weight.copy_(
True) layer_te.inter_attention.key_value.weight.T, True
)
layer_pd.inter_attention.key_value.weight.stop_gradient = no_wgrad layer_pd.inter_attention.key_value.weight.stop_gradient = no_wgrad
layer_te.inter_attention.key_value.weight.stop_gradient = no_wgrad layer_te.inter_attention.key_value.weight.stop_gradient = no_wgrad
layer_pd.inter_attention.proj.weight.copy_(layer_te.inter_attention.proj.weight.T, True) layer_pd.inter_attention.proj.weight.copy_(layer_te.inter_attention.proj.weight.T, True)
...@@ -1301,25 +1429,30 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f ...@@ -1301,25 +1429,30 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
layer_te.layernorm.weight.stop_gradient = no_wgrad layer_te.layernorm.weight.stop_gradient = no_wgrad
layer_te.layernorm.bias.stop_gradient = no_dbias layer_te.layernorm.bias.stop_gradient = no_dbias
def calc_transformer_output_and_grad(layer, def calc_transformer_output_and_grad(
encoder_input, layer,
mask, encoder_input,
encoder_output, mask,
enc_dec_attn_mask, encoder_output,
dout, enc_dec_attn_mask,
recompute_core_attention=False): dout,
recompute_core_attention=False,
):
_encoder_input = paddle.to_tensor(encoder_input, stop_gradient=False) _encoder_input = paddle.to_tensor(encoder_input, stop_gradient=False)
_encoder_output = paddle.to_tensor(encoder_output, stop_gradient=False) _encoder_output = paddle.to_tensor(encoder_output, stop_gradient=False)
out = layer(_encoder_input, out = layer(
mask, _encoder_input,
_encoder_output, mask,
enc_dec_attn_mask, _encoder_output,
recompute_core_attention=recompute_core_attention) enc_dec_attn_mask,
recompute_core_attention=recompute_core_attention,
)
out.backward(dout) out.backward(dout)
return out, _encoder_input.grad, _encoder_output.grad return out, _encoder_input.grad, _encoder_output.grad
out_ref, grad_encoder_input_ref, grad_encoder_output_ref = calc_transformer_output_and_grad( out_ref, grad_encoder_input_ref, grad_encoder_output_ref = calc_transformer_output_and_grad(
layer_pd, encoder_input, attn_mask, encoder_output, attn_mask, grad_out) layer_pd, encoder_input, attn_mask, encoder_output, attn_mask, grad_out
)
out, grad_encoder_input, grad_encoder_output = calc_transformer_output_and_grad( out, grad_encoder_input, grad_encoder_output = calc_transformer_output_and_grad(
layer_te, layer_te,
encoder_input, encoder_input,
...@@ -1327,52 +1460,74 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f ...@@ -1327,52 +1460,74 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
encoder_output, encoder_output,
attn_mask, attn_mask,
grad_out, grad_out,
recompute_core_attention=recompute_core_attention) recompute_core_attention=recompute_core_attention,
)
assert_allclose(out, out_ref, rtol=rtol, atol=atol) assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(grad_encoder_input, grad_encoder_input_ref, rtol=rtol, atol=atol) assert_allclose(grad_encoder_input, grad_encoder_input_ref, rtol=rtol, atol=atol)
assert_allclose(grad_encoder_output, grad_encoder_output_ref, rtol=rtol, atol=atol) assert_allclose(grad_encoder_output, grad_encoder_output_ref, rtol=rtol, atol=atol)
if not no_wgrad: if not no_wgrad:
if output_layernorm: if output_layernorm:
assert_allclose(layer_te.self_attention.qkv.weight.grad, assert_allclose(
layer_pd.self_attention.qkv.weight.grad.T, layer_te.self_attention.qkv.weight.grad,
rtol=rtol, layer_pd.self_attention.qkv.weight.grad.T,
atol=atol) rtol=rtol,
atol=atol,
)
else: else:
assert_allclose(layer_te.self_attention.layernorm_qkv.weight.grad, assert_allclose(
layer_pd.self_attention.layernorm_qkv.weight.grad.T, layer_te.self_attention.layernorm_qkv.weight.grad,
rtol=rtol, layer_pd.self_attention.layernorm_qkv.weight.grad.T,
atol=atol) rtol=rtol,
assert_allclose(layer_te.inter_attention.layernorm_query.weight.grad, atol=atol,
layer_pd.inter_attention.layernorm_query.weight.grad.T, )
rtol=rtol, assert_allclose(
atol=atol) layer_te.inter_attention.layernorm_query.weight.grad,
layer_pd.inter_attention.layernorm_query.weight.grad.T,
rtol=rtol,
atol=atol,
)
if not no_dbias: if not no_dbias:
if output_layernorm: if output_layernorm:
assert_allclose(layer_te.self_attention.qkv.bias.grad, assert_allclose(
layer_pd.self_attention.qkv.bias.grad, layer_te.self_attention.qkv.bias.grad,
rtol=0.5, layer_pd.self_attention.qkv.bias.grad,
atol=0.6) rtol=0.5,
atol=0.6,
)
else: else:
assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad, assert_allclose(
layer_pd.self_attention.layernorm_qkv.bias.grad, layer_te.self_attention.layernorm_qkv.bias.grad,
rtol=0.01, layer_pd.self_attention.layernorm_qkv.bias.grad,
atol=0.5) rtol=0.01,
assert_allclose(layer_te.inter_attention.layernorm_query.bias.grad, atol=0.5,
layer_pd.inter_attention.layernorm_query.bias.grad, )
rtol=rtol, assert_allclose(
atol=atol) layer_te.inter_attention.layernorm_query.bias.grad,
layer_pd.inter_attention.layernorm_query.bias.grad,
rtol=rtol,
atol=atol,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('bs', [8]) @pytest.mark.parametrize("bs", [8])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]]) @pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[1024, 16, 4096]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128]]) @pytest.mark.parametrize("q_seqlen, kv_seqlen", [[128, 128]])
@pytest.mark.parametrize('mask_type', ['causal']) @pytest.mark.parametrize("mask_type", ["causal"])
@pytest.mark.parametrize('math_dtype', ['bfloat16']) @pytest.mark.parametrize("math_dtype", ["bfloat16"])
@pytest.mark.parametrize('num_microbatch', [8]) @pytest.mark.parametrize("num_microbatch", [8])
def test_transformer_encoder_layer_microbatch(bs, hidden_size, num_heads, ffn_hidden_size, q_seqlen, def test_transformer_encoder_layer_microbatch(
kv_seqlen, mask_type, math_dtype, num_microbatch): bs,
hidden_size,
num_heads,
ffn_hidden_size,
q_seqlen,
kv_seqlen,
mask_type,
math_dtype,
num_microbatch,
):
""" """
Test Transformer Encoder Layer with FP8 weight caching Test Transformer Encoder Layer with FP8 weight caching
""" """
...@@ -1383,48 +1538,56 @@ def test_transformer_encoder_layer_microbatch(bs, hidden_size, num_heads, ffn_hi ...@@ -1383,48 +1538,56 @@ def test_transformer_encoder_layer_microbatch(bs, hidden_size, num_heads, ffn_hi
# Skip if cuDNN fused attention is not supported # Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported( if not is_fused_attention_supported(
num_heads=num_heads, num_heads=num_heads,
num_gqa_groups=num_heads, num_gqa_groups=num_heads,
q_seqlen=q_seqlen, q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen, kv_seqlen=kv_seqlen,
head_size=hidden_size // num_heads, head_size=hidden_size // num_heads,
dtype=math_dtype, dtype=math_dtype,
dropout=0.0, dropout=0.0,
qkv_layout="bs3hd", qkv_layout="bs3hd",
bias_type="no_bias", bias_type="no_bias",
mask_type=mask_type, mask_type=mask_type,
): ):
pytest.skip("cuDNN fused attention is not supported") pytest.skip("cuDNN fused attention is not supported")
layer_cached = te.TransformerLayer(hidden_size, layer_cached = te.TransformerLayer(
ffn_hidden_size, hidden_size,
num_heads, ffn_hidden_size,
layernorm_epsilon=eps, num_heads,
hidden_dropout=0.0, layernorm_epsilon=eps,
attention_dropout=0.0, hidden_dropout=0.0,
weight_attr=None, attention_dropout=0.0,
bias_attr=None, weight_attr=None,
self_attn_mask_type=mask_type, bias_attr=None,
layer_type='encoder') self_attn_mask_type=mask_type,
layer_normal = te.TransformerLayer(hidden_size, layer_type="encoder",
ffn_hidden_size, )
num_heads, layer_normal = te.TransformerLayer(
layernorm_epsilon=eps, hidden_size,
hidden_dropout=0.0, ffn_hidden_size,
attention_dropout=0.0, num_heads,
weight_attr=None, layernorm_epsilon=eps,
bias_attr=None, hidden_dropout=0.0,
self_attn_mask_type=mask_type, attention_dropout=0.0,
layer_type='encoder') weight_attr=None,
bias_attr=None,
self_attn_mask_type=mask_type,
layer_type="encoder",
)
layer_normal.self_attention.layernorm_qkv.ln_weight.copy_( layer_normal.self_attention.layernorm_qkv.ln_weight.copy_(
layer_cached.self_attention.layernorm_qkv.ln_weight, True) layer_cached.self_attention.layernorm_qkv.ln_weight, True
)
layer_normal.self_attention.layernorm_qkv.ln_bias.copy_( layer_normal.self_attention.layernorm_qkv.ln_bias.copy_(
layer_cached.self_attention.layernorm_qkv.ln_bias, True) layer_cached.self_attention.layernorm_qkv.ln_bias, True
)
layer_normal.self_attention.layernorm_qkv.weight.copy_( layer_normal.self_attention.layernorm_qkv.weight.copy_(
layer_cached.self_attention.layernorm_qkv.weight, True) layer_cached.self_attention.layernorm_qkv.weight, True
)
layer_normal.self_attention.layernorm_qkv.bias.copy_( layer_normal.self_attention.layernorm_qkv.bias.copy_(
layer_cached.self_attention.layernorm_qkv.bias, True) layer_cached.self_attention.layernorm_qkv.bias, True
)
layer_normal.self_attention.proj.weight.copy_(layer_cached.self_attention.proj.weight, True) layer_normal.self_attention.proj.weight.copy_(layer_cached.self_attention.proj.weight, True)
layer_normal.self_attention.proj.bias.copy_(layer_cached.self_attention.proj.bias, True) layer_normal.self_attention.proj.bias.copy_(layer_cached.self_attention.proj.bias, True)
...@@ -1442,18 +1605,19 @@ def test_transformer_encoder_layer_microbatch(bs, hidden_size, num_heads, ffn_hi ...@@ -1442,18 +1605,19 @@ def test_transformer_encoder_layer_microbatch(bs, hidden_size, num_heads, ffn_hi
def generate_input(): def generate_input():
encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype) encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)
q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen
kv_actual_seqlen = q_actual_seqlen kv_actual_seqlen = q_actual_seqlen
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool') attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool")
grad_out = paddle.normal(mean=0.0, std=0.02, grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, hidden_size)).astype(
shape=(bs, q_seqlen, hidden_size)).astype('float32') "float32"
)
for i in range(0, bs): for i in range(0, bs):
grad_out[i, q_actual_seqlen[i]:, :] = 0 grad_out[i, q_actual_seqlen[i] :, :] = 0
grad_out = grad_out.astype(math_dtype) grad_out = grad_out.astype(math_dtype)
for i in range(0, bs): for i in range(0, bs):
attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False
return encoder_input, attn_mask, grad_out return encoder_input, attn_mask, grad_out
...@@ -1477,7 +1641,9 @@ def test_transformer_encoder_layer_microbatch(bs, hidden_size, num_heads, ffn_hi ...@@ -1477,7 +1641,9 @@ def test_transformer_encoder_layer_microbatch(bs, hidden_size, num_heads, ffn_hi
out_ref.backward(grad_out) out_ref.backward(grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol) assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(layer_cached.self_attention.layernorm_qkv.weight.grad, assert_allclose(
layer_normal.self_attention.layernorm_qkv.weight.grad, layer_cached.self_attention.layernorm_qkv.weight.grad,
rtol=rtol, layer_normal.self_attention.layernorm_qkv.weight.grad,
atol=atol) rtol=rtol,
atol=atol,
)
...@@ -16,7 +16,7 @@ is_fp8_supported, reason = is_fp8_available() ...@@ -16,7 +16,7 @@ is_fp8_supported, reason = is_fp8_available()
def create_optimizer(model, use_pure_bf16, use_main_grad): def create_optimizer(model, use_pure_bf16, use_main_grad):
'''Create optimizer''' """Create optimizer"""
if use_main_grad: if use_main_grad:
assert use_pure_bf16 assert use_pure_bf16
model = mix_precision_utils.MixPrecisionLayer(model, dtype="bfloat16") model = mix_precision_utils.MixPrecisionLayer(model, dtype="bfloat16")
...@@ -32,7 +32,7 @@ def create_optimizer(model, use_pure_bf16, use_main_grad): ...@@ -32,7 +32,7 @@ def create_optimizer(model, use_pure_bf16, use_main_grad):
class Net(paddle.nn.Layer): class Net(paddle.nn.Layer):
'''Network use for main_grad testing''' """Network use for main_grad testing"""
def __init__(self, fuse_wgrad_accumulation): def __init__(self, fuse_wgrad_accumulation):
super().__init__() super().__init__()
...@@ -40,7 +40,7 @@ class Net(paddle.nn.Layer): ...@@ -40,7 +40,7 @@ class Net(paddle.nn.Layer):
4096, 4096,
16384, 16384,
32, 32,
layer_type='encoder', layer_type="encoder",
fuse_wgrad_accumulation=fuse_wgrad_accumulation, fuse_wgrad_accumulation=fuse_wgrad_accumulation,
) )
...@@ -50,7 +50,7 @@ class Net(paddle.nn.Layer): ...@@ -50,7 +50,7 @@ class Net(paddle.nn.Layer):
def train(enable_master_grad, fuse_wgrad_accumulation=False): def train(enable_master_grad, fuse_wgrad_accumulation=False):
'''Train function''' """Train function"""
paddle.seed(10) paddle.seed(10)
accumulate_steps = 4 accumulate_steps = 4
...@@ -64,7 +64,7 @@ def train(enable_master_grad, fuse_wgrad_accumulation=False): ...@@ -64,7 +64,7 @@ def train(enable_master_grad, fuse_wgrad_accumulation=False):
loss_list = [] loss_list = []
for step_id in range(16): for step_id in range(16):
inp = paddle.uniform([2, 1024, 4096], dtype='float32') inp = paddle.uniform([2, 1024, 4096], dtype="float32")
inp.stop_gradient = False inp.stop_gradient = False
with te.fp8_autocast(enabled=True): with te.fp8_autocast(enabled=True):
out = model(inp) out = model(inp)
...@@ -82,8 +82,8 @@ def train(enable_master_grad, fuse_wgrad_accumulation=False): ...@@ -82,8 +82,8 @@ def train(enable_master_grad, fuse_wgrad_accumulation=False):
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
def test_master_grad(): def test_master_grad():
'''Test main_grad''' """Test main_grad"""
paddle.set_default_dtype('float32') paddle.set_default_dtype("float32")
loss1 = train(enable_master_grad=False) loss1 = train(enable_master_grad=False)
loss2 = train(enable_master_grad=True) loss2 = train(enable_master_grad=True)
loss3 = train(enable_master_grad=True, fuse_wgrad_accumulation=True) loss3 = train(enable_master_grad=True, fuse_wgrad_accumulation=True)
......
...@@ -56,8 +56,13 @@ from transformer_engine.paddle.fp8 import is_fp8_available ...@@ -56,8 +56,13 @@ from transformer_engine.paddle.fp8 import is_fp8_available
from transformer_engine.paddle.constants import FP8FwdTensors from transformer_engine.paddle.constants import FP8FwdTensors
from transformer_engine.common.recipe import DelayedScaling from transformer_engine.common.recipe import DelayedScaling
GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024), GEMM_CASES = [
(16384, 1024, 1024)] (256, 256, 512),
(32, 32, 32),
(16384, 1024, 2816),
(16384, 2816, 1024),
(16384, 1024, 1024),
]
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
SELF_ATTN_CASES = [(2, 512, 12, 64)] SELF_ATTN_CASES = [(2, 512, 12, 64)]
...@@ -74,13 +79,13 @@ def setup(): ...@@ -74,13 +79,13 @@ def setup():
yield yield
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize('inplace', [True, False]) @pytest.mark.parametrize("inplace", [True, False])
def test_quantize_dequantize(fp8_dtype, inplace): def test_quantize_dequantize(fp8_dtype, inplace):
""" """
Test cast_to_fp8 and cast_from_fp8 Test cast_to_fp8 and cast_from_fp8
""" """
a = paddle.rand(shape=(32, 32), dtype='float32') a = paddle.rand(shape=(32, 32), dtype="float32")
# Init fp8_meta # Init fp8_meta
fp8_meta = create_fp8_meta() fp8_meta = create_fp8_meta()
a_fp8 = paddle.zeros(shape=a.shape, dtype=paddle.uint8) if inplace else None a_fp8 = paddle.zeros(shape=a.shape, dtype=paddle.uint8) if inplace else None
...@@ -99,7 +104,7 @@ def copy_bits_from_float_to_uint16(f): ...@@ -99,7 +104,7 @@ def copy_bits_from_float_to_uint16(f):
""" """
Copy bits Copy bits
""" """
return struct.unpack('<I', struct.pack('<f', f))[0] >> 16 return struct.unpack("<I", struct.pack("<f", f))[0] >> 16
def convert_float_to_uint16(float_list): def convert_float_to_uint16(float_list):
...@@ -124,95 +129,106 @@ class TestTranspose: ...@@ -124,95 +129,106 @@ class TestTranspose:
""" """
Test BF16 transpose Test BF16 transpose
""" """
a = paddle.rand(shape=(16, 32), dtype='bfloat16') a = paddle.rand(shape=(16, 32), dtype="bfloat16")
a_transposed = transpose(a, otype=tex.DType.kBFloat16) a_transposed = transpose(a, otype=tex.DType.kBFloat16)
assert_allclose(a_transposed, a.T) assert_allclose(a_transposed, a.T)
@staticmethod @staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_transpose_fp8(fp8_dtype): def test_transpose_fp8(fp8_dtype):
""" """
Test FP8 transpose Test FP8 transpose
""" """
min_val = -8 min_val = -8
max_val = 8 max_val = 8
a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32') a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32")
fp8_meta = create_fp8_meta() fp8_meta = create_fp8_meta()
a_fp8 = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype) a_fp8 = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
a_fp8_transposed = transpose(a_fp8, otype=fp8_dtype) a_fp8_transposed = transpose(a_fp8, otype=fp8_dtype)
a_transposed = cast_from_fp8(a_fp8_transposed, a_transposed = cast_from_fp8(
fp8_meta, a_fp8_transposed,
FP8FwdTensors.GEMM1_INPUT, fp8_meta,
itype=fp8_dtype, FP8FwdTensors.GEMM1_INPUT,
otype=tex.DType.kFloat32) itype=fp8_dtype,
otype=tex.DType.kFloat32,
)
assert_allclose(a_transposed, a.T) assert_allclose(a_transposed, a.T)
@staticmethod @staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize('inplace', [True, False]) @pytest.mark.parametrize("inplace", [True, False])
def test_cast_transpose(fp8_dtype, inplace): def test_cast_transpose(fp8_dtype, inplace):
""" """
Test cast_transpose Test cast_transpose
""" """
min_val = -8 min_val = -8
max_val = 8 max_val = 8
a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32') a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32")
fp8_meta = create_fp8_meta() fp8_meta = create_fp8_meta()
a_fp8_casted, a_fp8_transposed = None, None a_fp8_casted, a_fp8_transposed = None, None
if inplace: if inplace:
a_fp8_casted = paddle.zeros(shape=a.shape, dtype=paddle.uint8) a_fp8_casted = paddle.zeros(shape=a.shape, dtype=paddle.uint8)
a_fp8_transposed = paddle.zeros(shape=a.T.shape, dtype=paddle.uint8) a_fp8_transposed = paddle.zeros(shape=a.T.shape, dtype=paddle.uint8)
a_fp8_casted, a_fp8_transposed = cast_transpose(a, a_fp8_casted, a_fp8_transposed = cast_transpose(
fp8_meta, a,
FP8FwdTensors.GEMM1_INPUT, fp8_meta,
otype=fp8_dtype, FP8FwdTensors.GEMM1_INPUT,
cast_out=a_fp8_casted, otype=fp8_dtype,
transpose_out=a_fp8_transposed) cast_out=a_fp8_casted,
transpose_out=a_fp8_transposed,
a_transposed = cast_from_fp8(a_fp8_transposed, )
fp8_meta,
FP8FwdTensors.GEMM1_INPUT, a_transposed = cast_from_fp8(
itype=fp8_dtype, a_fp8_transposed,
otype=tex.DType.kFloat32) fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
a_casted = cast_from_fp8(a_fp8_casted, itype=fp8_dtype,
fp8_meta, otype=tex.DType.kFloat32,
FP8FwdTensors.GEMM1_INPUT, )
itype=fp8_dtype,
otype=tex.DType.kFloat32) a_casted = cast_from_fp8(
a_fp8_casted,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32,
)
assert_allclose(a_casted, a) assert_allclose(a_casted, a)
assert_allclose(a_transposed, a.T) assert_allclose(a_transposed, a.T)
@staticmethod @staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_cast_transpose_bgrad(fp8_dtype): def test_cast_transpose_bgrad(fp8_dtype):
""" """
Test cast_transpose_bgrad Test cast_transpose_bgrad
""" """
min_val = -8 min_val = -8
max_val = 8 max_val = 8
a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32') a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32")
fp8_meta = create_fp8_meta() fp8_meta = create_fp8_meta()
bgrad, a_fp8_casted, a_fp8_transposed = cast_transpose_bgrad(a, bgrad, a_fp8_casted, a_fp8_transposed = cast_transpose_bgrad(
fp8_meta, a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype
FP8FwdTensors.GEMM1_INPUT, )
otype=fp8_dtype)
a_transposed = cast_from_fp8(
a_transposed = cast_from_fp8(a_fp8_transposed, a_fp8_transposed,
fp8_meta, fp8_meta,
FP8FwdTensors.GEMM1_INPUT, FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype, itype=fp8_dtype,
otype=tex.DType.kFloat32) otype=tex.DType.kFloat32,
)
a_casted = cast_from_fp8(a_fp8_casted,
fp8_meta, a_casted = cast_from_fp8(
FP8FwdTensors.GEMM1_INPUT, a_fp8_casted,
itype=fp8_dtype, fp8_meta,
otype=tex.DType.kFloat32) FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32,
)
assert_allclose(a_casted, a) assert_allclose(a_casted, a)
assert_allclose(a_transposed, a.T) assert_allclose(a_transposed, a.T)
...@@ -229,7 +245,7 @@ class TestActivation: ...@@ -229,7 +245,7 @@ class TestActivation:
""" """
Test BF16 GELU Forward Test BF16 GELU Forward
""" """
a = paddle.rand(shape=(16, 32), dtype='bfloat16') * 2 - 1 a = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1
gelu_out = te_gelu(a, otype=tex.DType.kBFloat16) gelu_out = te_gelu(a, otype=tex.DType.kBFloat16)
gelu_ref = paddle.nn.GELU()(a) gelu_ref = paddle.nn.GELU()(a)
...@@ -237,21 +253,23 @@ class TestActivation: ...@@ -237,21 +253,23 @@ class TestActivation:
@staticmethod @staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_gelu_fp8(fp8_dtype): def test_gelu_fp8(fp8_dtype):
""" """
Test FP8 GELU Forward Test FP8 GELU Forward
""" """
a = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1 a = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1
fp8_meta = create_fp8_meta() fp8_meta = create_fp8_meta()
gelu_out_fp8 = gelu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype) gelu_out_fp8 = gelu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
gelu_out = cast_from_fp8(gelu_out_fp8, gelu_out = cast_from_fp8(
fp8_meta, gelu_out_fp8,
FP8FwdTensors.GEMM1_INPUT, fp8_meta,
itype=fp8_dtype, FP8FwdTensors.GEMM1_INPUT,
otype=tex.DType.kFloat32) itype=fp8_dtype,
otype=tex.DType.kFloat32,
)
gelu_ref = paddle.nn.GELU()(a) gelu_ref = paddle.nn.GELU()(a)
...@@ -259,36 +277,38 @@ class TestActivation: ...@@ -259,36 +277,38 @@ class TestActivation:
@staticmethod @staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_gelu_bwd_fp8(fp8_dtype): def test_gelu_bwd_fp8(fp8_dtype):
""" """
Test FP8 GELU Backward Test FP8 GELU Backward
""" """
# y = GELU(x), calculate ref # y = GELU(x), calculate ref
x = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1 x = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1
x.stop_gradient = False x.stop_gradient = False
y = paddle.nn.GELU()(x) y = paddle.nn.GELU()(x)
y_grad = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1 y_grad = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1
paddle.autograd.backward([y], [y_grad], True) paddle.autograd.backward([y], [y_grad], True)
# calculate fp8 # calculate fp8
fp8_meta = create_fp8_meta() fp8_meta = create_fp8_meta()
x_grad_fp8, x_grad_t_fp8, dbias = dgelu_cast_transpose_bgrad_fp8(y_grad, x_grad_fp8, x_grad_t_fp8, dbias = dgelu_cast_transpose_bgrad_fp8(
x, y_grad, x, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype
fp8_meta, )
FP8FwdTensors.GEMM1_INPUT,
otype=fp8_dtype) x_grad = cast_from_fp8(
x_grad_fp8,
x_grad = cast_from_fp8(x_grad_fp8, fp8_meta,
fp8_meta, FP8FwdTensors.GEMM1_INPUT,
FP8FwdTensors.GEMM1_INPUT, itype=fp8_dtype,
itype=fp8_dtype, otype=tex.DType.kFloat32,
otype=tex.DType.kFloat32) )
x_grad_t = cast_from_fp8(x_grad_t_fp8, x_grad_t = cast_from_fp8(
fp8_meta, x_grad_t_fp8,
FP8FwdTensors.GEMM1_INPUT, fp8_meta,
itype=fp8_dtype, FP8FwdTensors.GEMM1_INPUT,
otype=tex.DType.kFloat32) itype=fp8_dtype,
otype=tex.DType.kFloat32,
)
assert_allclose(x_grad, x.grad, rtol=0.1, atol=0.01) assert_allclose(x_grad, x.grad, rtol=0.1, atol=0.01)
assert_allclose(x_grad_t, x.grad.T, rtol=0.1, atol=0.01) assert_allclose(x_grad_t, x.grad.T, rtol=0.1, atol=0.01)
...@@ -299,7 +319,7 @@ class TestActivation: ...@@ -299,7 +319,7 @@ class TestActivation:
""" """
Test BF16 SwiGLU Forward Test BF16 SwiGLU Forward
""" """
a = paddle.rand(shape=(16, 32), dtype='bfloat16') * 2 - 1 a = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1
swiglu_out = swiglu(a, otype=tex.DType.kBFloat16) swiglu_out = swiglu(a, otype=tex.DType.kBFloat16)
swiglu_ref = swiglu_pd(a) swiglu_ref = swiglu_pd(a)
...@@ -307,21 +327,23 @@ class TestActivation: ...@@ -307,21 +327,23 @@ class TestActivation:
@staticmethod @staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_swiglu_fp8(fp8_dtype): def test_swiglu_fp8(fp8_dtype):
""" """
Test FP8 SwiGLU Forward Test FP8 SwiGLU Forward
""" """
a = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1 a = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1
fp8_meta = create_fp8_meta() fp8_meta = create_fp8_meta()
swiglu_out_fp8 = swiglu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype) swiglu_out_fp8 = swiglu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
swiglu_out = cast_from_fp8(swiglu_out_fp8, swiglu_out = cast_from_fp8(
fp8_meta, swiglu_out_fp8,
FP8FwdTensors.GEMM1_INPUT, fp8_meta,
itype=fp8_dtype, FP8FwdTensors.GEMM1_INPUT,
otype=tex.DType.kFloat32) itype=fp8_dtype,
otype=tex.DType.kFloat32,
)
swiglu_ref = swiglu_pd(a) swiglu_ref = swiglu_pd(a)
...@@ -333,10 +355,10 @@ class TestActivation: ...@@ -333,10 +355,10 @@ class TestActivation:
Test SwiGLU Backward Test SwiGLU Backward
""" """
# y = SwiGLU(x), calculate ref # y = SwiGLU(x), calculate ref
x = paddle.rand(shape=(16, 32), dtype='bfloat16') * 2 - 1 x = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1
x.stop_gradient = False x.stop_gradient = False
y = swiglu_pd(x) y = swiglu_pd(x)
y_grad = paddle.rand(shape=(16, 16), dtype='bfloat16') * 2 - 1 y_grad = paddle.rand(shape=(16, 16), dtype="bfloat16") * 2 - 1
paddle.autograd.backward([y], [y_grad], True) paddle.autograd.backward([y], [y_grad], True)
# calculate fp8 # calculate fp8
x_grad = dswiglu(y_grad, x, otype=tex.DType.kBFloat16) x_grad = dswiglu(y_grad, x, otype=tex.DType.kBFloat16)
...@@ -350,17 +372,18 @@ class TestGemm: ...@@ -350,17 +372,18 @@ class TestGemm:
""" """
@staticmethod @staticmethod
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0), @pytest.mark.skipif(
reason="BF16 GEMM requires Ampere+ GPU") paddle.device.cuda.get_device_capability() < (8, 0), reason="BF16 GEMM requires Ampere+ GPU"
@pytest.mark.parametrize('m,n,k', GEMM_CASES) )
@pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_bf16(m, n, k): def test_bf16(m, n, k):
""" """
Test "TN" BF16 GEMM Test "TN" BF16 GEMM
""" """
a = paddle.rand(shape=(m, k), dtype='bfloat16') a = paddle.rand(shape=(m, k), dtype="bfloat16")
b = paddle.rand(shape=(n, k), dtype='bfloat16') b = paddle.rand(shape=(n, k), dtype="bfloat16")
workspace = paddle.zeros(shape=[33_554_432], dtype='uint8') workspace = paddle.zeros(shape=[33_554_432], dtype="uint8")
ref_out = paddle.matmul(a, b.T) ref_out = paddle.matmul(a, b.T)
# CublasLt inside tex.te_gemm assumes inputs are column major. # CublasLt inside tex.te_gemm assumes inputs are column major.
...@@ -368,37 +391,51 @@ class TestGemm: ...@@ -368,37 +391,51 @@ class TestGemm:
# transpose of X. # transpose of X.
# Here we perform "TN" GEMM in column major, i.e., b@a^T = C^T, # Here we perform "TN" GEMM in column major, i.e., b@a^T = C^T,
# which is equivalent to a@b^T = C in row major. # which is equivalent to a@b^T = C in row major.
actual_out, _, _ = gemm(b, a, paddle.bfloat16, workspace, False, None, False, False, "TN", actual_out, _, _ = gemm(
None, None, False) b, a, paddle.bfloat16, workspace, False, None, False, False, "TN", None, None, False
)
assert_allclose(actual_out, ref_out, rtol=1.6e-2, atol=1e-5) assert_allclose(actual_out, ref_out, rtol=1.6e-2, atol=1e-5)
@staticmethod @staticmethod
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0), @pytest.mark.skipif(
reason="BF16 GEMM requires Ampere+ GPU") paddle.device.cuda.get_device_capability() < (8, 0), reason="BF16 GEMM requires Ampere+ GPU"
@pytest.mark.parametrize('m,n,k', GEMM_CASES) )
@pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_bf16_inplace(m, n, k): def test_bf16_inplace(m, n, k):
""" """
Test "TN" BF16 GEMM, with accumulate=True Test "TN" BF16 GEMM, with accumulate=True
""" """
min_val = -16 min_val = -16
max_val = 16 max_val = 16
a = paddle.rand(shape=(m, k), dtype='bfloat16') a = paddle.rand(shape=(m, k), dtype="bfloat16")
b = paddle.rand(shape=(n, k), dtype='bfloat16') b = paddle.rand(shape=(n, k), dtype="bfloat16")
c = paddle.cast(paddle.randint(min_val, max_val, shape=(m, n)), 'bfloat16') c = paddle.cast(paddle.randint(min_val, max_val, shape=(m, n)), "bfloat16")
workspace = paddle.zeros(shape=[33_554_432], dtype='uint8') workspace = paddle.zeros(shape=[33_554_432], dtype="uint8")
ref_out = c + paddle.matmul(a, b.T) ref_out = c + paddle.matmul(a, b.T)
actual_out = paddle.clone(c) actual_out = paddle.clone(c)
_, _, _ = gemm(b, a, paddle.bfloat16, workspace, False, None, False, True, "TN", actual_out, _, _, _ = gemm(
None, False) b,
a,
paddle.bfloat16,
workspace,
False,
None,
False,
True,
"TN",
actual_out,
None,
False,
)
assert_allclose(actual_out, ref_out, rtol=5e-2, atol=5e-2) assert_allclose(actual_out, ref_out, rtol=5e-2, atol=5e-2)
@staticmethod @staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', GEMM_CASES) @pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_fp8_randint(m, n, k): def test_fp8_randint(m, n, k):
""" """
Test "TN" FP8 GEMM Test "TN" FP8 GEMM
...@@ -409,17 +446,26 @@ class TestGemm: ...@@ -409,17 +446,26 @@ class TestGemm:
out_dtype = paddle.float32 out_dtype = paddle.float32
fp8_meta = create_fp8_meta(num_gemms=1) fp8_meta = create_fp8_meta(num_gemms=1)
a = paddle.cast(paddle.randint(min_val, max_val, shape=(m, k)), 'float32') a = paddle.cast(paddle.randint(min_val, max_val, shape=(m, k)), "float32")
a_casted = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype) a_casted = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
b = paddle.cast(paddle.randint(min_val, max_val, shape=(n, k)), 'float32') b = paddle.cast(paddle.randint(min_val, max_val, shape=(n, k)), "float32")
b_casted = cast_to_fp8(b, fp8_meta, FP8FwdTensors.GEMM1_WEIGHT, otype=fp8_dtype) b_casted = cast_to_fp8(b, fp8_meta, FP8FwdTensors.GEMM1_WEIGHT, otype=fp8_dtype)
workspace = paddle.zeros(shape=[33_554_432], dtype='uint8') workspace = paddle.zeros(shape=[33_554_432], dtype="uint8")
ref_out = paddle.matmul(a, b.T) ref_out = paddle.matmul(a, b.T)
actual_out, _ = fp8_gemm(b_casted, fp8_meta.scale_inv, FP8FwdTensors.GEMM1_WEIGHT, actual_out, _ = fp8_gemm(
fp8_dtype, a_casted, fp8_meta.scale_inv, FP8FwdTensors.GEMM1_INPUT, b_casted,
fp8_dtype, out_dtype, workspace) fp8_meta.scale_inv,
FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype,
a_casted,
fp8_meta.scale_inv,
FP8FwdTensors.GEMM1_INPUT,
fp8_dtype,
out_dtype,
workspace,
)
assert_allclose(actual_out, ref_out) assert_allclose(actual_out, ref_out)
...@@ -434,14 +480,12 @@ class TestLayerNorm: ...@@ -434,14 +480,12 @@ class TestLayerNorm:
""" """
Calculate reference using paddle layer_norm op Calculate reference using paddle layer_norm op
""" """
y = paddle.nn.functional.layer_norm(x=x, y = paddle.nn.functional.layer_norm(
normalized_shape=x.shape[1:], x=x, normalized_shape=x.shape[1:], weight=gamma, bias=beta, epsilon=eps
weight=gamma, )
bias=beta,
epsilon=eps)
mean = paddle.mean(x, axis=-1) mean = paddle.mean(x, axis=-1)
var = paddle.var(x, axis=-1) var = paddle.var(x, axis=-1)
inv_var = paddle.sqrt(1. / var) inv_var = paddle.sqrt(1.0 / var)
return y, mean, inv_var return y, mean, inv_var
@staticmethod @staticmethod
...@@ -453,11 +497,9 @@ class TestLayerNorm: ...@@ -453,11 +497,9 @@ class TestLayerNorm:
gamma.stop_gradient = False gamma.stop_gradient = False
beta.stop_gradient = False beta.stop_gradient = False
y = paddle.nn.functional.layer_norm(x=x, y = paddle.nn.functional.layer_norm(
normalized_shape=x.shape[1:], x=x, normalized_shape=x.shape[1:], weight=gamma, bias=beta, epsilon=eps
weight=gamma, )
bias=beta,
epsilon=eps)
paddle.autograd.backward([y], [dy], True) paddle.autograd.backward([y], [dy], True)
...@@ -469,9 +511,9 @@ class TestLayerNorm: ...@@ -469,9 +511,9 @@ class TestLayerNorm:
""" """
N, H = (16, 32) N, H = (16, 32)
eps = 1e-3 eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype='bfloat16') x = paddle.uniform(shape=(N, H), dtype="bfloat16")
gamma = paddle.uniform(shape=(H,), dtype='bfloat16') gamma = paddle.uniform(shape=(H,), dtype="bfloat16")
beta = paddle.uniform(shape=(H,), dtype='bfloat16') beta = paddle.uniform(shape=(H,), dtype="bfloat16")
y, mu, rsigma = layernorm_fwd(x, gamma, beta, eps, tex.DType.kBFloat16) y, mu, rsigma = layernorm_fwd(x, gamma, beta, eps, tex.DType.kBFloat16)
...@@ -490,9 +532,9 @@ class TestLayerNorm: ...@@ -490,9 +532,9 @@ class TestLayerNorm:
N, H = (16, 32) N, H = (16, 32)
eps = 1e-3 eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype='float32') x = paddle.uniform(shape=(N, H), dtype="float32")
gamma = paddle.uniform(shape=(H,), dtype='float32') gamma = paddle.uniform(shape=(H,), dtype="float32")
beta = paddle.uniform(shape=(H,), dtype='float32') beta = paddle.uniform(shape=(H,), dtype="float32")
fp8_tensor = FP8FwdTensors.GEMM1_INPUT fp8_tensor = FP8FwdTensors.GEMM1_INPUT
fp8_meta = create_fp8_meta() fp8_meta = create_fp8_meta()
...@@ -513,10 +555,10 @@ class TestLayerNorm: ...@@ -513,10 +555,10 @@ class TestLayerNorm:
""" """
N, H = (16, 32) N, H = (16, 32)
eps = 1e-3 eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype='bfloat16') x = paddle.uniform(shape=(N, H), dtype="bfloat16")
dy = paddle.uniform(shape=(N, H), dtype='bfloat16') dy = paddle.uniform(shape=(N, H), dtype="bfloat16")
gamma = paddle.uniform(shape=(H,), dtype='bfloat16') gamma = paddle.uniform(shape=(H,), dtype="bfloat16")
beta = paddle.uniform(shape=(H,), dtype='bfloat16') beta = paddle.uniform(shape=(H,), dtype="bfloat16")
dx_ref, dgamma_ref, dbeta_ref = self.calc_bwd_ref(x, eps, gamma, beta, dy) dx_ref, dgamma_ref, dbeta_ref = self.calc_bwd_ref(x, eps, gamma, beta, dy)
...@@ -563,8 +605,8 @@ class TestRMSNorm: ...@@ -563,8 +605,8 @@ class TestRMSNorm:
""" """
N, H = (16, 32) N, H = (16, 32)
eps = 1e-3 eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype='bfloat16') x = paddle.uniform(shape=(N, H), dtype="bfloat16")
gamma = paddle.uniform(shape=(H,), dtype='bfloat16') gamma = paddle.uniform(shape=(H,), dtype="bfloat16")
y, _ = rmsnorm_fwd(x, gamma, eps, tex.DType.kBFloat16) y, _ = rmsnorm_fwd(x, gamma, eps, tex.DType.kBFloat16)
...@@ -581,8 +623,8 @@ class TestRMSNorm: ...@@ -581,8 +623,8 @@ class TestRMSNorm:
N, H = (16, 32) N, H = (16, 32)
eps = 1e-3 eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype='float32') x = paddle.uniform(shape=(N, H), dtype="float32")
gamma = paddle.uniform(shape=(H,), dtype='float32') gamma = paddle.uniform(shape=(H,), dtype="float32")
fp8_tensor = FP8FwdTensors.GEMM1_INPUT fp8_tensor = FP8FwdTensors.GEMM1_INPUT
fp8_meta = create_fp8_meta() fp8_meta = create_fp8_meta()
...@@ -602,9 +644,9 @@ class TestRMSNorm: ...@@ -602,9 +644,9 @@ class TestRMSNorm:
""" """
N, H = (16, 32) N, H = (16, 32)
eps = 1e-3 eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype='bfloat16') x = paddle.uniform(shape=(N, H), dtype="bfloat16")
dy = paddle.uniform(shape=(N, H), dtype='bfloat16') dy = paddle.uniform(shape=(N, H), dtype="bfloat16")
gamma = paddle.uniform(shape=(H,), dtype='bfloat16') gamma = paddle.uniform(shape=(H,), dtype="bfloat16")
dx_ref, dgamma_ref = self.calc_bwd_ref(x, eps, gamma, dy) dx_ref, dgamma_ref = self.calc_bwd_ref(x, eps, gamma, dy)
...@@ -620,7 +662,7 @@ class TestFusedAttn: ...@@ -620,7 +662,7 @@ class TestFusedAttn:
Test fused attention operators Test fused attention operators
""" """
def set_input(self, b, s_q, s_kv, h, d, dtype, attn_mode='self_attn', is_causal_masking=False): def set_input(self, b, s_q, s_kv, h, d, dtype, attn_mode="self_attn", is_causal_masking=False):
""" """
set test input set test input
""" """
...@@ -682,10 +724,10 @@ class TestFusedAttn: ...@@ -682,10 +724,10 @@ class TestFusedAttn:
assert attn_mode == "self_attn", "only support causal masking for self attention" assert attn_mode == "self_attn", "only support causal masking for self attention"
for i in range(0, self.batch_size): for i in range(0, self.batch_size):
for j in range(self.q_actual_seqlen[i]): for j in range(self.q_actual_seqlen[i]):
self.attn_mask[i, :, j, :j + 1] = 0 self.attn_mask[i, :, j, : j + 1] = 0
else: else:
for i in range(0, self.batch_size): for i in range(0, self.batch_size):
self.attn_mask[i, :, :self.q_actual_seqlen[i], :self.kv_actual_seqlen[i]] = 0 self.attn_mask[i, :, : self.q_actual_seqlen[i], : self.kv_actual_seqlen[i]] = 0
dout = _random((self.batch_size, self.q_seqlen, self.num_heads, self.head_size)) dout = _random((self.batch_size, self.q_seqlen, self.num_heads, self.head_size))
self.dout = paddle.to_tensor(dout, dtype=self.dtype) self.dout = paddle.to_tensor(dout, dtype=self.dtype)
...@@ -696,9 +738,9 @@ class TestFusedAttn: ...@@ -696,9 +738,9 @@ class TestFusedAttn:
k_tensor = paddle.to_tensor(self.kv, stop_gradient=False) k_tensor = paddle.to_tensor(self.kv, stop_gradient=False)
v_tensor = paddle.to_tensor(self.kv, stop_gradient=False) v_tensor = paddle.to_tensor(self.kv, stop_gradient=False)
q_out = paddle.transpose(x=q_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d] q_out = paddle.transpose(x=q_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d]
k_out = paddle.transpose(x=k_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d] k_out = paddle.transpose(x=k_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d]
v_out = paddle.transpose(x=v_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d] v_out = paddle.transpose(x=v_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d]
qk_out = paddle.matmul( qk_out = paddle.matmul(
x=q_out * self.scaling_factor, x=q_out * self.scaling_factor,
...@@ -707,10 +749,10 @@ class TestFusedAttn: ...@@ -707,10 +749,10 @@ class TestFusedAttn:
transpose_y=True, transpose_y=True,
) )
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True).cast('bool') attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True).cast("bool")
attn_mask_vals = paddle.full(qk_out.shape, -1e4, qk_out.dtype) attn_mask_vals = paddle.full(qk_out.shape, -1e4, qk_out.dtype)
attn_mask_out = paddle.where(attn_mask, attn_mask_vals, qk_out) attn_mask_out = paddle.where(attn_mask, attn_mask_vals, qk_out)
attn_mask_out = paddle.cast(attn_mask_out, 'float32') attn_mask_out = paddle.cast(attn_mask_out, "float32")
softmax_out = F.softmax(attn_mask_out) softmax_out = F.softmax(attn_mask_out)
softmax_out = paddle.cast(softmax_out, self.dtype) softmax_out = paddle.cast(softmax_out, self.dtype)
...@@ -725,7 +767,7 @@ class TestFusedAttn: ...@@ -725,7 +767,7 @@ class TestFusedAttn:
else: else:
qkv_out = paddle.matmul(softmax_out, v_out) qkv_out = paddle.matmul(softmax_out, v_out)
out = paddle.transpose(qkv_out, perm=[0, 2, 1, 3]) # [b, h, s, d] -> [b, s, h, d] out = paddle.transpose(qkv_out, perm=[0, 2, 1, 3]) # [b, h, s, d] -> [b, s, h, d]
paddle.autograd.backward( paddle.autograd.backward(
[out], [out],
...@@ -738,17 +780,17 @@ class TestFusedAttn: ...@@ -738,17 +780,17 @@ class TestFusedAttn:
paddle.disable_static(place=paddle.CUDAPlace(0)) paddle.disable_static(place=paddle.CUDAPlace(0))
if self.attn_mode == "self_attn": if self.attn_mode == "self_attn":
qkv = np.stack([self.q, self.kv, self.kv], axis=2) # [b, s, 3, h, d] qkv = np.stack([self.q, self.kv, self.kv], axis=2) # [b, s, 3, h, d]
qkv_tensor = paddle.to_tensor(qkv, stop_gradient=False) qkv_tensor = paddle.to_tensor(qkv, stop_gradient=False)
else: else:
q_tensor = paddle.to_tensor(self.q, stop_gradient=False) q_tensor = paddle.to_tensor(self.q, stop_gradient=False)
kv = np.stack([self.kv, self.kv], axis=2) # [b, s, 2, h, d] kv = np.stack([self.kv, self.kv], axis=2) # [b, s, 2, h, d]
kv_tensor = paddle.to_tensor(kv, stop_gradient=False) kv_tensor = paddle.to_tensor(kv, stop_gradient=False)
q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True) q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True)
kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True) kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True)
qkv_layout = ("bs3hd" if self.attn_mode == "self_attn" else "bshd_bs2hd") qkv_layout = "bs3hd" if self.attn_mode == "self_attn" else "bshd_bs2hd"
fused_attention_backend = get_fused_attention_backend( fused_attention_backend = get_fused_attention_backend(
num_heads=self.num_heads, num_heads=self.num_heads,
num_gqa_groups=self.num_heads, num_gqa_groups=self.num_heads,
...@@ -764,7 +806,7 @@ class TestFusedAttn: ...@@ -764,7 +806,7 @@ class TestFusedAttn:
qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16 qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16
out, softmax_aux_tensor, q_grad, k_grad, v_grad = None, None, None, None, None out, softmax_aux_tensor, q_grad, k_grad, v_grad = None, None, None, None, None
if self.attn_mode == 'self_attn': if self.attn_mode == "self_attn":
out, softmax_aux_tensor, rng_state = fused_attn_fwd_qkvpacked( out, softmax_aux_tensor, rng_state = fused_attn_fwd_qkvpacked(
qkv_tensor, qkv_tensor,
q_cu_seqlen_tensor, q_cu_seqlen_tensor,
...@@ -776,7 +818,8 @@ class TestFusedAttn: ...@@ -776,7 +818,8 @@ class TestFusedAttn:
attn_scale=self.scaling_factor, attn_scale=self.scaling_factor,
dropout=self.dropout_prob, dropout=self.dropout_prob,
set_zero=False, set_zero=False,
attn_mask_type="causal" if self.is_causal_masking else "padding") attn_mask_type="causal" if self.is_causal_masking else "padding",
)
dqkv, _ = fused_attn_bwd_qkvpacked( dqkv, _ = fused_attn_bwd_qkvpacked(
qkv_tensor, qkv_tensor,
q_cu_seqlen_tensor, q_cu_seqlen_tensor,
...@@ -790,11 +833,12 @@ class TestFusedAttn: ...@@ -790,11 +833,12 @@ class TestFusedAttn:
attn_scale=self.scaling_factor, attn_scale=self.scaling_factor,
dropout=self.dropout_prob, dropout=self.dropout_prob,
set_zero=False, set_zero=False,
attn_mask_type="causal" if self.is_causal_masking else "padding") attn_mask_type="causal" if self.is_causal_masking else "padding",
)
q_grad = dqkv[:, :, 0, :, :] q_grad = dqkv[:, :, 0, :, :]
k_grad = dqkv[:, :, 1, :, :] k_grad = dqkv[:, :, 1, :, :]
v_grad = dqkv[:, :, 2, :, :] v_grad = dqkv[:, :, 2, :, :]
else: # attn_mode == 'cross_attn' else: # attn_mode == 'cross_attn'
out, softmax_aux_tensor, rng_state = fused_attn_fwd_kvpacked( out, softmax_aux_tensor, rng_state = fused_attn_fwd_kvpacked(
q_tensor, q_tensor,
kv_tensor, kv_tensor,
...@@ -808,22 +852,25 @@ class TestFusedAttn: ...@@ -808,22 +852,25 @@ class TestFusedAttn:
Bias=None, Bias=None,
attn_scale=self.scaling_factor, attn_scale=self.scaling_factor,
dropout=self.dropout_prob, dropout=self.dropout_prob,
set_zero=False) set_zero=False,
dq, dkv, _ = fused_attn_bwd_kvpacked(q_tensor, )
kv_tensor, dq, dkv, _ = fused_attn_bwd_kvpacked(
q_cu_seqlen_tensor, q_tensor,
kv_cu_seqlen_tensor, kv_tensor,
rng_state, q_cu_seqlen_tensor,
out, kv_cu_seqlen_tensor,
self.dout, rng_state,
softmax_aux_tensor, out,
fused_attention_backend=fused_attention_backend, self.dout,
max_seqlen_q=self.q_seqlen, softmax_aux_tensor,
max_seqlen_kv=self.kv_seqlen, fused_attention_backend=fused_attention_backend,
qkv_dtype=qkv_dtype, max_seqlen_q=self.q_seqlen,
attn_scale=self.scaling_factor, max_seqlen_kv=self.kv_seqlen,
dropout=self.dropout_prob, qkv_dtype=qkv_dtype,
set_zero=False) attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False,
)
q_grad = dq q_grad = dq
k_grad = dkv[:, :, 0, :, :] k_grad = dkv[:, :, 0, :, :]
v_grad = dkv[:, :, 1, :, :] v_grad = dkv[:, :, 1, :, :]
...@@ -871,7 +918,8 @@ class TestFusedAttn: ...@@ -871,7 +918,8 @@ class TestFusedAttn:
dropout=self.dropout_prob, dropout=self.dropout_prob,
set_zero=False, set_zero=False,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
attn_mask_type="causal" if self.is_causal_masking else "padding") attn_mask_type="causal" if self.is_causal_masking else "padding",
)
dq, dk, dv, _ = fused_attn_bwd( dq, dk, dv, _ = fused_attn_bwd(
q_tensor, q_tensor,
k_tensor, k_tensor,
...@@ -890,28 +938,29 @@ class TestFusedAttn: ...@@ -890,28 +938,29 @@ class TestFusedAttn:
dropout=self.dropout_prob, dropout=self.dropout_prob,
set_zero=False, set_zero=False,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
attn_mask_type="causal" if self.is_causal_masking else "padding") attn_mask_type="causal" if self.is_causal_masking else "padding",
)
return out, dq, dk, dv return out, dq, dk, dv
@pytest.mark.parametrize('b, s, h, d', SELF_ATTN_CASES) @pytest.mark.parametrize("b, s, h, d", SELF_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
@pytest.mark.parametrize('is_causal_masking', [True, False]) @pytest.mark.parametrize("is_causal_masking", [True, False])
def test_self_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking): def test_self_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking):
""" """
test self attention forward + backward test self attention forward + backward
""" """
if not is_fused_attention_supported( if not is_fused_attention_supported(
num_heads=h, num_heads=h,
num_gqa_groups=h, num_gqa_groups=h,
q_seqlen=s, q_seqlen=s,
kv_seqlen=s, kv_seqlen=s,
head_size=d, head_size=d,
dtype=dtype, dtype=dtype,
dropout=0.0, dropout=0.0,
qkv_layout="bs3hd", qkv_layout="bs3hd",
bias_type="no_bias", bias_type="no_bias",
mask_type="causal" if is_causal_masking else "padding", mask_type="causal" if is_causal_masking else "padding",
): ):
pytest.skip("cuDNN fused attention is not supported") pytest.skip("cuDNN fused attention is not supported")
self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking) self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
...@@ -922,23 +971,23 @@ class TestFusedAttn: ...@@ -922,23 +971,23 @@ class TestFusedAttn:
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_ATTN_CASES) @pytest.mark.parametrize("b, s_q, s_kv, h, d", CROSS_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
def test_cross_attn_forward_backward(self, b, s_q, s_kv, h, d, dtype): def test_cross_attn_forward_backward(self, b, s_q, s_kv, h, d, dtype):
""" """
test cross attention forward + backward test cross attention forward + backward
""" """
if not is_fused_attention_supported( if not is_fused_attention_supported(
num_heads=h, num_heads=h,
num_gqa_groups=h, num_gqa_groups=h,
q_seqlen=s_q, q_seqlen=s_q,
kv_seqlen=s_kv, kv_seqlen=s_kv,
head_size=d, head_size=d,
dtype=dtype, dtype=dtype,
dropout=0.0, dropout=0.0,
qkv_layout="bshd_bs2hd", qkv_layout="bshd_bs2hd",
bias_type="no_bias", bias_type="no_bias",
mask_type="padding", mask_type="padding",
): ):
pytest.skip("cuDNN fused attention is not supported") pytest.skip("cuDNN fused attention is not supported")
self.set_input(b, s_q, s_kv, h, d, dtype, "cross_attn") self.set_input(b, s_q, s_kv, h, d, dtype, "cross_attn")
...@@ -949,24 +998,24 @@ class TestFusedAttn: ...@@ -949,24 +998,24 @@ class TestFusedAttn:
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
@pytest.mark.parametrize('b, s, h, d', FLASH_ATTN_CASES) @pytest.mark.parametrize("b, s, h, d", FLASH_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
@pytest.mark.parametrize('is_causal_masking', [True]) @pytest.mark.parametrize("is_causal_masking", [True])
def test_flash_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking): def test_flash_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking):
""" """
test flash attention forward + backward test flash attention forward + backward
""" """
if not is_fused_attention_supported( if not is_fused_attention_supported(
num_heads=h, num_heads=h,
num_gqa_groups=h, num_gqa_groups=h,
q_seqlen=s, q_seqlen=s,
kv_seqlen=s, kv_seqlen=s,
head_size=d, head_size=d,
dtype=dtype, dtype=dtype,
dropout=0.0, dropout=0.0,
qkv_layout="bs3hd", qkv_layout="bs3hd",
bias_type="no_bias", bias_type="no_bias",
mask_type="causal" if is_causal_masking else "padding", mask_type="causal" if is_causal_masking else "padding",
): ):
pytest.skip("cuDNN fused attention is not supported") pytest.skip("cuDNN fused attention is not supported")
self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking) self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
...@@ -977,25 +1026,26 @@ class TestFusedAttn: ...@@ -977,25 +1026,26 @@ class TestFusedAttn:
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
@pytest.mark.parametrize('b, s, h, d', FLASH_ATTN_CASES) @pytest.mark.parametrize("b, s, h, d", FLASH_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
@pytest.mark.parametrize('is_causal_masking', [False, True]) @pytest.mark.parametrize("is_causal_masking", [False, True])
def test_fused_attn_with_separate_qkv_forward_backward(self, b, s, h, d, dtype, def test_fused_attn_with_separate_qkv_forward_backward(
is_causal_masking): self, b, s, h, d, dtype, is_causal_masking
):
""" """
test flash attention forward + backward with separate qkv inputs test flash attention forward + backward with separate qkv inputs
""" """
if not is_fused_attention_supported( if not is_fused_attention_supported(
num_heads=h, num_heads=h,
num_gqa_groups=h, num_gqa_groups=h,
q_seqlen=s, q_seqlen=s,
kv_seqlen=s, kv_seqlen=s,
head_size=d, head_size=d,
dtype=dtype, dtype=dtype,
dropout=0.0, dropout=0.0,
qkv_layout="bshd_bshd_bshd", qkv_layout="bshd_bshd_bshd",
bias_type="no_bias", bias_type="no_bias",
mask_type="causal" if is_causal_masking else "padding", mask_type="causal" if is_causal_masking else "padding",
): ):
pytest.skip("cuDNN fused attention is not supported") pytest.skip("cuDNN fused attention is not supported")
self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking) self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
...@@ -1013,7 +1063,7 @@ class TestSoftmax: ...@@ -1013,7 +1063,7 @@ class TestSoftmax:
""" """
@staticmethod @staticmethod
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
def test_scaled_softmax_fwd_bwd(dtype): def test_scaled_softmax_fwd_bwd(dtype):
"""test scaled softmax""" """test scaled softmax"""
B, H, S = (16, 4, 32) B, H, S = (16, 4, 32)
...@@ -1034,7 +1084,7 @@ class TestSoftmax: ...@@ -1034,7 +1084,7 @@ class TestSoftmax:
assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3) assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3)
@staticmethod @staticmethod
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
def test_scaled_masked_softmax_fwd_bwd(dtype): def test_scaled_masked_softmax_fwd_bwd(dtype):
"""test scaled masked softmax""" """test scaled masked softmax"""
B, H, S = (16, 4, 32) B, H, S = (16, 4, 32)
...@@ -1058,7 +1108,7 @@ class TestSoftmax: ...@@ -1058,7 +1108,7 @@ class TestSoftmax:
assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3) assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3)
@staticmethod @staticmethod
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
def test_scaled_upper_triang_masked_softmax_fwd_bwd(dtype): def test_scaled_upper_triang_masked_softmax_fwd_bwd(dtype):
"""test scaled upper triang masked softmax""" """test scaled upper triang masked softmax"""
B, S = (16, 32) B, S = (16, 32)
...@@ -1068,7 +1118,7 @@ class TestSoftmax: ...@@ -1068,7 +1118,7 @@ class TestSoftmax:
x.stop_gradient = False x.stop_gradient = False
dy = paddle.uniform(shape=(B, S, S), dtype=dtype) dy = paddle.uniform(shape=(B, S, S), dtype=dtype)
mask = paddle.ones((S, S), dtype='int32') mask = paddle.ones((S, S), dtype="int32")
col_beg, col_end = 1, S col_beg, col_end = 1, S
for row in range(0, S): for row in range(0, S):
mask[row, col_beg:col_end] = 0 mask[row, col_beg:col_end] = 0
...@@ -1087,7 +1137,7 @@ class TestSoftmax: ...@@ -1087,7 +1137,7 @@ class TestSoftmax:
assert_allclose(dx_ref, dx, rtol=1e-4, atol=5e-3) assert_allclose(dx_ref, dx, rtol=1e-4, atol=5e-3)
@pytest.mark.parametrize('update_weight_scale_inv', [True, False]) @pytest.mark.parametrize("update_weight_scale_inv", [True, False])
def test_amax_and_scale_update(update_weight_scale_inv): def test_amax_and_scale_update(update_weight_scale_inv):
"""Test update_scale""" """Test update_scale"""
num_gemm = 6 num_gemm = 6
...@@ -1097,11 +1147,11 @@ def test_amax_and_scale_update(update_weight_scale_inv): ...@@ -1097,11 +1147,11 @@ def test_amax_and_scale_update(update_weight_scale_inv):
fp8_max = recipe.fp8_format.value.max_fwd fp8_max = recipe.fp8_format.value.max_fwd
non_weight_mask = paddle.to_tensor([True, False] * (num_gemm // 2)) non_weight_mask = paddle.to_tensor([True, False] * (num_gemm // 2))
amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype='float32') amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype="float32")
rolled_history_ref = paddle.roll(amax_history_tensor, -1, axis=0) rolled_history_ref = paddle.roll(amax_history_tensor, -1, axis=0)
rolled_history_ref[0] = 0.0 rolled_history_ref[0] = 0.0
amax_tensor = paddle.max(amax_history_tensor, axis=0) amax_tensor = paddle.max(amax_history_tensor, axis=0)
scale_tensor = paddle.ones(shape=[num_gemm], dtype='float32') scale_tensor = paddle.ones(shape=[num_gemm], dtype="float32")
def calc_ref(amax, scale, fp8_max, margin=0): def calc_ref(amax, scale, fp8_max, margin=0):
"""Calculate reference scale""" """Calculate reference scale"""
...@@ -1110,12 +1160,12 @@ def test_amax_and_scale_update(update_weight_scale_inv): ...@@ -1110,12 +1160,12 @@ def test_amax_and_scale_update(update_weight_scale_inv):
sf = paddle.where(paddle.isfinite(amax), sf, scale) sf = paddle.where(paddle.isfinite(amax), sf, scale)
return sf return sf
scale_ref = calc_ref(amax_tensor, scale_tensor, fp8_max, 0.) scale_ref = calc_ref(amax_tensor, scale_tensor, fp8_max, 0.0)
if update_weight_scale_inv: if update_weight_scale_inv:
scale_inv_ref = 1. / scale_ref scale_inv_ref = 1.0 / scale_ref
else: else:
scale_inv_ref = paddle.zeros_like(scale_tensor) scale_inv_ref = paddle.zeros_like(scale_tensor)
scale_inv_ref = paddle.where(non_weight_mask, 1. / scale_ref, scale_inv_ref) scale_inv_ref = paddle.where(non_weight_mask, 1.0 / scale_ref, scale_inv_ref)
# Placeholder # Placeholder
scale_actual = paddle.zeros_like(scale_tensor) scale_actual = paddle.zeros_like(scale_tensor)
...@@ -1123,13 +1173,15 @@ def test_amax_and_scale_update(update_weight_scale_inv): ...@@ -1123,13 +1173,15 @@ def test_amax_and_scale_update(update_weight_scale_inv):
if update_weight_scale_inv: if update_weight_scale_inv:
non_weight_mask = paddle.empty([0]) non_weight_mask = paddle.empty([0])
tex.amax_and_scale_update_inplace(_amax_history=amax_history_tensor, tex.amax_and_scale_update_inplace(
_scale=scale_actual, _amax_history=amax_history_tensor,
_scale_inv=scale_inv_actual, _scale=scale_actual,
non_weight_mask=non_weight_mask, _scale_inv=scale_inv_actual,
fp8_dtype=int(fp8_dtype), non_weight_mask=non_weight_mask,
margin=0., fp8_dtype=int(fp8_dtype),
amax_compute="max") margin=0.0,
amax_compute="max",
)
assert_allclose(scale_actual, scale_ref, rtol=1e-7, atol=1e-7) assert_allclose(scale_actual, scale_ref, rtol=1e-7, atol=1e-7)
assert_allclose(scale_inv_actual, scale_inv_ref, rtol=1e-7, atol=1e-7) assert_allclose(scale_inv_actual, scale_inv_ref, rtol=1e-7, atol=1e-7)
...@@ -1141,8 +1193,8 @@ def test_update_latest_history(): ...@@ -1141,8 +1193,8 @@ def test_update_latest_history():
num_gemm = 6 num_gemm = 6
history_len = 1024 history_len = 1024
amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype='float32') amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype="float32")
amax = paddle.rand(shape=[num_gemm], dtype='float32') amax = paddle.rand(shape=[num_gemm], dtype="float32")
tex.update_latest_amax_history_inplace(_history=amax_history_tensor, amax=amax) tex.update_latest_amax_history_inplace(_history=amax_history_tensor, amax=amax)
......
...@@ -22,7 +22,7 @@ class TestParallelLinear(TestDistributed): ...@@ -22,7 +22,7 @@ class TestParallelLinear(TestDistributed):
@unittest.skipIf(not gpu_has_fp8, reason) @unittest.skipIf(not gpu_has_fp8, reason)
def test_linear_tp(self): def test_linear_tp(self):
"""Tests linear with tensor parallel in BF16""" """Tests linear with tensor parallel in BF16"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'linear_tp.py')) self.run_2gpu(str(test_root / "parallel_tests" / "linear_tp.py"))
class TestParallelLayerNormLinear(TestDistributed): class TestParallelLayerNormLinear(TestDistributed):
...@@ -32,7 +32,7 @@ class TestParallelLayerNormLinear(TestDistributed): ...@@ -32,7 +32,7 @@ class TestParallelLayerNormLinear(TestDistributed):
@unittest.skipIf(not gpu_has_fp8, reason) @unittest.skipIf(not gpu_has_fp8, reason)
def test_layernorm_linear_tp(self): def test_layernorm_linear_tp(self):
"""Tests layernorm_linear with tensor parallel in BF16""" """Tests layernorm_linear with tensor parallel in BF16"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'layernorm_linear_tp.py')) self.run_2gpu(str(test_root / "parallel_tests" / "layernorm_linear_tp.py"))
class TestParallelLayerNormMLP(TestDistributed): class TestParallelLayerNormMLP(TestDistributed):
...@@ -42,7 +42,7 @@ class TestParallelLayerNormMLP(TestDistributed): ...@@ -42,7 +42,7 @@ class TestParallelLayerNormMLP(TestDistributed):
@unittest.skipIf(not gpu_has_fp8, reason) @unittest.skipIf(not gpu_has_fp8, reason)
def test_layernorm_mlp_tp(self): def test_layernorm_mlp_tp(self):
"""Tests layernorm_mlp with tensor parallel in BF16""" """Tests layernorm_mlp with tensor parallel in BF16"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'layernorm_mlp_tp.py')) self.run_2gpu(str(test_root / "parallel_tests" / "layernorm_mlp_tp.py"))
class TestAmaxReduction(TestDistributed): class TestAmaxReduction(TestDistributed):
...@@ -52,7 +52,7 @@ class TestAmaxReduction(TestDistributed): ...@@ -52,7 +52,7 @@ class TestAmaxReduction(TestDistributed):
@unittest.skipIf(not gpu_has_fp8, reason) @unittest.skipIf(not gpu_has_fp8, reason)
def test_amax_reduction(self): def test_amax_reduction(self):
"""Tests amax reduction""" """Tests amax reduction"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'amax_reduction.py')) self.run_2gpu(str(test_root / "parallel_tests" / "amax_reduction.py"))
class TestPipelineParallel(TestDistributed): class TestPipelineParallel(TestDistributed):
...@@ -62,7 +62,7 @@ class TestPipelineParallel(TestDistributed): ...@@ -62,7 +62,7 @@ class TestPipelineParallel(TestDistributed):
@unittest.skipIf(not gpu_has_fp8, reason) @unittest.skipIf(not gpu_has_fp8, reason)
def test_pipeline_parallel(self): def test_pipeline_parallel(self):
"""Tests pipeline parallel""" """Tests pipeline parallel"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'linear_pp.py')) self.run_2gpu(str(test_root / "parallel_tests" / "linear_pp.py"))
class TestGroupSharding(TestDistributed): class TestGroupSharding(TestDistributed):
...@@ -72,7 +72,7 @@ class TestGroupSharding(TestDistributed): ...@@ -72,7 +72,7 @@ class TestGroupSharding(TestDistributed):
@unittest.skipIf(not gpu_has_fp8, reason) @unittest.skipIf(not gpu_has_fp8, reason)
def test_group_sharding(self): def test_group_sharding(self):
"""Tests group sharding""" """Tests group sharding"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'group_sharding.py')) self.run_2gpu(str(test_root / "parallel_tests" / "group_sharding.py"))
class TestParallelAttention(TestDistributed): class TestParallelAttention(TestDistributed):
...@@ -82,7 +82,7 @@ class TestParallelAttention(TestDistributed): ...@@ -82,7 +82,7 @@ class TestParallelAttention(TestDistributed):
@unittest.skipIf(not gpu_has_fp8, reason) @unittest.skipIf(not gpu_has_fp8, reason)
def test_attention_tp(self): def test_attention_tp(self):
"""Tests TransMultiHeadAttentionformer Layer with tensor parallel in BF16""" """Tests TransMultiHeadAttentionformer Layer with tensor parallel in BF16"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'attention_tp.py')) self.run_2gpu(str(test_root / "parallel_tests" / "attention_tp.py"))
class TestParallelTransformerLayer(TestDistributed): class TestParallelTransformerLayer(TestDistributed):
...@@ -92,8 +92,8 @@ class TestParallelTransformerLayer(TestDistributed): ...@@ -92,8 +92,8 @@ class TestParallelTransformerLayer(TestDistributed):
@unittest.skipIf(not gpu_has_fp8, reason) @unittest.skipIf(not gpu_has_fp8, reason)
def test_transformer_tp(self): def test_transformer_tp(self):
"""Tests Transformer Layer with tensor parallel in BF16""" """Tests Transformer Layer with tensor parallel in BF16"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'transformer_tp.py')) self.run_2gpu(str(test_root / "parallel_tests" / "transformer_tp.py"))
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -17,7 +17,7 @@ is_fp8_supported, reason = is_fp8_available() ...@@ -17,7 +17,7 @@ is_fp8_supported, reason = is_fp8_available()
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('use_reentrant', [False, True]) @pytest.mark.parametrize("use_reentrant", [False, True])
def test_transformer_encoder_recompute(use_reentrant): def test_transformer_encoder_recompute(use_reentrant):
""" """
Test TransformerLayer encoder recompute Test TransformerLayer encoder recompute
...@@ -29,17 +29,17 @@ def test_transformer_encoder_recompute(use_reentrant): ...@@ -29,17 +29,17 @@ def test_transformer_encoder_recompute(use_reentrant):
"""Launch training in subprocess and check output""" """Launch training in subprocess and check output"""
try: try:
cmd = [ cmd = [
'python', "python",
str(test_root / 'recompute_tests' / 'recompute_transformer_encoder.py'), str(test_root / "recompute_tests" / "recompute_transformer_encoder.py"),
str(int(enable_recompute)), str(int(enable_recompute)),
str(int(use_reentrant)) str(int(use_reentrant)),
] ]
result = subprocess.check_output(cmd, stderr=subprocess.STDOUT, universal_newlines=True) result = subprocess.check_output(cmd, stderr=subprocess.STDOUT, universal_newlines=True)
print(result) print(result)
loss_match = re.search(r'Loss:\s+(-?\d+\.\d+)', result) loss_match = re.search(r"Loss:\s+(-?\d+\.\d+)", result)
memory_match = re.search(r'Peak memory:\s+(\d+)', result) memory_match = re.search(r"Peak memory:\s+(\d+)", result)
loss_value = float(loss_match.group(1)) loss_value = float(loss_match.group(1))
memory_value = int(memory_match.group(1)) memory_value = int(memory_match.group(1))
......
...@@ -3,4 +3,5 @@ ...@@ -3,4 +3,5 @@
# See LICENSE for license information. # See LICENSE for license information.
import transformer_engine.paddle import transformer_engine.paddle
print("OK") print("OK")
...@@ -11,7 +11,7 @@ import paddle ...@@ -11,7 +11,7 @@ import paddle
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
import transformer_engine # pylint: disable=unused-import import transformer_engine # pylint: disable=unused-import
from transformer_engine.paddle.constants import ( from transformer_engine.paddle.constants import (
TE_DType, TE_DType,
AttnBiasType, AttnBiasType,
...@@ -19,7 +19,9 @@ from transformer_engine.paddle.constants import ( ...@@ -19,7 +19,9 @@ from transformer_engine.paddle.constants import (
FusedAttnBackend, FusedAttnBackend,
) )
from transformer_engine.paddle.fp8 import FP8TensorMeta from transformer_engine.paddle.fp8 import FP8TensorMeta
from transformer_engine import transformer_engine_paddle as tex # pylint: disable=wrong-import-order from transformer_engine import (
transformer_engine_paddle as tex,
) # pylint: disable=wrong-import-order
def create_fp8_meta(num_gemms=1, amax_history_len=10): def create_fp8_meta(num_gemms=1, amax_history_len=10):
...@@ -31,18 +33,14 @@ def create_fp8_meta(num_gemms=1, amax_history_len=10): ...@@ -31,18 +33,14 @@ def create_fp8_meta(num_gemms=1, amax_history_len=10):
return fp8_meta return fp8_meta
def assert_allclose(actual, def assert_allclose(
desired, actual, desired, rtol=1e-05, atol=1e-08, equal_nan=True, err_msg="", verbose=True
rtol=1e-05, ):
atol=1e-08,
equal_nan=True,
err_msg='',
verbose=True):
"""Compare two input paddle tensors""" """Compare two input paddle tensors"""
if isinstance(actual, paddle.Tensor): if isinstance(actual, paddle.Tensor):
actual = paddle.cast(actual, 'float32') actual = paddle.cast(actual, "float32")
if isinstance(desired, paddle.Tensor): if isinstance(desired, paddle.Tensor):
desired = paddle.cast(desired, 'float32') desired = paddle.cast(desired, "float32")
if len(actual.shape) == 0: if len(actual.shape) == 0:
actual = actual.item() actual = actual.item()
desired = desired.item() desired = desired.item()
...@@ -54,8 +52,9 @@ def assert_allclose(actual, ...@@ -54,8 +52,9 @@ def assert_allclose(actual,
def assert_shape(inp, expected_shape): def assert_shape(inp, expected_shape):
"""Assert the shape of input tensor equals to expected shape""" """Assert the shape of input tensor equals to expected shape"""
assert inp.shape == expected_shape, f"Expected tensor shape: {expected_shape} != " \ assert (
f"actual tensor shape: {inp.shape}" inp.shape == expected_shape
), f"Expected tensor shape: {expected_shape} != actual tensor shape: {inp.shape}"
def is_devices_enough(required): def is_devices_enough(required):
...@@ -91,12 +90,21 @@ def set_random_seed(seed): ...@@ -91,12 +90,21 @@ def set_random_seed(seed):
np.random.seed(seed + 100 * pp_rank) np.random.seed(seed + 100 * pp_rank)
seed_offset = seed + 1024 + paddle.distributed.get_world_size() seed_offset = seed + 1024 + paddle.distributed.get_world_size()
global_seed = (seed_offset + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) + global_seed = (
sharding_rank * (mp_size * pp_size * dp_size)) seed_offset
+ pp_rank * (mp_size)
+ dp_rank * (mp_size * pp_size)
+ sharding_rank * (mp_size * pp_size * dp_size)
)
seed_offset += paddle.distributed.get_world_size() seed_offset += paddle.distributed.get_world_size()
local_seed = (seed_offset + mp_rank + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) + local_seed = (
sharding_rank * (mp_size * pp_size * dp_size)) seed_offset
+ mp_rank
+ pp_rank * (mp_size)
+ dp_rank * (mp_size * pp_size)
+ sharding_rank * (mp_size * pp_size * dp_size)
)
tracker = get_rng_state_tracker() tracker = get_rng_state_tracker()
# tracker.reset() # tracker.reset()
......
...@@ -112,13 +112,17 @@ def perf_and_loss_plots(): ...@@ -112,13 +112,17 @@ def perf_and_loss_plots():
lm_loss_data.append(lm_data["loss"]) lm_loss_data.append(lm_data["loss"])
lm_perf_data.append(lm_data["perf"]) lm_perf_data.append(lm_data["perf"])
save_plot( save_plot(
model_config + " loss", legend, model_config + " loss",
lm_loss_data, model_config + "_loss.png", legend,
lm_loss_data,
model_config + "_loss.png",
"LM-Loss", "LM-Loss",
) )
save_plot( save_plot(
model_config + " perf", model_config + " perf",
legend, lm_perf_data, model_config + "_perf.png", legend,
lm_perf_data,
model_config + "_perf.png",
"Time per step (ms)", "Time per step (ms)",
) )
......
...@@ -68,7 +68,9 @@ def get_filename( ...@@ -68,7 +68,9 @@ def get_filename(
config = f"gpt3_{model}_dp{dp}_tp{tp}_pp{pp}_sp{sp}" config = f"gpt3_{model}_dp{dp}_tp{tp}_pp{pp}_sp{sp}"
config_dir = os.path.join(mlm_log_dir, config) config_dir = os.path.join(mlm_log_dir, config)
os.makedirs(config_dir, exist_ok=True) os.makedirs(config_dir, exist_ok=True)
fname = f"{'te' if use_te else 'megatron'}" + (f"_fp8_{fp8_recipe}" if fp8_recipe else "") + ".txt" fname = (
f"{'te' if use_te else 'megatron'}" + (f"_fp8_{fp8_recipe}" if fp8_recipe else "") + ".txt"
)
return os.path.join(config_dir, fname) return os.path.join(config_dir, fname)
...@@ -106,4 +108,5 @@ def test_distributed(dtype, fp8_recipe, dp, tp, pp, sp, use_te, model): ...@@ -106,4 +108,5 @@ def test_distributed(dtype, fp8_recipe, dp, tp, pp, sp, use_te, model):
TRANSFORMER_IMPL="transformer_engine" if use_te else "local", TRANSFORMER_IMPL="transformer_engine" if use_te else "local",
**asdict(model_configs[model]), **asdict(model_configs[model]),
), ),
check=True) check=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment