Unverified Commit 9416519d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Apply formatting (#929)



* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d99142a0
......@@ -21,15 +21,15 @@ from paddle.distributed.utils.launch_utils import (
watch_local_trainers,
)
__all__ = ['TestDistributed']
__all__ = ["TestDistributed"]
def get_cluster_from_args(selected_gpus):
"""Get node information from selected GPUs"""
cluster_node_ips = '127.0.0.1'
node_ip = '127.0.0.1'
cluster_node_ips = "127.0.0.1"
node_ip = "127.0.0.1"
node_ips = [x.strip() for x in cluster_node_ips.split(',')]
node_ips = [x.strip() for x in cluster_node_ips.split(",")]
node_ips.index(node_ip)
......@@ -47,7 +47,7 @@ def get_cluster_from_args(selected_gpus):
def get_gpus(selected_gpus):
"""Get selected GPU string"""
selected_gpus = [x.strip() for x in selected_gpus.split(',')]
selected_gpus = [x.strip() for x in selected_gpus.split(",")]
return selected_gpus
......@@ -86,7 +86,7 @@ def start_local_trainers(
print(f"trainer proc env:{current_env}")
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
if os.getenv("WITH_COVERAGE", "OFF") == "ON":
cmd = "python -m coverage run --branch -p " + training_script
else:
cmd = "python -u " + training_script
......@@ -95,7 +95,9 @@ def start_local_trainers(
fn = None
proc = subprocess.Popen(cmd.split(" ") + training_script_args, env=current_env) # pylint: disable=consider-using-with
proc = subprocess.Popen(
cmd.split(" ") + training_script_args, env=current_env
) # pylint: disable=consider-using-with
tp = TrainerProc()
tp.proc = proc
......@@ -117,10 +119,10 @@ class TestDistributed(unittest.TestCase):
allocator_strategy="auto_growth",
):
"""Run target file in subprocesses"""
if (not core.is_compiled_with_cuda() or core.get_cuda_device_count() == 0):
if not core.is_compiled_with_cuda() or core.get_cuda_device_count() == 0:
return
selected_gpus = get_gpus('0,1')
selected_gpus = get_gpus("0,1")
cluster = None
pod = None
......
......@@ -27,7 +27,7 @@ class TestAmaxReduction(unittest.TestCase):
def setUp(self):
self.data_parallel_size = 2
self.init_dist_env()
self.global_dtype = 'bfloat16'
self.global_dtype = "bfloat16"
paddle.set_default_dtype(self.global_dtype)
def init_dist_env(self):
......@@ -83,5 +83,5 @@ class TestAmaxReduction(unittest.TestCase):
assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].scale_inv)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
......@@ -44,8 +44,8 @@ class TestAttentionTp(unittest.TestCase):
self.num_heads = 16
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.global_dtype = 'bfloat16'
self.mask_type = "padding"
self.global_dtype = "bfloat16"
self.rtol = 5e-3
self.atol = 5e-3
self.eps = 1e-3
......@@ -56,7 +56,7 @@ class TestAttentionTp(unittest.TestCase):
inp, mask = inp_list
if sequence_parallel:
split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :]
input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :]
else:
input_parallel = inp
with te.fp8_autocast(enabled=fp8_enabled):
......@@ -80,18 +80,20 @@ class TestAttentionTp(unittest.TestCase):
self.num_heads,
)
common_kwargs = {
'layernorm_epsilon': self.eps,
'attention_dropout': 0.0,
'attn_mask_type': self.mask_type,
'attention_type': 'self',
"layernorm_epsilon": self.eps,
"attention_dropout": 0.0,
"attn_mask_type": self.mask_type,
"attention_type": "self",
"tp_group": self.tp_group,
"input_layernorm": True,
}
layer_tp = te.MultiHeadAttention(*common_args,
layer_tp = te.MultiHeadAttention(
*common_args,
**common_kwargs,
set_parallel_mode=True,
sequence_parallel=self.sequence_parallel)
sequence_parallel=self.sequence_parallel,
)
layer_single = te.MultiHeadAttention(*common_args, **common_kwargs, set_parallel_mode=False)
def _get_total_weight(local_weight, tp_group, axis, interleave=False):
......@@ -102,12 +104,15 @@ class TestAttentionTp(unittest.TestCase):
# Due to the interleaved qkv layout, need to concat on num_head
# dimension for column parallel linear in MultiHeadAttention layer
assert axis == 0
assert [3 * self.hidden_size // self.world_size,
self.hidden_size] == partial_weight.shape
assert [
3 * self.hidden_size // self.world_size,
self.hidden_size,
] == partial_weight.shape
local_num_head = self.num_heads // self.world_size
for idx, _ in enumerate(total_weight):
total_weight[idx] = total_weight[idx].reshape(
[3, local_num_head, -1, self.hidden_size])
[3, local_num_head, -1, self.hidden_size]
)
total_weight = paddle.concat(total_weight, axis=1).reshape([-1, self.hidden_size])
else:
total_weight = paddle.concat(total_weight, axis=axis)
......@@ -123,42 +128,47 @@ class TestAttentionTp(unittest.TestCase):
weight_dst = _get_weight(layer_dst, weight_names)
if partition_mode is None:
total_weight = weight_src
elif partition_mode == 'column':
total_weight = _get_total_weight(weight_src,
tp_group=self.tp_group,
axis=0,
interleave=interleave)
elif partition_mode == 'row':
elif partition_mode == "column":
total_weight = _get_total_weight(
weight_src, tp_group=self.tp_group, axis=0, interleave=interleave
)
elif partition_mode == "row":
total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1)
else:
raise ValueError(f"Partition Mode {partition_mode} is not supported.")
assert weight_dst.shape == total_weight.shape, \
f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match."
assert (
weight_dst.shape == total_weight.shape
), f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match."
weight_dst.copy_(total_weight, True)
copy_weight(layer_tp, layer_single, None, ['layernorm_qkv', 'ln_weight'])
copy_weight(layer_tp, layer_single, 'column', ['layernorm_qkv', 'weight'], interleave=True)
copy_weight(layer_tp, layer_single, 'row', ['proj', 'weight'])
copy_weight(layer_tp, layer_single, None, ["layernorm_qkv", "ln_weight"])
copy_weight(layer_tp, layer_single, "column", ["layernorm_qkv", "weight"], interleave=True)
copy_weight(layer_tp, layer_single, "row", ["proj", "weight"])
if self.sequence_parallel:
register_sequence_parallel_allreduce_hooks(layer_tp, accumulation_steps=1)
optimizer_tp = paddle.optimizer.SGD(learning_rate=0.01, parameters=layer_tp.parameters())
optimizer_single = paddle.optimizer.SGD(learning_rate=0.01,
parameters=layer_single.parameters())
optimizer_single = paddle.optimizer.SGD(
learning_rate=0.01, parameters=layer_single.parameters()
)
layer_tp = fleet.distributed_model(layer_tp)
optimizer_tp = fleet.distributed_optimizer(optimizer_tp)
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.q_seqlen, self.hidden_size],
self.global_dtype)
mask = paddle.zeros(shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen),
dtype='bool')
loss_tp, out_tp = self._train_one_step(layer_tp, [inp, mask], optimizer_tp, self.fp8,
self.sequence_parallel)
loss_single, out_single = self._train_one_step(layer_single, [inp, mask],
optimizer_single, self.fp8)
inp = paddle.uniform(
[self.batch_size, self.q_seqlen, self.hidden_size], self.global_dtype
)
mask = paddle.zeros(
shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), dtype="bool"
)
loss_tp, out_tp = self._train_one_step(
layer_tp, [inp, mask], optimizer_tp, self.fp8, self.sequence_parallel
)
loss_single, out_single = self._train_one_step(
layer_single, [inp, mask], optimizer_single, self.fp8
)
assert_allclose(out_tp, out_single, rtol=self.rtol, atol=self.atol)
assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol)
......@@ -173,8 +183,8 @@ class TestAttentionTpFp8(TestAttentionTp):
self.num_heads = 16
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.global_dtype = 'bfloat16'
self.mask_type = "padding"
self.global_dtype = "bfloat16"
self.rtol = 5e-2
self.atol = 5e-2
self.eps = 1e-3
......@@ -192,8 +202,8 @@ class TestAttentionSp(TestAttentionTp):
self.num_heads = 16
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.global_dtype = 'bfloat16'
self.mask_type = "padding"
self.global_dtype = "bfloat16"
self.rtol = 5e-3
self.atol = 5e-3
self.eps = 1e-3
......@@ -211,8 +221,8 @@ class TestAttentionSpFp8(TestAttentionTp):
self.num_heads = 16
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.global_dtype = 'bfloat16'
self.mask_type = "padding"
self.global_dtype = "bfloat16"
self.rtol = 5e-2
self.atol = 1e-1
self.eps = 1e-3
......@@ -220,5 +230,5 @@ class TestAttentionSpFp8(TestAttentionTp):
self.sequence_parallel = True
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
......@@ -8,7 +8,8 @@ import unittest
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import (
DygraphShardingOptimizer,)
DygraphShardingOptimizer,
)
from utils import assert_allclose, set_random_seed
import transformer_engine.paddle as te
......@@ -25,7 +26,7 @@ class TestGroupSharding(unittest.TestCase):
def set_attr(self):
"""Set test configs"""
self.sharding_degree = 2
self.global_dtype = 'float32'
self.global_dtype = "float32"
self.rtol = 1e-5
self.atol = 1e-5
self.batch_size = 16
......@@ -59,9 +60,10 @@ class TestGroupSharding(unittest.TestCase):
class ShardingLevel: # pylint: disable=too-few-public-methods,
"""Paddle sharding options"""
kStage1 = 'os'
kStage2 = 'os_g'
kStage3 = 'p_g_os'
kStage1 = "os"
kStage2 = "os_g"
kStage3 = "p_g_os"
level = ShardingLevel.kStage3 if stage == 3 else ShardingLevel.kStage2
model, optimizer, _ = paddle.distributed.sharding.group_sharded_parallel(
......@@ -104,8 +106,9 @@ class TestGroupSharding(unittest.TestCase):
loss_pd = train_one_step(model_pd, inp, optimizer_pd)
assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol)
assert len(optimizer_te.state_dict()) == 4, \
"Expect each rank to hold 4 optimizer state entries."
assert (
len(optimizer_te.state_dict()) == 4
), "Expect each rank to hold 4 optimizer state entries."
def test_group_sharding_stage2(self):
"""Tests group sharding training"""
......@@ -141,8 +144,9 @@ class TestGroupSharding(unittest.TestCase):
loss_pd = train_one_step(model_pd, inp, optimizer_pd)
assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol)
assert len(optimizer_te.state_dict()) == 4, \
"Expect each rank to hold 4 optimizer state entries."
assert (
len(optimizer_te.state_dict()) == 4
), "Expect each rank to hold 4 optimizer state entries."
def test_group_sharding_stage3(self):
"""Tests group sharding training"""
......@@ -174,11 +178,11 @@ class TestGroupSharding(unittest.TestCase):
assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol)
for name, value in optimizer_te.state_dict().items():
if name.endswith('w_0_moment1_0'):
assert value.numel() == \
self.in_channels * self.out_channels // self.sharding_degree, \
"Expect optimizer state to be sharded across trainers."
if name.endswith("w_0_moment1_0"):
assert (
value.numel() == self.in_channels * self.out_channels // self.sharding_degree
), "Expect optimizer state to be sharded across trainers."
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
......@@ -42,22 +42,22 @@ class TestLayerNormLinearTp(unittest.TestCase):
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = 'bfloat16'
self.global_dtype = "bfloat16"
self.rtol = 1e-3
self.atol = 1e-3
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = False
def _train_one_step(self, layer, inp, optimizer, split_input='none', gather_output=False):
def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False):
inp = paddle.to_tensor(inp, stop_gradient=True)
assert split_input in ['none', 'column', 'row']
if split_input == 'column':
assert split_input in ["none", "column", "row"]
if split_input == "column":
split_size = inp.shape[1] // self.world_size
input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)]
elif split_input == 'row':
input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)]
elif split_input == "row":
split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :]
input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :]
else:
input_parallel = inp
input_parallel.stop_gradient = False
......@@ -70,12 +70,12 @@ class TestLayerNormLinearTp(unittest.TestCase):
loss.backward()
optimizer.step()
optimizer.clear_grad()
if split_input != 'none':
if split_input != "none":
grad_input = []
paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group)
if split_input == 'column':
if split_input == "column":
grad_input = paddle.concat(grad_input, axis=1)
elif split_input == 'row':
elif split_input == "row":
grad_input = paddle.concat(grad_input, axis=0)
else:
grad_input = input_parallel.grad
......@@ -88,14 +88,14 @@ class TestLayerNormLinearTp(unittest.TestCase):
self.in_features,
self.out_features,
eps=self.eps,
parallel_mode='column',
parallel_mode="column",
sequence_parallel=self.sequence_parallel,
)
layer_pd = te.LayerNormLinear(
self.in_features,
self.out_features,
eps=self.eps,
backend='paddle',
backend="paddle",
)
# Get total weight
total_weight = []
......@@ -104,8 +104,9 @@ class TestLayerNormLinearTp(unittest.TestCase):
total_weight = paddle.concat(total_weight, axis=0)
layer_pd.weight.copy_(total_weight.T, True)
assert_shape(layer_te.weight,
[self.out_features // self.model_parallel_size, self.in_features])
assert_shape(
layer_te.weight, [self.out_features // self.model_parallel_size, self.in_features]
)
assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size])
optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
......@@ -121,8 +122,9 @@ class TestLayerNormLinearTp(unittest.TestCase):
layer_te,
inp,
optimizer_te,
split_input='row' if self.sequence_parallel else 'none',
gather_output=True)
split_input="row" if self.sequence_parallel else "none",
gather_output=True,
)
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
......@@ -136,7 +138,7 @@ class TestLayerNormLinearTpFp8(TestLayerNormLinearTp):
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = 'bfloat16'
self.global_dtype = "bfloat16"
self.rtol = 1e-2
self.atol = 1e-2
self.eps = 1e-3
......@@ -152,7 +154,7 @@ class TestLayerNormLinearSp(TestLayerNormLinearTp):
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = 'bfloat16'
self.global_dtype = "bfloat16"
self.rtol = 1e-3
self.atol = 1e-3
self.eps = 1e-3
......@@ -168,7 +170,7 @@ class TestLayerNormLinearSpFp8(TestLayerNormLinearTp):
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = 'bfloat16'
self.global_dtype = "bfloat16"
self.rtol = 1e-2
self.atol = 1e-2
self.eps = 1e-3
......@@ -176,5 +178,5 @@ class TestLayerNormLinearSpFp8(TestLayerNormLinearTp):
self.sequence_parallel = True
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
......@@ -42,22 +42,22 @@ class TestLayerNormMLPTp(unittest.TestCase):
self.batch_size = 16
self.hidden_size = 32
self.ffn_hidden_size = 64
self.global_dtype = 'bfloat16'
self.global_dtype = "bfloat16"
self.rtol = 1e-3
self.atol = 1e-3
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = False
def _train_one_step(self, layer, inp, optimizer, split_input='none', gather_output=False):
def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False):
inp = paddle.to_tensor(inp, stop_gradient=True)
assert split_input in ['none', 'column', 'row']
if split_input == 'column':
assert split_input in ["none", "column", "row"]
if split_input == "column":
split_size = inp.shape[1] // self.world_size
input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)]
elif split_input == 'row':
input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)]
elif split_input == "row":
split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :]
input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :]
else:
input_parallel = inp
input_parallel.stop_gradient = False
......@@ -71,12 +71,12 @@ class TestLayerNormMLPTp(unittest.TestCase):
loss.backward()
optimizer.step()
optimizer.clear_grad()
if split_input != 'none':
if split_input != "none":
grad_input = []
paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group)
if split_input == 'column':
if split_input == "column":
grad_input = paddle.concat(grad_input, axis=1)
elif split_input == 'row':
elif split_input == "row":
grad_input = paddle.concat(grad_input, axis=0)
else:
grad_input = input_parallel.grad
......@@ -97,7 +97,7 @@ class TestLayerNormMLPTp(unittest.TestCase):
ffn_hidden_size=self.ffn_hidden_size,
eps=self.eps,
set_parallel_mode=False,
backend='paddle',
backend="paddle",
)
def _get_total_weight(local_weight, tp_group, axis):
......@@ -113,11 +113,15 @@ class TestLayerNormMLPTp(unittest.TestCase):
layer_pd.fc1_weight.copy_(total_fc1_weight.T, True)
layer_pd.fc2_weight.copy_(total_fc2_weight.T, True)
assert_shape(layer_te.fc1_weight,
[self.ffn_hidden_size // self.model_parallel_size, self.hidden_size])
assert_shape(
layer_te.fc1_weight,
[self.ffn_hidden_size // self.model_parallel_size, self.hidden_size],
)
assert_shape(layer_te.fc1_bias, [self.ffn_hidden_size // self.model_parallel_size])
assert_shape(layer_te.fc2_weight,
[self.hidden_size, self.ffn_hidden_size // self.model_parallel_size])
assert_shape(
layer_te.fc2_weight,
[self.hidden_size, self.ffn_hidden_size // self.model_parallel_size],
)
assert_shape(layer_te.fc2_bias, [self.hidden_size])
optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
......@@ -133,8 +137,9 @@ class TestLayerNormMLPTp(unittest.TestCase):
layer_te,
inp,
optimizer_te,
split_input='row' if self.sequence_parallel else 'none',
gather_output=self.sequence_parallel)
split_input="row" if self.sequence_parallel else "none",
gather_output=self.sequence_parallel,
)
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
......@@ -148,7 +153,7 @@ class TestLayerNormMLPTpFp8(TestLayerNormMLPTp):
self.batch_size = 16
self.hidden_size = 32
self.ffn_hidden_size = 64
self.global_dtype = 'bfloat16'
self.global_dtype = "bfloat16"
self.rtol = 1e-2
self.atol = 1e-2
self.eps = 1e-3
......@@ -164,7 +169,7 @@ class TestLayerNormMLPSp(TestLayerNormMLPTp):
self.batch_size = 16
self.hidden_size = 32
self.ffn_hidden_size = 64
self.global_dtype = 'bfloat16'
self.global_dtype = "bfloat16"
self.rtol = 1e-3
self.atol = 1e-3
self.eps = 1e-3
......@@ -180,7 +185,7 @@ class TestLayerNormMLPSpFp8(TestLayerNormMLPTp):
self.batch_size = 16
self.hidden_size = 32
self.ffn_hidden_size = 64
self.global_dtype = 'bfloat16'
self.global_dtype = "bfloat16"
self.rtol = 1e-2
self.atol = 1e-2
self.eps = 1e-3
......@@ -188,5 +193,5 @@ class TestLayerNormMLPSpFp8(TestLayerNormMLPTp):
self.sequence_parallel = True
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
......@@ -23,14 +23,14 @@ class TELinear(te.Linear):
"""To pass is_first_microbatch"""
def __init__(self, *args, **kwargs):
assert 'accumulate_steps' in kwargs
self.accumulate_steps = kwargs['accumulate_steps']
del kwargs['accumulate_steps']
assert "accumulate_steps" in kwargs
self.accumulate_steps = kwargs["accumulate_steps"]
del kwargs["accumulate_steps"]
self._micro_batch_id = 0
super().__init__(*args, **kwargs)
def forward(self, *args, **kwargs):
kwargs['is_first_microbatch'] = (self._micro_batch_id % self.accumulate_steps) == 0
kwargs["is_first_microbatch"] = (self._micro_batch_id % self.accumulate_steps) == 0
if paddle.is_grad_enabled() and self.training:
self._micro_batch_id += 1
return super().forward(*args, **kwargs)
......@@ -39,14 +39,16 @@ class TELinear(te.Linear):
class TEPipelineModel(PipelineLayer):
"""Model for pipeline parallel test"""
def __init__(self,
def __init__(
self,
in_features,
hidden_features,
weight_attrs,
use_te=True,
use_fp8=False,
accumulate_steps=1,
**kwargs):
**kwargs,
):
self.in_features = in_features
self.hidden_features = hidden_features
self.fp8 = use_fp8
......@@ -56,19 +58,23 @@ class TEPipelineModel(PipelineLayer):
Linear = TELinear if use_te else paddle.nn.Linear
extra_kwargs = {}
if use_te:
extra_kwargs['accumulate_steps'] = accumulate_steps
extra_kwargs["accumulate_steps"] = accumulate_steps
model_desc = [
LayerDesc(Linear,
LayerDesc(
Linear,
self.in_features,
self.hidden_features,
weight_attr=weight_attrs[0],
**extra_kwargs),
LayerDesc(Linear,
**extra_kwargs,
),
LayerDesc(
Linear,
self.hidden_features,
self.in_features,
weight_attr=weight_attrs[1],
**extra_kwargs),
**extra_kwargs,
),
]
super().__init__(layers=model_desc, loss_fn=paddle.nn.CrossEntropyLoss(), **kwargs)
......@@ -129,7 +135,7 @@ class TestLinearPipelineParallel(unittest.TestCase):
self.micro_batch_size = 16
self.in_features = 32
self.hidden_features = 64
self.global_dtype = 'float32'
self.global_dtype = "float32"
self.rtol = 1e-5
self.atol = 1e-5
self.iter = 10
......@@ -164,16 +170,18 @@ class TestLinearPipelineParallel(unittest.TestCase):
# Check if model is split across ranks as expected
for name, sublayer in pipe_model.named_sublayers():
if name in ('_loss_fn', 'shared_layers'):
if name in ("_loss_fn", "shared_layers"):
continue
if self.rank == 0:
assert tuple(sublayer.weight.shape) == weight1_np.T.shape, \
f"Shape does not match, expect: {weight1_np.T.shape}, " \
assert tuple(sublayer.weight.shape) == weight1_np.T.shape, (
f"Shape does not match, expect: {weight1_np.T.shape}, "
f"actual: {tuple(sublayer.weight.shape)}"
)
elif self.rank == 1:
assert tuple(sublayer.weight.shape) == weight2_np.T.shape, \
f"Shape does not match, expect: {weight2_np.T.shape}, " \
assert tuple(sublayer.weight.shape) == weight2_np.T.shape, (
f"Shape does not match, expect: {weight2_np.T.shape}, "
f"actual: {tuple(sublayer.weight.shape)}"
)
standalone_model = StandaloneModel(
self.in_features,
......@@ -182,8 +190,9 @@ class TestLinearPipelineParallel(unittest.TestCase):
)
optimizer_te = paddle.optimizer.SGD(learning_rate=0.1, parameters=pipe_model.parameters())
optimizer_pd = paddle.optimizer.SGD(learning_rate=0.1,
parameters=standalone_model.parameters())
optimizer_pd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=standalone_model.parameters()
)
pipe_model = fleet.distributed_model(pipe_model)
optimizer_te = fleet.distributed_optimizer(optimizer_te)
......@@ -196,8 +205,9 @@ class TestLinearPipelineParallel(unittest.TestCase):
return loss
for i in range(self.iter):
inp = paddle.to_tensor(np.random.normal(size=[self.batch_size, self.in_features]),
dtype=self.global_dtype)
inp = paddle.to_tensor(
np.random.normal(size=[self.batch_size, self.in_features]), dtype=self.global_dtype
)
label = paddle.to_tensor(np.random.randint(self.in_features, size=[self.batch_size, 1]))
loss_te = pipe_model.train_batch([inp, label], optimizer_te)
loss_pd = train_one_step(standalone_model, [inp, label], optimizer_pd)
......@@ -214,12 +224,12 @@ class TestLinearPipelineParallelFP8(TestLinearPipelineParallel):
self.micro_batch_size = 16
self.in_features = 32
self.hidden_features = 64
self.global_dtype = 'float32'
self.global_dtype = "float32"
self.rtol = 5e-2
self.atol = 5e-2
self.iter = 10
self.fp8 = True
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
......@@ -42,21 +42,21 @@ class TestLinearTp(unittest.TestCase):
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = 'bfloat16'
self.global_dtype = "bfloat16"
self.rtol = 1e-3
self.atol = 1e-3
self.fp8 = False
self.sequence_parallel = False
def _train_one_step(self, layer, inp, optimizer, split_input='none', gather_output=False):
def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False):
inp = paddle.to_tensor(inp, stop_gradient=True)
assert split_input in ['none', 'column', 'row']
if split_input == 'column':
assert split_input in ["none", "column", "row"]
if split_input == "column":
split_size = inp.shape[1] // self.world_size
input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)]
elif split_input == 'row':
input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)]
elif split_input == "row":
split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :]
input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :]
else:
input_parallel = inp
input_parallel.stop_gradient = False
......@@ -69,12 +69,12 @@ class TestLinearTp(unittest.TestCase):
loss.backward()
optimizer.step()
optimizer.clear_grad()
if split_input != 'none':
if split_input != "none":
grad_input = []
paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group)
if split_input == 'column':
if split_input == "column":
grad_input = paddle.concat(grad_input, axis=1)
elif split_input == 'row':
elif split_input == "row":
grad_input = paddle.concat(grad_input, axis=0)
else:
grad_input = input_parallel.grad
......@@ -86,13 +86,13 @@ class TestLinearTp(unittest.TestCase):
layer_te = te.Linear(
self.in_features,
self.out_features,
parallel_mode='column',
parallel_mode="column",
sequence_parallel=self.sequence_parallel,
)
layer_pd = te.Linear(
self.in_features,
self.out_features,
backend='paddle',
backend="paddle",
)
# Get total weight
total_weight = []
......@@ -101,8 +101,9 @@ class TestLinearTp(unittest.TestCase):
total_weight = paddle.concat(total_weight, axis=0)
layer_pd.weight.copy_(total_weight.T, True)
assert_shape(layer_te.weight,
[self.out_features // self.model_parallel_size, self.in_features])
assert_shape(
layer_te.weight, [self.out_features // self.model_parallel_size, self.in_features]
)
assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size])
optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
......@@ -118,8 +119,9 @@ class TestLinearTp(unittest.TestCase):
layer_te,
inp,
optimizer_te,
split_input='row' if self.sequence_parallel else 'none',
gather_output=True)
split_input="row" if self.sequence_parallel else "none",
gather_output=True,
)
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
......@@ -130,13 +132,13 @@ class TestLinearTp(unittest.TestCase):
layer_te = te.Linear(
self.in_features,
self.out_features,
parallel_mode='row',
parallel_mode="row",
sequence_parallel=self.sequence_parallel,
)
layer_pd = te.Linear(
self.in_features,
self.out_features,
backend='paddle',
backend="paddle",
)
# Get total weight
total_weight = []
......@@ -145,8 +147,9 @@ class TestLinearTp(unittest.TestCase):
total_weight = paddle.concat(total_weight, axis=1)
layer_pd.weight.copy_(total_weight.T, True)
assert_shape(layer_te.weight,
[self.out_features, self.in_features // self.model_parallel_size])
assert_shape(
layer_te.weight, [self.out_features, self.in_features // self.model_parallel_size]
)
assert_shape(layer_te.bias, [self.out_features])
optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
......@@ -158,11 +161,13 @@ class TestLinearTp(unittest.TestCase):
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype)
with te.fp8_autocast(enabled=self.fp8):
loss_tp, grad_input = self._train_one_step(layer_te,
loss_tp, grad_input = self._train_one_step(
layer_te,
inp,
optimizer_te,
split_input='column',
gather_output=self.sequence_parallel)
split_input="column",
gather_output=self.sequence_parallel,
)
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
......@@ -176,7 +181,7 @@ class TestLinearTpFP8(TestLinearTp):
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = 'bfloat16'
self.global_dtype = "bfloat16"
self.rtol = 1e-2
self.atol = 1e-2
self.fp8 = True
......@@ -191,7 +196,7 @@ class TestLinearSp(TestLinearTp):
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = 'bfloat16'
self.global_dtype = "bfloat16"
self.rtol = 1e-3
self.atol = 1e-3
self.fp8 = False
......@@ -206,12 +211,12 @@ class TestLinearSpFP8(TestLinearTp):
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = 'bfloat16'
self.global_dtype = "bfloat16"
self.rtol = 1e-2
self.atol = 1e-2
self.fp8 = True
self.sequence_parallel = True
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
......@@ -45,9 +45,9 @@ class TestTransformerTp(unittest.TestCase):
self.ffn_hidden_size = 4096
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.layer_type = 'encoder'
self.global_dtype = 'bfloat16'
self.mask_type = "padding"
self.layer_type = "encoder"
self.global_dtype = "bfloat16"
self.rtol = 5e-2
self.atol = 5e-2
self.eps = 1e-3
......@@ -58,7 +58,7 @@ class TestTransformerTp(unittest.TestCase):
inp, mask = inp_list
if sequence_parallel:
split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :]
input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :]
else:
input_parallel = inp
with te.fp8_autocast(enabled=fp8_enabled):
......@@ -83,16 +83,18 @@ class TestTransformerTp(unittest.TestCase):
self.num_heads,
]
common_kwargs = {
'layernorm_epsilon': self.eps,
'hidden_dropout': 0.0,
'attention_dropout': 0.0,
'self_attn_mask_type': self.mask_type,
'layer_type': self.layer_type,
"layernorm_epsilon": self.eps,
"hidden_dropout": 0.0,
"attention_dropout": 0.0,
"self_attn_mask_type": self.mask_type,
"layer_type": self.layer_type,
}
layer_tp = te.TransformerLayer(*common_args,
layer_tp = te.TransformerLayer(
*common_args,
**common_kwargs,
set_parallel_mode=True,
sequence_parallel=self.sequence_parallel)
sequence_parallel=self.sequence_parallel,
)
layer_single = te.TransformerLayer(*common_args, **common_kwargs, set_parallel_mode=False)
def _get_total_weight(local_weight, tp_group, axis, interleave=False):
......@@ -103,12 +105,15 @@ class TestTransformerTp(unittest.TestCase):
# Due to the interleaved qkv layout, need to concat on num_head
# dimension for column parallel linear in MultiHeadAttention layer
assert axis == 0
assert [3 * self.hidden_size // self.world_size,
self.hidden_size] == partial_weight.shape
assert [
3 * self.hidden_size // self.world_size,
self.hidden_size,
] == partial_weight.shape
local_num_head = self.num_heads // self.world_size
for idx, _ in enumerate(total_weight):
total_weight[idx] = total_weight[idx].reshape(
[3, local_num_head, -1, self.hidden_size])
[3, local_num_head, -1, self.hidden_size]
)
total_weight = paddle.concat(total_weight, axis=1).reshape([-1, self.hidden_size])
else:
total_weight = paddle.concat(total_weight, axis=axis)
......@@ -124,48 +129,56 @@ class TestTransformerTp(unittest.TestCase):
weight_dst = _get_weight(layer_dst, weight_names)
if partition_mode is None:
total_weight = weight_src
elif partition_mode == 'column':
total_weight = _get_total_weight(weight_src,
tp_group=self.tp_group,
axis=0,
interleave=interleave)
elif partition_mode == 'row':
elif partition_mode == "column":
total_weight = _get_total_weight(
weight_src, tp_group=self.tp_group, axis=0, interleave=interleave
)
elif partition_mode == "row":
total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1)
else:
raise ValueError(f"Partition Mode {partition_mode} is not supported.")
assert weight_dst.shape == total_weight.shape, \
f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match."
assert (
weight_dst.shape == total_weight.shape
), f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match."
weight_dst.copy_(total_weight, True)
copy_weight(layer_tp, layer_single, None, ['self_attention', 'layernorm_qkv', 'ln_weight'])
copy_weight(layer_tp,
copy_weight(layer_tp, layer_single, None, ["self_attention", "layernorm_qkv", "ln_weight"])
copy_weight(
layer_tp,
layer_single,
'column', ['self_attention', 'layernorm_qkv', 'weight'],
interleave=True)
copy_weight(layer_tp, layer_single, 'row', ['self_attention', 'proj', 'weight'])
copy_weight(layer_tp, layer_single, None, ['layernorm_mlp', 'ln_weight'])
copy_weight(layer_tp, layer_single, 'column', ['layernorm_mlp', 'fc1_weight'])
copy_weight(layer_tp, layer_single, 'row', ['layernorm_mlp', 'fc2_weight'])
"column",
["self_attention", "layernorm_qkv", "weight"],
interleave=True,
)
copy_weight(layer_tp, layer_single, "row", ["self_attention", "proj", "weight"])
copy_weight(layer_tp, layer_single, None, ["layernorm_mlp", "ln_weight"])
copy_weight(layer_tp, layer_single, "column", ["layernorm_mlp", "fc1_weight"])
copy_weight(layer_tp, layer_single, "row", ["layernorm_mlp", "fc2_weight"])
if self.sequence_parallel:
register_sequence_parallel_allreduce_hooks(layer_tp, accumulation_steps=1)
optimizer_tp = paddle.optimizer.SGD(learning_rate=0.01, parameters=layer_tp.parameters())
optimizer_single = paddle.optimizer.SGD(learning_rate=0.01,
parameters=layer_single.parameters())
optimizer_single = paddle.optimizer.SGD(
learning_rate=0.01, parameters=layer_single.parameters()
)
layer_tp = fleet.distributed_model(layer_tp)
optimizer_tp = fleet.distributed_optimizer(optimizer_tp)
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.q_seqlen, self.hidden_size],
self.global_dtype)
mask = paddle.zeros(shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen),
dtype='bool')
loss_tp, out_tp = self._train_one_step(layer_tp, [inp, mask], optimizer_tp, self.fp8,
self.sequence_parallel)
loss_single, out_single = self._train_one_step(layer_single, [inp, mask],
optimizer_single, self.fp8)
inp = paddle.uniform(
[self.batch_size, self.q_seqlen, self.hidden_size], self.global_dtype
)
mask = paddle.zeros(
shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), dtype="bool"
)
loss_tp, out_tp = self._train_one_step(
layer_tp, [inp, mask], optimizer_tp, self.fp8, self.sequence_parallel
)
loss_single, out_single = self._train_one_step(
layer_single, [inp, mask], optimizer_single, self.fp8
)
assert_allclose(out_tp, out_single, rtol=self.rtol, atol=self.atol)
assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol)
......@@ -181,9 +194,9 @@ class TestTransformerTpFp8(TestTransformerTp):
self.ffn_hidden_size = 4096
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.layer_type = 'encoder'
self.global_dtype = 'bfloat16'
self.mask_type = "padding"
self.layer_type = "encoder"
self.global_dtype = "bfloat16"
self.rtol = 5e-2
self.atol = 0.5
self.eps = 1e-3
......@@ -202,9 +215,9 @@ class TestTransformerSp(TestTransformerTp):
self.ffn_hidden_size = 4096
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.layer_type = 'encoder'
self.global_dtype = 'bfloat16'
self.mask_type = "padding"
self.layer_type = "encoder"
self.global_dtype = "bfloat16"
self.rtol = 5e-2
self.atol = 5e-2
self.eps = 1e-3
......@@ -223,9 +236,9 @@ class TestTransformerSpFp8(TestTransformerSp):
self.ffn_hidden_size = 4096
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.layer_type = 'encoder'
self.global_dtype = 'bfloat16'
self.mask_type = "padding"
self.layer_type = "encoder"
self.global_dtype = "bfloat16"
self.rtol = 5e-2
self.atol = 0.5
self.eps = 1e-3
......@@ -233,5 +246,5 @@ class TestTransformerSpFp8(TestTransformerSp):
self.sequence_parallel = True
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
......@@ -37,14 +37,17 @@ def main():
enable_recompute = int(sys.argv[1])
use_reentrant = int(sys.argv[2])
layers = paddle.nn.LayerList([
layers = paddle.nn.LayerList(
[
te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_heads,
layer_type='encoder',
) for _ in range(num_layers)
])
layer_type="encoder",
)
for _ in range(num_layers)
]
)
model = Net(layers)
optimizer = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters())
......@@ -52,7 +55,7 @@ def main():
for _ in range(10):
inp = paddle.uniform([batch_size, q_seqlen, hidden_size])
inp.stop_gradient = False
mask = paddle.zeros(shape=(batch_size, 1, q_seqlen, kv_seqlen), dtype='bool')
mask = paddle.zeros(shape=(batch_size, 1, q_seqlen, kv_seqlen), dtype="bool")
with te.fp8_autocast(enabled=True):
out = model(inp, mask, enable_recompute, use_reentrant)
loss = out.mean()
......
......@@ -26,14 +26,14 @@ def setup():
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('use_fp8', [True, False])
@pytest.mark.parametrize("use_fp8", [True, False])
def test_checkpoint(use_fp8):
"""Test checkpoint save / load"""
bs = 16
in_features = 16
out_features = 32
file_name = "model.pdparams"
input_tensor = paddle.uniform(shape=(bs, in_features), dtype='float32')
input_tensor = paddle.uniform(shape=(bs, in_features), dtype="float32")
model = te.Linear(in_features, out_features)
model_loaded = te.Linear(in_features, out_features)
# Populate amax_history
......@@ -91,15 +91,18 @@ class TestLinear:
"""
@staticmethod
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="BF16 Linear requires Ampere+ GPU")
@pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES)
@pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize('no_dgrad', [True, False])
@pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
def test_linear_bf16(bs, in_features, out_features, has_bias, no_dbias, no_dgrad, no_wgrad,
activation_dtype):
@pytest.mark.skipif(
paddle.device.cuda.get_device_capability() < (8, 0),
reason="BF16 Linear requires Ampere+ GPU",
)
@pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
@pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize("no_dgrad", [True, False])
@pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
def test_linear_bf16(
bs, in_features, out_features, has_bias, no_dbias, no_dgrad, no_wgrad, activation_dtype
):
"""
Test BF16 Linear
"""
......@@ -112,10 +115,9 @@ class TestLinear:
paddle.set_default_dtype(activation_dtype)
layer_te = te.Linear(in_features, out_features, bias_attr=None if has_bias else False)
layer_pd = te.Linear(in_features,
out_features,
bias_attr=None if has_bias else False,
backend='paddle')
layer_pd = te.Linear(
in_features, out_features, bias_attr=None if has_bias else False, backend="paddle"
)
layer_pd.weight.copy_(layer_te.weight.T, True)
if has_bias:
layer_pd.bias.copy_(layer_te.bias, True)
......@@ -139,15 +141,25 @@ class TestLinear:
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES)
@pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize('no_dgrad', [True, False])
@pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('fp8_wgrad', [True, False])
@pytest.mark.parametrize('do_calibration', [True, False])
@pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
def test_linear_fp8(bs, in_features, out_features, has_bias, no_dbias, no_dgrad, no_wgrad,
fp8_wgrad, do_calibration, activation_dtype):
@pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
@pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize("no_dgrad", [True, False])
@pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize("fp8_wgrad", [True, False])
@pytest.mark.parametrize("do_calibration", [True, False])
@pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
def test_linear_fp8(
bs,
in_features,
out_features,
has_bias,
no_dbias,
no_dgrad,
no_wgrad,
fp8_wgrad,
do_calibration,
activation_dtype,
):
"""
Test FP8 Linear
"""
......@@ -170,7 +182,7 @@ class TestLinear:
in_features=in_features,
out_features=out_features,
bias_attr=None if has_bias else False,
backend='paddle',
backend="paddle",
)
layer_pd.weight.copy_(layer_te.weight.T, True)
if has_bias:
......@@ -182,8 +194,9 @@ class TestLinear:
layer_te.bias.stop_gradient = no_dbias
layer_pd.bias.stop_gradient = no_dbias
with fp8_autocast(enabled=not do_calibration, calibrating=do_calibration,
fp8_recipe=recipe):
with fp8_autocast(
enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe
):
out_ref, grad_input_ref = calc_output_and_grad(layer_pd, input_tensor, grad_out)
out, grad_input = calc_output_and_grad(layer_te, input_tensor, grad_out)
......@@ -199,9 +212,9 @@ class TestLinear:
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES)
@pytest.mark.parametrize('activation_dtype', ['bfloat16'])
@pytest.mark.parametrize('num_microbatch', [8])
@pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
@pytest.mark.parametrize("activation_dtype", ["bfloat16"])
@pytest.mark.parametrize("num_microbatch", [8])
def test_linear_fp8_microbatch(bs, in_features, out_features, activation_dtype, num_microbatch):
"""
Test FP8 Linear
......@@ -236,17 +249,16 @@ class TestLinear:
out_ref.backward(grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(layer_cached.weight.grad,
layer_normal.weight.grad,
rtol=rtol,
atol=atol)
assert_allclose(
layer_cached.weight.grad, layer_normal.weight.grad, rtol=rtol, atol=atol
)
@pytest.mark.parametrize('bs,hidden_size', NORM_CASES)
@pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize('no_dgrad', [True, False])
@pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
@pytest.mark.parametrize("bs,hidden_size", NORM_CASES)
@pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize("no_dgrad", [True, False])
@pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
def test_layernorm_bf16(bs, hidden_size, has_bias, no_dbias, no_dgrad, no_wgrad, activation_dtype):
"""
Test BF16 LayerNorm
......@@ -261,10 +273,9 @@ def test_layernorm_bf16(bs, hidden_size, has_bias, no_dbias, no_dgrad, no_wgrad,
paddle.set_default_dtype(activation_dtype)
layer_te = te.LayerNorm(hidden_size=hidden_size, eps=eps, bias_attr=None if has_bias else False)
layer_pd = te.LayerNorm(hidden_size=hidden_size,
eps=eps,
bias_attr=None if has_bias else False,
backend='paddle')
layer_pd = te.LayerNorm(
hidden_size=hidden_size, eps=eps, bias_attr=None if has_bias else False, backend="paddle"
)
layer_pd.weight.copy_(layer_te.weight, True)
if has_bias:
layer_pd.bias.copy_(layer_te.bias, True)
......@@ -293,17 +304,29 @@ class TestLayerNormLinear:
"""
@staticmethod
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="BF16 Linear requires Ampere+ GPU")
@pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES)
@pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize('no_dgrad', [True, False])
@pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('return_ln_out', [True, False])
@pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
@pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm'])
def test_layernorm_linear_bf16(bs, in_features, out_features, has_bias, no_dbias, no_dgrad,
no_wgrad, return_ln_out, activation_dtype, normalization):
@pytest.mark.skipif(
paddle.device.cuda.get_device_capability() < (8, 0),
reason="BF16 Linear requires Ampere+ GPU",
)
@pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
@pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize("no_dgrad", [True, False])
@pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize("return_ln_out", [True, False])
@pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
@pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
def test_layernorm_linear_bf16(
bs,
in_features,
out_features,
has_bias,
no_dbias,
no_dgrad,
no_wgrad,
return_ln_out,
activation_dtype,
normalization,
):
"""
Test BF16 LayerNormLinear Layer
"""
......@@ -315,7 +338,7 @@ class TestLayerNormLinear:
input_tensor.stop_gradient = no_dgrad
grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)
eps = 1e-3
has_ln_bias = normalization == 'LayerNorm'
has_ln_bias = normalization == "LayerNorm"
layer_te = te.LayerNormLinear(
in_features=in_features,
......@@ -333,7 +356,7 @@ class TestLayerNormLinear:
normalization=normalization,
bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out,
backend='paddle',
backend="paddle",
)
layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
......@@ -355,11 +378,11 @@ class TestLayerNormLinear:
layer_pd.bias.stop_gradient = no_dbias
out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out(
layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out)
out, ln_out, grad_input = calc_output_and_grad_ln_out(layer_te,
input_tensor,
grad_out,
return_ln_out=return_ln_out)
layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out
)
out, ln_out, grad_input = calc_output_and_grad_ln_out(
layer_te, input_tensor, grad_out, return_ln_out=return_ln_out
)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
if not no_dgrad:
......@@ -377,18 +400,29 @@ class TestLayerNormLinear:
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES)
@pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize('no_dgrad', [True, False])
@pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('fp8_wgrad', [True, False])
@pytest.mark.parametrize('do_calibration', [True, False])
@pytest.mark.parametrize('return_ln_out', [True, False])
@pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
@pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm'])
def test_layernorm_linear_fp8(bs, in_features, out_features, has_bias, no_dbias, no_dgrad,
no_wgrad, fp8_wgrad, do_calibration, return_ln_out,
activation_dtype, normalization):
@pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
@pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize("no_dgrad", [True, False])
@pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize("fp8_wgrad", [True, False])
@pytest.mark.parametrize("do_calibration", [True, False])
@pytest.mark.parametrize("return_ln_out", [True, False])
@pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
@pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
def test_layernorm_linear_fp8(
bs,
in_features,
out_features,
has_bias,
no_dbias,
no_dgrad,
no_wgrad,
fp8_wgrad,
do_calibration,
return_ln_out,
activation_dtype,
normalization,
):
"""
Test FP8 LayerNormLinear Layer
"""
......@@ -400,7 +434,7 @@ class TestLayerNormLinear:
input_tensor.stop_gradient = no_dgrad
grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)
eps = 1e-3
has_ln_bias = normalization == 'LayerNorm'
has_ln_bias = normalization == "LayerNorm"
recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad))
......@@ -420,7 +454,7 @@ class TestLayerNormLinear:
normalization=normalization,
bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out,
backend='paddle',
backend="paddle",
)
layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
......@@ -441,14 +475,15 @@ class TestLayerNormLinear:
layer_te.bias.stop_gradient = no_dbias
layer_pd.bias.stop_gradient = no_dbias
with fp8_autocast(enabled=not do_calibration, calibrating=do_calibration,
fp8_recipe=recipe):
with fp8_autocast(
enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe
):
out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out(
layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out)
out, ln_out, grad_input = calc_output_and_grad_ln_out(layer_te,
input_tensor,
grad_out,
return_ln_out=return_ln_out)
layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out
)
out, ln_out, grad_input = calc_output_and_grad_ln_out(
layer_te, input_tensor, grad_out, return_ln_out=return_ln_out
)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
if not no_dgrad:
......@@ -468,11 +503,12 @@ class TestLayerNormLinear:
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES)
@pytest.mark.parametrize('activation_dtype', ['bfloat16'])
@pytest.mark.parametrize('num_microbatch', [8])
def test_layernorm_linear_fp8_microbatch(bs, in_features, out_features, activation_dtype,
num_microbatch):
@pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
@pytest.mark.parametrize("activation_dtype", ["bfloat16"])
@pytest.mark.parametrize("num_microbatch", [8])
def test_layernorm_linear_fp8_microbatch(
bs, in_features, out_features, activation_dtype, num_microbatch
):
"""
Test FP8 LayerNormLinear Layer
"""
......@@ -513,14 +549,12 @@ class TestLayerNormLinear:
out_ref.backward(grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(layer_cached.weight.grad,
layer_normal.weight.grad,
rtol=rtol,
atol=atol)
assert_allclose(layer_cached.ln_weight.grad,
layer_normal.ln_weight.grad,
rtol=rtol,
atol=atol)
assert_allclose(
layer_cached.weight.grad, layer_normal.weight.grad, rtol=rtol, atol=atol
)
assert_allclose(
layer_cached.ln_weight.grad, layer_normal.ln_weight.grad, rtol=rtol, atol=atol
)
class TestLayerNormMLP:
......@@ -529,19 +563,31 @@ class TestLayerNormMLP:
"""
@staticmethod
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="BF16 Linear requires Ampere+ GPU")
@pytest.mark.parametrize('bs,hidden_size,ffn_hidden_size', LINEAR_CASES)
@pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize('no_dgrad', [True, False])
@pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('return_ln_out', [True, False])
@pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
@pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm'])
@pytest.mark.parametrize('activation', ['gelu', 'swiglu'])
def test_layernorm_mlp_bf16(bs, hidden_size, ffn_hidden_size, has_bias, no_dbias, no_dgrad,
no_wgrad, return_ln_out, activation_dtype, normalization,
activation):
@pytest.mark.skipif(
paddle.device.cuda.get_device_capability() < (8, 0),
reason="BF16 Linear requires Ampere+ GPU",
)
@pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES)
@pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize("no_dgrad", [True, False])
@pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize("return_ln_out", [True, False])
@pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
@pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
@pytest.mark.parametrize("activation", ["gelu", "swiglu"])
def test_layernorm_mlp_bf16(
bs,
hidden_size,
ffn_hidden_size,
has_bias,
no_dbias,
no_dgrad,
no_wgrad,
return_ln_out,
activation_dtype,
normalization,
activation,
):
"""
Tests for TestLayerNormMLP layer
"""
......@@ -553,7 +599,7 @@ class TestLayerNormMLP:
input_tensor.stop_gradient = no_dgrad
grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
eps = 1e-3
has_ln_bias = normalization == 'LayerNorm'
has_ln_bias = normalization == "LayerNorm"
layer_te = te.LayerNormMLP(
hidden_size=hidden_size,
......@@ -572,7 +618,7 @@ class TestLayerNormMLP:
activation=activation,
bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out,
backend='paddle',
backend="paddle",
)
layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
if has_ln_bias:
......@@ -599,55 +645,63 @@ class TestLayerNormMLP:
layer_pd.fc2_bias.stop_gradient = no_dbias
out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out(
layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out)
out, ln_out, grad_input = calc_output_and_grad_ln_out(layer_te,
input_tensor,
grad_out,
return_ln_out=return_ln_out)
layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out
)
out, ln_out, grad_input = calc_output_and_grad_ln_out(
layer_te, input_tensor, grad_out, return_ln_out=return_ln_out
)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
if not no_dgrad:
assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
if not no_wgrad:
assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol)
assert_allclose(layer_te.fc1_weight.grad,
layer_pd.fc1_weight.grad.T,
rtol=rtol,
atol=atol)
assert_allclose(layer_te.fc2_weight.grad,
layer_pd.fc2_weight.grad.T,
rtol=rtol,
atol=atol)
assert_allclose(
layer_te.fc1_weight.grad, layer_pd.fc1_weight.grad.T, rtol=rtol, atol=atol
)
assert_allclose(
layer_te.fc2_weight.grad, layer_pd.fc2_weight.grad.T, rtol=rtol, atol=atol
)
if not no_dbias:
if has_ln_bias:
assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol)
if has_bias:
assert_allclose(layer_te.fc1_bias.grad,
layer_pd.fc1_bias.grad,
rtol=rtol,
atol=atol)
assert_allclose(layer_te.fc2_bias.grad,
layer_pd.fc2_bias.grad,
rtol=rtol,
atol=atol)
assert_allclose(
layer_te.fc1_bias.grad, layer_pd.fc1_bias.grad, rtol=rtol, atol=atol
)
assert_allclose(
layer_te.fc2_bias.grad, layer_pd.fc2_bias.grad, rtol=rtol, atol=atol
)
if return_ln_out:
assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('bs,hidden_size,ffn_hidden_size', LINEAR_CASES)
@pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize('no_dgrad', [True, False])
@pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('fp8_wgrad', [True, False])
@pytest.mark.parametrize('do_calibration', [True, False])
@pytest.mark.parametrize('return_ln_out', [True, False])
@pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
@pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm'])
@pytest.mark.parametrize('activation', ['gelu', 'swiglu'])
def test_layernorm_mlp_fp8(bs, hidden_size, ffn_hidden_size, has_bias, no_dbias, no_dgrad,
no_wgrad, fp8_wgrad, do_calibration, return_ln_out, activation_dtype,
normalization, activation):
@pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES)
@pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize("no_dgrad", [True, False])
@pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize("fp8_wgrad", [True, False])
@pytest.mark.parametrize("do_calibration", [True, False])
@pytest.mark.parametrize("return_ln_out", [True, False])
@pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
@pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
@pytest.mark.parametrize("activation", ["gelu", "swiglu"])
def test_layernorm_mlp_fp8(
bs,
hidden_size,
ffn_hidden_size,
has_bias,
no_dbias,
no_dgrad,
no_wgrad,
fp8_wgrad,
do_calibration,
return_ln_out,
activation_dtype,
normalization,
activation,
):
"""
Test FP8 LayerNormMLP Layer
"""
......@@ -659,7 +713,7 @@ class TestLayerNormMLP:
input_tensor.stop_gradient = no_dgrad
grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
eps = 1e-3
has_ln_bias = normalization == 'LayerNorm'
has_ln_bias = normalization == "LayerNorm"
recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad))
......@@ -681,7 +735,7 @@ class TestLayerNormMLP:
activation=activation,
bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out,
backend='paddle',
backend="paddle",
)
layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
if has_ln_bias:
......@@ -707,40 +761,37 @@ class TestLayerNormMLP:
layer_pd.fc1_bias.stop_gradient = no_dbias
layer_pd.fc2_bias.stop_gradient = no_dbias
with fp8_autocast(enabled=not do_calibration, calibrating=do_calibration,
fp8_recipe=recipe):
with fp8_autocast(
enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe
):
out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out(
layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out)
out, ln_out, grad_input = calc_output_and_grad_ln_out(layer_te,
input_tensor,
grad_out,
return_ln_out=return_ln_out)
layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out
)
out, ln_out, grad_input = calc_output_and_grad_ln_out(
layer_te, input_tensor, grad_out, return_ln_out=return_ln_out
)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
if not no_dgrad:
assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
if not no_wgrad:
assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol)
assert_allclose(layer_te.fc1_weight.grad,
layer_pd.fc1_weight.grad.T,
rtol=rtol,
atol=atol)
assert_allclose(layer_te.fc2_weight.grad,
layer_pd.fc2_weight.grad.T,
rtol=rtol,
atol=atol)
assert_allclose(
layer_te.fc1_weight.grad, layer_pd.fc1_weight.grad.T, rtol=rtol, atol=atol
)
assert_allclose(
layer_te.fc2_weight.grad, layer_pd.fc2_weight.grad.T, rtol=rtol, atol=atol
)
if not no_dbias:
if has_ln_bias:
assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol)
if has_bias:
assert_allclose(layer_te.fc1_bias.grad,
layer_pd.fc1_bias.grad,
rtol=rtol,
atol=atol)
assert_allclose(layer_te.fc2_bias.grad,
layer_pd.fc2_bias.grad,
rtol=rtol,
atol=atol)
assert_allclose(
layer_te.fc1_bias.grad, layer_pd.fc1_bias.grad, rtol=rtol, atol=atol
)
assert_allclose(
layer_te.fc2_bias.grad, layer_pd.fc2_bias.grad, rtol=rtol, atol=atol
)
if return_ln_out:
assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol)
......@@ -749,11 +800,12 @@ class TestLayerNormMLP:
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('bs,hidden_size,ffn_hidden_size', LINEAR_CASES)
@pytest.mark.parametrize('activation_dtype', ['bfloat16'])
@pytest.mark.parametrize('num_microbatch', [8])
def test_layernorm_mlp_fp8_microbatch(bs, hidden_size, ffn_hidden_size, activation_dtype,
num_microbatch):
@pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES)
@pytest.mark.parametrize("activation_dtype", ["bfloat16"])
@pytest.mark.parametrize("num_microbatch", [8])
def test_layernorm_mlp_fp8_microbatch(
bs, hidden_size, ffn_hidden_size, activation_dtype, num_microbatch
):
"""
Test FP8 LayerNormMLP Layer
"""
......@@ -803,28 +855,26 @@ class TestLayerNormMLP:
out_ref.backward(grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(layer_cached.ln_weight.grad,
layer_normal.ln_weight.grad,
rtol=rtol,
atol=atol)
assert_allclose(layer_cached.fc1_weight.grad,
layer_normal.fc1_weight.grad,
rtol=rtol,
atol=atol)
assert_allclose(layer_cached.fc2_weight.grad,
layer_normal.fc2_weight.grad,
rtol=rtol,
atol=atol)
assert_allclose(
layer_cached.ln_weight.grad, layer_normal.ln_weight.grad, rtol=rtol, atol=atol
)
assert_allclose(
layer_cached.fc1_weight.grad, layer_normal.fc1_weight.grad, rtol=rtol, atol=atol
)
assert_allclose(
layer_cached.fc2_weight.grad, layer_normal.fc2_weight.grad, rtol=rtol, atol=atol
)
@pytest.mark.parametrize('bs', [1, 2])
@pytest.mark.parametrize('hidden_size, num_heads', [[1024, 16]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[1024, 1024]])
@pytest.mark.parametrize('attn_type', ['self', 'cross'])
@pytest.mark.parametrize('mask_type', ['causal', 'padding'])
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, attn_type,
mask_type, math_dtype):
@pytest.mark.parametrize("bs", [1, 2])
@pytest.mark.parametrize("hidden_size, num_heads", [[1024, 16]])
@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]])
@pytest.mark.parametrize("attn_type", ["self", "cross"])
@pytest.mark.parametrize("mask_type", ["causal", "padding"])
@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"])
def test_dot_product_attention(
bs, hidden_size, num_heads, q_seqlen, kv_seqlen, attn_type, mask_type, math_dtype
):
"""
Test DotProductAttention Layer
"""
......@@ -848,40 +898,51 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
):
pytest.skip("cuDNN fused attention is not supported")
attn_q_input = paddle.normal(mean=0.0, std=0.02,
shape=(bs, q_seqlen, num_heads, head_size)).astype(math_dtype)
attn_k_input = paddle.normal(mean=0.0, std=0.02,
shape=(bs, kv_seqlen, num_heads, head_size)).astype(math_dtype)
attn_v_input = paddle.normal(mean=0.0, std=0.02,
shape=(bs, kv_seqlen, num_heads, head_size)).astype(math_dtype)
q_actual_seqlen = paddle.randint(low=20, high=q_seqlen, shape=(bs,), dtype='int32')
kv_actual_seqlen = paddle.randint(low=20, high=kv_seqlen, shape=(bs,),
dtype='int32') if attn_type == 'cross' else q_actual_seqlen
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool')
attn_q_input = paddle.normal(
mean=0.0, std=0.02, shape=(bs, q_seqlen, num_heads, head_size)
).astype(math_dtype)
attn_k_input = paddle.normal(
mean=0.0, std=0.02, shape=(bs, kv_seqlen, num_heads, head_size)
).astype(math_dtype)
attn_v_input = paddle.normal(
mean=0.0, std=0.02, shape=(bs, kv_seqlen, num_heads, head_size)
).astype(math_dtype)
q_actual_seqlen = paddle.randint(low=20, high=q_seqlen, shape=(bs,), dtype="int32")
kv_actual_seqlen = (
paddle.randint(low=20, high=kv_seqlen, shape=(bs,), dtype="int32")
if attn_type == "cross"
else q_actual_seqlen
)
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool")
grad_out = paddle.normal(mean=0.0, std=0.02,
shape=(bs, q_seqlen, num_heads, head_size)).astype('float32')
grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, num_heads, head_size)).astype(
"float32"
)
for i in range(0, bs):
grad_out[i, q_actual_seqlen[i]:, :, :] = 0
grad_out[i, q_actual_seqlen[i] :, :, :] = 0
grad_out = grad_out.astype(math_dtype)
for i in range(0, bs):
attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False
attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False
head_size = hidden_size // num_heads
layer_te = te.DotProductAttention(num_heads,
layer_te = te.DotProductAttention(
num_heads,
head_size,
attention_dropout=0.0,
attn_mask_type=mask_type,
attention_type=attn_type,
backend='transformer_engine')
layer_pd = te.DotProductAttention(num_heads,
backend="transformer_engine",
)
layer_pd = te.DotProductAttention(
num_heads,
head_size,
attention_dropout=0.0,
attn_mask_type=mask_type,
attention_type=attn_type,
backend='paddle')
backend="paddle",
)
def calc_attn_output_and_grad(layer, q, k, v, mask, dout):
_q = paddle.to_tensor(q, stop_gradient=False)
......@@ -892,23 +953,29 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
out.backward(dout)
return out, _q.grad, _k.grad, _v.grad
out, q_grad, k_grad, v_grad = calc_attn_output_and_grad(layer_te, attn_q_input, attn_k_input,
attn_v_input, attn_mask, grad_out)
out, q_grad, k_grad, v_grad = calc_attn_output_and_grad(
layer_te, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out
)
out_ref, q_grad_ref, k_grad_ref, v_grad_ref = calc_attn_output_and_grad(
layer_pd, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out)
layer_pd, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out
)
valid_out_ref = paddle.full_like(out_ref, 0)
for i in range(0, bs):
valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :]
valid_out_ref[i, 0 : q_actual_seqlen[i], :, :] = out_ref[i, 0 : q_actual_seqlen[i], :, :]
valid_q_grad_ref = paddle.full_like(q_grad_ref, 0)
valid_k_grad_ref = paddle.full_like(k_grad_ref, 0)
valid_v_grad_ref = paddle.full_like(v_grad_ref, 0)
for i in range(0, bs):
valid_q_grad_ref[i, 0:q_actual_seqlen[i], :, :] = q_grad_ref[i, 0:q_actual_seqlen[i], :, :]
valid_k_grad_ref[i, 0:kv_actual_seqlen[i], :, :] = k_grad_ref[i,
0:kv_actual_seqlen[i], :, :]
valid_v_grad_ref[i, 0:kv_actual_seqlen[i], :, :] = v_grad_ref[i,
0:kv_actual_seqlen[i], :, :]
valid_q_grad_ref[i, 0 : q_actual_seqlen[i], :, :] = q_grad_ref[
i, 0 : q_actual_seqlen[i], :, :
]
valid_k_grad_ref[i, 0 : kv_actual_seqlen[i], :, :] = k_grad_ref[
i, 0 : kv_actual_seqlen[i], :, :
]
valid_v_grad_ref[i, 0 : kv_actual_seqlen[i], :, :] = v_grad_ref[
i, 0 : kv_actual_seqlen[i], :, :
]
assert_allclose(out, valid_out_ref, rtol=rtol, atol=atol)
assert_allclose(q_grad, valid_q_grad_ref, rtol=rtol, atol=atol)
......@@ -916,21 +983,34 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
assert_allclose(v_grad, valid_v_grad_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize('bs', [1, 2])
@pytest.mark.parametrize('num_gqa_groups', [1, 2, 4])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[256, 4, 1024]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[1024, 1024]])
@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]])
@pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('mask_type', ['causal', 'padding'])
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
@pytest.mark.parametrize('output_layernorm', [True, False])
@pytest.mark.parametrize('return_layernorm_output', [True, False])
@pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm'])
def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size,
has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type,
math_dtype, output_layernorm, return_layernorm_output,
normalization):
@pytest.mark.parametrize("bs", [1, 2])
@pytest.mark.parametrize("num_gqa_groups", [1, 2, 4])
@pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[256, 4, 1024]])
@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]])
@pytest.mark.parametrize("has_bias, no_dbias", [[False, True], [True, True], [True, False]])
@pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize("mask_type", ["causal", "padding"])
@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"])
@pytest.mark.parametrize("output_layernorm", [True, False])
@pytest.mark.parametrize("return_layernorm_output", [True, False])
@pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
def test_transformer_encoder_layer(
bs,
hidden_size,
num_heads,
num_gqa_groups,
ffn_hidden_size,
has_bias,
no_dbias,
no_wgrad,
q_seqlen,
kv_seqlen,
mask_type,
math_dtype,
output_layernorm,
return_layernorm_output,
normalization,
):
"""
Test Transformer Encoder Layer
"""
......@@ -938,7 +1018,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
rtol = 5e-2
atol = 5e-2
eps = 1e-3
has_ln_bias = normalization == 'LayerNorm'
has_ln_bias = normalization == "LayerNorm"
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
......@@ -957,20 +1037,22 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)
q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen
q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen
kv_actual_seqlen = q_actual_seqlen
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool')
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool")
grad_out = paddle.normal(mean=0.0, std=0.02,
shape=(bs, q_seqlen, hidden_size)).astype('float32')
grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, hidden_size)).astype(
"float32"
)
for i in range(0, bs):
grad_out[i, q_actual_seqlen[i]:, :] = 0
grad_out[i, q_actual_seqlen[i] :, :] = 0
grad_out = grad_out.astype(math_dtype)
for i in range(0, bs):
attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False
attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False
layer_te = te.TransformerLayer(hidden_size,
layer_te = te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_heads,
num_gqa_groups=num_gqa_groups,
......@@ -982,10 +1064,12 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
self_attn_mask_type=mask_type,
apply_residual_connection_post_layernorm=return_layernorm_output,
output_layernorm=output_layernorm,
layer_type='encoder',
layer_type="encoder",
normalization=normalization,
backend='transformer_engine')
layer_pd = te.TransformerLayer(hidden_size,
backend="transformer_engine",
)
layer_pd = te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_heads,
num_gqa_groups=num_gqa_groups,
......@@ -997,9 +1081,10 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
self_attn_mask_type=mask_type,
apply_residual_connection_post_layernorm=return_layernorm_output,
output_layernorm=output_layernorm,
layer_type='encoder',
layer_type="encoder",
normalization=normalization,
backend='paddle')
backend="paddle",
)
# MultiHeadAttention params
if output_layernorm:
......@@ -1012,21 +1097,25 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
layer_te.self_attention.qkv.bias.stop_gradient = no_dbias
else:
layer_pd.self_attention.layernorm_qkv.ln_weight.copy_(
layer_te.self_attention.layernorm_qkv.ln_weight, True)
layer_te.self_attention.layernorm_qkv.ln_weight, True
)
layer_pd.self_attention.layernorm_qkv.weight.copy_(
layer_te.self_attention.layernorm_qkv.weight.T, True)
layer_te.self_attention.layernorm_qkv.weight.T, True
)
layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
if has_ln_bias:
layer_pd.self_attention.layernorm_qkv.ln_bias.copy_(
layer_te.self_attention.layernorm_qkv.ln_bias, True)
layer_te.self_attention.layernorm_qkv.ln_bias, True
)
layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
if has_bias:
layer_pd.self_attention.layernorm_qkv.bias.copy_(
layer_te.self_attention.layernorm_qkv.bias, True)
layer_te.self_attention.layernorm_qkv.bias, True
)
layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias
layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias
......@@ -1074,52 +1163,75 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
out.backward(dout)
return out, _encoder_input.grad
out_ref, grad_input_ref = calc_transformer_output_and_grad(layer_pd, encoder_input, attn_mask,
grad_out)
out_ref, grad_input_ref = calc_transformer_output_and_grad(
layer_pd, encoder_input, attn_mask, grad_out
)
out, grad_input = calc_transformer_output_and_grad(layer_te, encoder_input, attn_mask, grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
if not no_wgrad:
if output_layernorm:
assert_allclose(layer_te.self_attention.qkv.weight.grad,
assert_allclose(
layer_te.self_attention.qkv.weight.grad,
layer_pd.self_attention.qkv.weight.grad.T,
rtol=rtol,
atol=atol)
atol=atol,
)
else:
assert_allclose(layer_te.self_attention.layernorm_qkv.weight.grad,
assert_allclose(
layer_te.self_attention.layernorm_qkv.weight.grad,
layer_pd.self_attention.layernorm_qkv.weight.grad.T,
rtol=rtol,
atol=atol)
atol=atol,
)
if not no_dbias:
if output_layernorm:
assert_allclose(layer_te.self_attention.qkv.bias.grad,
assert_allclose(
layer_te.self_attention.qkv.bias.grad,
layer_pd.self_attention.qkv.bias.grad,
rtol=0.01,
atol=0.5)
atol=0.5,
)
else:
assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad,
assert_allclose(
layer_te.self_attention.layernorm_qkv.bias.grad,
layer_pd.self_attention.layernorm_qkv.bias.grad,
rtol=0.01,
atol=0.5)
@pytest.mark.parametrize('bs', [1, 2])
@pytest.mark.parametrize('num_gqa_groups', [1, 2, 4])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[256, 4, 1024]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[1024, 1024]])
@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]])
@pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('mask_type', ['causal', 'padding'])
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
@pytest.mark.parametrize('output_layernorm', [True, False])
@pytest.mark.parametrize('return_layernorm_output', [True, False])
@pytest.mark.parametrize('recompute_core_attention', [True, False])
@pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm'])
def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size,
has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type,
math_dtype, output_layernorm, return_layernorm_output,
recompute_core_attention, normalization):
atol=0.5,
)
@pytest.mark.parametrize("bs", [1, 2])
@pytest.mark.parametrize("num_gqa_groups", [1, 2, 4])
@pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[256, 4, 1024]])
@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]])
@pytest.mark.parametrize("has_bias, no_dbias", [[False, True], [True, True], [True, False]])
@pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize("mask_type", ["causal", "padding"])
@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"])
@pytest.mark.parametrize("output_layernorm", [True, False])
@pytest.mark.parametrize("return_layernorm_output", [True, False])
@pytest.mark.parametrize("recompute_core_attention", [True, False])
@pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
def test_transformer_decoder_layer(
bs,
hidden_size,
num_heads,
num_gqa_groups,
ffn_hidden_size,
has_bias,
no_dbias,
no_wgrad,
q_seqlen,
kv_seqlen,
mask_type,
math_dtype,
output_layernorm,
return_layernorm_output,
recompute_core_attention,
normalization,
):
"""
Test Transformer Decoder Layer
"""
......@@ -1127,7 +1239,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
rtol = 5e-2
atol = 6e-2
eps = 1e-3
has_ln_bias = normalization == 'LayerNorm'
has_ln_bias = normalization == "LayerNorm"
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
......@@ -1144,17 +1256,20 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
):
pytest.skip("cuDNN fused attention is not supported")
encoder_input = paddle.normal(mean=0.0, std=0.1,
shape=(bs, q_seqlen, hidden_size)).astype(math_dtype)
encoder_output = paddle.normal(mean=0.0, std=0.1,
shape=(bs, kv_seqlen, hidden_size)).astype(math_dtype)
encoder_input = paddle.normal(mean=0.0, std=0.1, shape=(bs, q_seqlen, hidden_size)).astype(
math_dtype
)
encoder_output = paddle.normal(mean=0.0, std=0.1, shape=(bs, kv_seqlen, hidden_size)).astype(
math_dtype
)
q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen
q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen
kv_actual_seqlen = q_actual_seqlen
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool')
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool")
grad_out = paddle.normal(mean=0.0, std=0.01,
shape=(bs, q_seqlen, hidden_size)).astype('float32')
grad_out = paddle.normal(mean=0.0, std=0.01, shape=(bs, q_seqlen, hidden_size)).astype(
"float32"
)
# rounding to avoid numerical issues
encoder_input = paddle.round(encoder_input * 1000) / 1000
......@@ -1162,13 +1277,14 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
grad_out = paddle.round(grad_out * 1000) / 1000
for i in range(0, bs):
grad_out[i, q_actual_seqlen[i]:, :] = 0
grad_out[i, q_actual_seqlen[i] :, :] = 0
grad_out = grad_out.astype(math_dtype)
for i in range(0, bs):
attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False
attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False
layer_te = te.TransformerLayer(hidden_size,
layer_te = te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_heads,
num_gqa_groups=num_gqa_groups,
......@@ -1180,10 +1296,12 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
self_attn_mask_type=mask_type,
apply_residual_connection_post_layernorm=return_layernorm_output,
output_layernorm=output_layernorm,
layer_type='decoder',
layer_type="decoder",
normalization=normalization,
backend='transformer_engine')
layer_pd = te.TransformerLayer(hidden_size,
backend="transformer_engine",
)
layer_pd = te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_heads,
num_gqa_groups=num_gqa_groups,
......@@ -1195,9 +1313,10 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
self_attn_mask_type=mask_type,
apply_residual_connection_post_layernorm=return_layernorm_output,
output_layernorm=output_layernorm,
layer_type='decoder',
layer_type="decoder",
normalization=normalization,
backend='paddle')
backend="paddle",
)
# MultiHeadAttention params - self attn
if output_layernorm:
......@@ -1210,21 +1329,25 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
layer_te.self_attention.qkv.bias.stop_gradient = no_dbias
else:
layer_pd.self_attention.layernorm_qkv.ln_weight.copy_(
layer_te.self_attention.layernorm_qkv.ln_weight, True)
layer_te.self_attention.layernorm_qkv.ln_weight, True
)
layer_pd.self_attention.layernorm_qkv.weight.copy_(
layer_te.self_attention.layernorm_qkv.weight.T, True)
layer_te.self_attention.layernorm_qkv.weight.T, True
)
layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
if has_ln_bias:
layer_pd.self_attention.layernorm_qkv.ln_bias.copy_(
layer_te.self_attention.layernorm_qkv.ln_bias, True)
layer_te.self_attention.layernorm_qkv.ln_bias, True
)
layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
if has_bias:
layer_pd.self_attention.layernorm_qkv.bias.copy_(
layer_te.self_attention.layernorm_qkv.bias, True)
layer_te.self_attention.layernorm_qkv.bias, True
)
layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias
layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias
......@@ -1238,26 +1361,31 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
# MultiHeadAttention params - cross attn
layer_pd.inter_attention.layernorm_query.ln_weight.copy_(
layer_te.inter_attention.layernorm_query.ln_weight, True)
layer_te.inter_attention.layernorm_query.ln_weight, True
)
layer_pd.inter_attention.layernorm_query.weight.copy_(
layer_te.inter_attention.layernorm_query.weight.T, True)
layer_te.inter_attention.layernorm_query.weight.T, True
)
layer_pd.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad
layer_pd.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad
layer_te.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad
layer_te.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad
if has_ln_bias:
layer_pd.inter_attention.layernorm_query.ln_bias.copy_(
layer_te.inter_attention.layernorm_query.ln_bias, True)
layer_te.inter_attention.layernorm_query.ln_bias, True
)
layer_pd.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias
layer_te.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias
if has_bias:
layer_pd.inter_attention.layernorm_query.bias.copy_(
layer_te.inter_attention.layernorm_query.bias, True)
layer_te.inter_attention.layernorm_query.bias, True
)
layer_pd.inter_attention.layernorm_query.bias.stop_gradient = no_dbias
layer_te.inter_attention.layernorm_query.bias.stop_gradient = no_dbias
layer_pd.inter_attention.key_value.weight.copy_(layer_te.inter_attention.key_value.weight.T,
True)
layer_pd.inter_attention.key_value.weight.copy_(
layer_te.inter_attention.key_value.weight.T, True
)
layer_pd.inter_attention.key_value.weight.stop_gradient = no_wgrad
layer_te.inter_attention.key_value.weight.stop_gradient = no_wgrad
layer_pd.inter_attention.proj.weight.copy_(layer_te.inter_attention.proj.weight.T, True)
......@@ -1301,25 +1429,30 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
layer_te.layernorm.weight.stop_gradient = no_wgrad
layer_te.layernorm.bias.stop_gradient = no_dbias
def calc_transformer_output_and_grad(layer,
def calc_transformer_output_and_grad(
layer,
encoder_input,
mask,
encoder_output,
enc_dec_attn_mask,
dout,
recompute_core_attention=False):
recompute_core_attention=False,
):
_encoder_input = paddle.to_tensor(encoder_input, stop_gradient=False)
_encoder_output = paddle.to_tensor(encoder_output, stop_gradient=False)
out = layer(_encoder_input,
out = layer(
_encoder_input,
mask,
_encoder_output,
enc_dec_attn_mask,
recompute_core_attention=recompute_core_attention)
recompute_core_attention=recompute_core_attention,
)
out.backward(dout)
return out, _encoder_input.grad, _encoder_output.grad
out_ref, grad_encoder_input_ref, grad_encoder_output_ref = calc_transformer_output_and_grad(
layer_pd, encoder_input, attn_mask, encoder_output, attn_mask, grad_out)
layer_pd, encoder_input, attn_mask, encoder_output, attn_mask, grad_out
)
out, grad_encoder_input, grad_encoder_output = calc_transformer_output_and_grad(
layer_te,
encoder_input,
......@@ -1327,52 +1460,74 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
encoder_output,
attn_mask,
grad_out,
recompute_core_attention=recompute_core_attention)
recompute_core_attention=recompute_core_attention,
)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(grad_encoder_input, grad_encoder_input_ref, rtol=rtol, atol=atol)
assert_allclose(grad_encoder_output, grad_encoder_output_ref, rtol=rtol, atol=atol)
if not no_wgrad:
if output_layernorm:
assert_allclose(layer_te.self_attention.qkv.weight.grad,
assert_allclose(
layer_te.self_attention.qkv.weight.grad,
layer_pd.self_attention.qkv.weight.grad.T,
rtol=rtol,
atol=atol)
atol=atol,
)
else:
assert_allclose(layer_te.self_attention.layernorm_qkv.weight.grad,
assert_allclose(
layer_te.self_attention.layernorm_qkv.weight.grad,
layer_pd.self_attention.layernorm_qkv.weight.grad.T,
rtol=rtol,
atol=atol)
assert_allclose(layer_te.inter_attention.layernorm_query.weight.grad,
atol=atol,
)
assert_allclose(
layer_te.inter_attention.layernorm_query.weight.grad,
layer_pd.inter_attention.layernorm_query.weight.grad.T,
rtol=rtol,
atol=atol)
atol=atol,
)
if not no_dbias:
if output_layernorm:
assert_allclose(layer_te.self_attention.qkv.bias.grad,
assert_allclose(
layer_te.self_attention.qkv.bias.grad,
layer_pd.self_attention.qkv.bias.grad,
rtol=0.5,
atol=0.6)
atol=0.6,
)
else:
assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad,
assert_allclose(
layer_te.self_attention.layernorm_qkv.bias.grad,
layer_pd.self_attention.layernorm_qkv.bias.grad,
rtol=0.01,
atol=0.5)
assert_allclose(layer_te.inter_attention.layernorm_query.bias.grad,
atol=0.5,
)
assert_allclose(
layer_te.inter_attention.layernorm_query.bias.grad,
layer_pd.inter_attention.layernorm_query.bias.grad,
rtol=rtol,
atol=atol)
atol=atol,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('bs', [8])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128]])
@pytest.mark.parametrize('mask_type', ['causal'])
@pytest.mark.parametrize('math_dtype', ['bfloat16'])
@pytest.mark.parametrize('num_microbatch', [8])
def test_transformer_encoder_layer_microbatch(bs, hidden_size, num_heads, ffn_hidden_size, q_seqlen,
kv_seqlen, mask_type, math_dtype, num_microbatch):
@pytest.mark.parametrize("bs", [8])
@pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[1024, 16, 4096]])
@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[128, 128]])
@pytest.mark.parametrize("mask_type", ["causal"])
@pytest.mark.parametrize("math_dtype", ["bfloat16"])
@pytest.mark.parametrize("num_microbatch", [8])
def test_transformer_encoder_layer_microbatch(
bs,
hidden_size,
num_heads,
ffn_hidden_size,
q_seqlen,
kv_seqlen,
mask_type,
math_dtype,
num_microbatch,
):
"""
Test Transformer Encoder Layer with FP8 weight caching
"""
......@@ -1396,7 +1551,8 @@ def test_transformer_encoder_layer_microbatch(bs, hidden_size, num_heads, ffn_hi
):
pytest.skip("cuDNN fused attention is not supported")
layer_cached = te.TransformerLayer(hidden_size,
layer_cached = te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_heads,
layernorm_epsilon=eps,
......@@ -1405,8 +1561,10 @@ def test_transformer_encoder_layer_microbatch(bs, hidden_size, num_heads, ffn_hi
weight_attr=None,
bias_attr=None,
self_attn_mask_type=mask_type,
layer_type='encoder')
layer_normal = te.TransformerLayer(hidden_size,
layer_type="encoder",
)
layer_normal = te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_heads,
layernorm_epsilon=eps,
......@@ -1415,16 +1573,21 @@ def test_transformer_encoder_layer_microbatch(bs, hidden_size, num_heads, ffn_hi
weight_attr=None,
bias_attr=None,
self_attn_mask_type=mask_type,
layer_type='encoder')
layer_type="encoder",
)
layer_normal.self_attention.layernorm_qkv.ln_weight.copy_(
layer_cached.self_attention.layernorm_qkv.ln_weight, True)
layer_cached.self_attention.layernorm_qkv.ln_weight, True
)
layer_normal.self_attention.layernorm_qkv.ln_bias.copy_(
layer_cached.self_attention.layernorm_qkv.ln_bias, True)
layer_cached.self_attention.layernorm_qkv.ln_bias, True
)
layer_normal.self_attention.layernorm_qkv.weight.copy_(
layer_cached.self_attention.layernorm_qkv.weight, True)
layer_cached.self_attention.layernorm_qkv.weight, True
)
layer_normal.self_attention.layernorm_qkv.bias.copy_(
layer_cached.self_attention.layernorm_qkv.bias, True)
layer_cached.self_attention.layernorm_qkv.bias, True
)
layer_normal.self_attention.proj.weight.copy_(layer_cached.self_attention.proj.weight, True)
layer_normal.self_attention.proj.bias.copy_(layer_cached.self_attention.proj.bias, True)
......@@ -1442,18 +1605,19 @@ def test_transformer_encoder_layer_microbatch(bs, hidden_size, num_heads, ffn_hi
def generate_input():
encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)
q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen
q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen
kv_actual_seqlen = q_actual_seqlen
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool')
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool")
grad_out = paddle.normal(mean=0.0, std=0.02,
shape=(bs, q_seqlen, hidden_size)).astype('float32')
grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, hidden_size)).astype(
"float32"
)
for i in range(0, bs):
grad_out[i, q_actual_seqlen[i]:, :] = 0
grad_out[i, q_actual_seqlen[i] :, :] = 0
grad_out = grad_out.astype(math_dtype)
for i in range(0, bs):
attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False
attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False
return encoder_input, attn_mask, grad_out
......@@ -1477,7 +1641,9 @@ def test_transformer_encoder_layer_microbatch(bs, hidden_size, num_heads, ffn_hi
out_ref.backward(grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(layer_cached.self_attention.layernorm_qkv.weight.grad,
assert_allclose(
layer_cached.self_attention.layernorm_qkv.weight.grad,
layer_normal.self_attention.layernorm_qkv.weight.grad,
rtol=rtol,
atol=atol)
atol=atol,
)
......@@ -16,7 +16,7 @@ is_fp8_supported, reason = is_fp8_available()
def create_optimizer(model, use_pure_bf16, use_main_grad):
'''Create optimizer'''
"""Create optimizer"""
if use_main_grad:
assert use_pure_bf16
model = mix_precision_utils.MixPrecisionLayer(model, dtype="bfloat16")
......@@ -32,7 +32,7 @@ def create_optimizer(model, use_pure_bf16, use_main_grad):
class Net(paddle.nn.Layer):
'''Network use for main_grad testing'''
"""Network use for main_grad testing"""
def __init__(self, fuse_wgrad_accumulation):
super().__init__()
......@@ -40,7 +40,7 @@ class Net(paddle.nn.Layer):
4096,
16384,
32,
layer_type='encoder',
layer_type="encoder",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
)
......@@ -50,7 +50,7 @@ class Net(paddle.nn.Layer):
def train(enable_master_grad, fuse_wgrad_accumulation=False):
'''Train function'''
"""Train function"""
paddle.seed(10)
accumulate_steps = 4
......@@ -64,7 +64,7 @@ def train(enable_master_grad, fuse_wgrad_accumulation=False):
loss_list = []
for step_id in range(16):
inp = paddle.uniform([2, 1024, 4096], dtype='float32')
inp = paddle.uniform([2, 1024, 4096], dtype="float32")
inp.stop_gradient = False
with te.fp8_autocast(enabled=True):
out = model(inp)
......@@ -82,8 +82,8 @@ def train(enable_master_grad, fuse_wgrad_accumulation=False):
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
def test_master_grad():
'''Test main_grad'''
paddle.set_default_dtype('float32')
"""Test main_grad"""
paddle.set_default_dtype("float32")
loss1 = train(enable_master_grad=False)
loss2 = train(enable_master_grad=True)
loss3 = train(enable_master_grad=True, fuse_wgrad_accumulation=True)
......
......@@ -56,8 +56,13 @@ from transformer_engine.paddle.fp8 import is_fp8_available
from transformer_engine.paddle.constants import FP8FwdTensors
from transformer_engine.common.recipe import DelayedScaling
GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)]
GEMM_CASES = [
(256, 256, 512),
(32, 32, 32),
(16384, 1024, 2816),
(16384, 2816, 1024),
(16384, 1024, 1024),
]
is_fp8_supported, reason = is_fp8_available()
SELF_ATTN_CASES = [(2, 512, 12, 64)]
......@@ -74,13 +79,13 @@ def setup():
yield
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize('inplace', [True, False])
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("inplace", [True, False])
def test_quantize_dequantize(fp8_dtype, inplace):
"""
Test cast_to_fp8 and cast_from_fp8
"""
a = paddle.rand(shape=(32, 32), dtype='float32')
a = paddle.rand(shape=(32, 32), dtype="float32")
# Init fp8_meta
fp8_meta = create_fp8_meta()
a_fp8 = paddle.zeros(shape=a.shape, dtype=paddle.uint8) if inplace else None
......@@ -99,7 +104,7 @@ def copy_bits_from_float_to_uint16(f):
"""
Copy bits
"""
return struct.unpack('<I', struct.pack('<f', f))[0] >> 16
return struct.unpack("<I", struct.pack("<f", f))[0] >> 16
def convert_float_to_uint16(float_list):
......@@ -124,95 +129,106 @@ class TestTranspose:
"""
Test BF16 transpose
"""
a = paddle.rand(shape=(16, 32), dtype='bfloat16')
a = paddle.rand(shape=(16, 32), dtype="bfloat16")
a_transposed = transpose(a, otype=tex.DType.kBFloat16)
assert_allclose(a_transposed, a.T)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_transpose_fp8(fp8_dtype):
"""
Test FP8 transpose
"""
min_val = -8
max_val = 8
a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32')
a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32")
fp8_meta = create_fp8_meta()
a_fp8 = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
a_fp8_transposed = transpose(a_fp8, otype=fp8_dtype)
a_transposed = cast_from_fp8(a_fp8_transposed,
a_transposed = cast_from_fp8(
a_fp8_transposed,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32)
otype=tex.DType.kFloat32,
)
assert_allclose(a_transposed, a.T)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize('inplace', [True, False])
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("inplace", [True, False])
def test_cast_transpose(fp8_dtype, inplace):
"""
Test cast_transpose
"""
min_val = -8
max_val = 8
a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32')
a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32")
fp8_meta = create_fp8_meta()
a_fp8_casted, a_fp8_transposed = None, None
if inplace:
a_fp8_casted = paddle.zeros(shape=a.shape, dtype=paddle.uint8)
a_fp8_transposed = paddle.zeros(shape=a.T.shape, dtype=paddle.uint8)
a_fp8_casted, a_fp8_transposed = cast_transpose(a,
a_fp8_casted, a_fp8_transposed = cast_transpose(
a,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
otype=fp8_dtype,
cast_out=a_fp8_casted,
transpose_out=a_fp8_transposed)
transpose_out=a_fp8_transposed,
)
a_transposed = cast_from_fp8(a_fp8_transposed,
a_transposed = cast_from_fp8(
a_fp8_transposed,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32)
otype=tex.DType.kFloat32,
)
a_casted = cast_from_fp8(a_fp8_casted,
a_casted = cast_from_fp8(
a_fp8_casted,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32)
otype=tex.DType.kFloat32,
)
assert_allclose(a_casted, a)
assert_allclose(a_transposed, a.T)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_cast_transpose_bgrad(fp8_dtype):
"""
Test cast_transpose_bgrad
"""
min_val = -8
max_val = 8
a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32')
a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32")
fp8_meta = create_fp8_meta()
bgrad, a_fp8_casted, a_fp8_transposed = cast_transpose_bgrad(a,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
otype=fp8_dtype)
bgrad, a_fp8_casted, a_fp8_transposed = cast_transpose_bgrad(
a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype
)
a_transposed = cast_from_fp8(a_fp8_transposed,
a_transposed = cast_from_fp8(
a_fp8_transposed,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32)
otype=tex.DType.kFloat32,
)
a_casted = cast_from_fp8(a_fp8_casted,
a_casted = cast_from_fp8(
a_fp8_casted,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32)
otype=tex.DType.kFloat32,
)
assert_allclose(a_casted, a)
assert_allclose(a_transposed, a.T)
......@@ -229,7 +245,7 @@ class TestActivation:
"""
Test BF16 GELU Forward
"""
a = paddle.rand(shape=(16, 32), dtype='bfloat16') * 2 - 1
a = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1
gelu_out = te_gelu(a, otype=tex.DType.kBFloat16)
gelu_ref = paddle.nn.GELU()(a)
......@@ -237,21 +253,23 @@ class TestActivation:
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_gelu_fp8(fp8_dtype):
"""
Test FP8 GELU Forward
"""
a = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1
a = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1
fp8_meta = create_fp8_meta()
gelu_out_fp8 = gelu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
gelu_out = cast_from_fp8(gelu_out_fp8,
gelu_out = cast_from_fp8(
gelu_out_fp8,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32)
otype=tex.DType.kFloat32,
)
gelu_ref = paddle.nn.GELU()(a)
......@@ -259,36 +277,38 @@ class TestActivation:
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_gelu_bwd_fp8(fp8_dtype):
"""
Test FP8 GELU Backward
"""
# y = GELU(x), calculate ref
x = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1
x = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1
x.stop_gradient = False
y = paddle.nn.GELU()(x)
y_grad = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1
y_grad = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1
paddle.autograd.backward([y], [y_grad], True)
# calculate fp8
fp8_meta = create_fp8_meta()
x_grad_fp8, x_grad_t_fp8, dbias = dgelu_cast_transpose_bgrad_fp8(y_grad,
x,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
otype=fp8_dtype)
x_grad_fp8, x_grad_t_fp8, dbias = dgelu_cast_transpose_bgrad_fp8(
y_grad, x, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype
)
x_grad = cast_from_fp8(x_grad_fp8,
x_grad = cast_from_fp8(
x_grad_fp8,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32)
otype=tex.DType.kFloat32,
)
x_grad_t = cast_from_fp8(x_grad_t_fp8,
x_grad_t = cast_from_fp8(
x_grad_t_fp8,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32)
otype=tex.DType.kFloat32,
)
assert_allclose(x_grad, x.grad, rtol=0.1, atol=0.01)
assert_allclose(x_grad_t, x.grad.T, rtol=0.1, atol=0.01)
......@@ -299,7 +319,7 @@ class TestActivation:
"""
Test BF16 SwiGLU Forward
"""
a = paddle.rand(shape=(16, 32), dtype='bfloat16') * 2 - 1
a = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1
swiglu_out = swiglu(a, otype=tex.DType.kBFloat16)
swiglu_ref = swiglu_pd(a)
......@@ -307,21 +327,23 @@ class TestActivation:
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_swiglu_fp8(fp8_dtype):
"""
Test FP8 SwiGLU Forward
"""
a = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1
a = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1
fp8_meta = create_fp8_meta()
swiglu_out_fp8 = swiglu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
swiglu_out = cast_from_fp8(swiglu_out_fp8,
swiglu_out = cast_from_fp8(
swiglu_out_fp8,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32)
otype=tex.DType.kFloat32,
)
swiglu_ref = swiglu_pd(a)
......@@ -333,10 +355,10 @@ class TestActivation:
Test SwiGLU Backward
"""
# y = SwiGLU(x), calculate ref
x = paddle.rand(shape=(16, 32), dtype='bfloat16') * 2 - 1
x = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1
x.stop_gradient = False
y = swiglu_pd(x)
y_grad = paddle.rand(shape=(16, 16), dtype='bfloat16') * 2 - 1
y_grad = paddle.rand(shape=(16, 16), dtype="bfloat16") * 2 - 1
paddle.autograd.backward([y], [y_grad], True)
# calculate fp8
x_grad = dswiglu(y_grad, x, otype=tex.DType.kBFloat16)
......@@ -350,17 +372,18 @@ class TestGemm:
"""
@staticmethod
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="BF16 GEMM requires Ampere+ GPU")
@pytest.mark.parametrize('m,n,k', GEMM_CASES)
@pytest.mark.skipif(
paddle.device.cuda.get_device_capability() < (8, 0), reason="BF16 GEMM requires Ampere+ GPU"
)
@pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_bf16(m, n, k):
"""
Test "TN" BF16 GEMM
"""
a = paddle.rand(shape=(m, k), dtype='bfloat16')
b = paddle.rand(shape=(n, k), dtype='bfloat16')
a = paddle.rand(shape=(m, k), dtype="bfloat16")
b = paddle.rand(shape=(n, k), dtype="bfloat16")
workspace = paddle.zeros(shape=[33_554_432], dtype='uint8')
workspace = paddle.zeros(shape=[33_554_432], dtype="uint8")
ref_out = paddle.matmul(a, b.T)
# CublasLt inside tex.te_gemm assumes inputs are column major.
......@@ -368,37 +391,51 @@ class TestGemm:
# transpose of X.
# Here we perform "TN" GEMM in column major, i.e., b@a^T = C^T,
# which is equivalent to a@b^T = C in row major.
actual_out, _, _ = gemm(b, a, paddle.bfloat16, workspace, False, None, False, False, "TN",
None, None, False)
actual_out, _, _ = gemm(
b, a, paddle.bfloat16, workspace, False, None, False, False, "TN", None, None, False
)
assert_allclose(actual_out, ref_out, rtol=1.6e-2, atol=1e-5)
@staticmethod
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="BF16 GEMM requires Ampere+ GPU")
@pytest.mark.parametrize('m,n,k', GEMM_CASES)
@pytest.mark.skipif(
paddle.device.cuda.get_device_capability() < (8, 0), reason="BF16 GEMM requires Ampere+ GPU"
)
@pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_bf16_inplace(m, n, k):
"""
Test "TN" BF16 GEMM, with accumulate=True
"""
min_val = -16
max_val = 16
a = paddle.rand(shape=(m, k), dtype='bfloat16')
b = paddle.rand(shape=(n, k), dtype='bfloat16')
c = paddle.cast(paddle.randint(min_val, max_val, shape=(m, n)), 'bfloat16')
workspace = paddle.zeros(shape=[33_554_432], dtype='uint8')
a = paddle.rand(shape=(m, k), dtype="bfloat16")
b = paddle.rand(shape=(n, k), dtype="bfloat16")
c = paddle.cast(paddle.randint(min_val, max_val, shape=(m, n)), "bfloat16")
workspace = paddle.zeros(shape=[33_554_432], dtype="uint8")
ref_out = c + paddle.matmul(a, b.T)
actual_out = paddle.clone(c)
_, _, _ = gemm(b, a, paddle.bfloat16, workspace, False, None, False, True, "TN", actual_out,
None, False)
_, _, _ = gemm(
b,
a,
paddle.bfloat16,
workspace,
False,
None,
False,
True,
"TN",
actual_out,
None,
False,
)
assert_allclose(actual_out, ref_out, rtol=5e-2, atol=5e-2)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', GEMM_CASES)
@pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_fp8_randint(m, n, k):
"""
Test "TN" FP8 GEMM
......@@ -409,17 +446,26 @@ class TestGemm:
out_dtype = paddle.float32
fp8_meta = create_fp8_meta(num_gemms=1)
a = paddle.cast(paddle.randint(min_val, max_val, shape=(m, k)), 'float32')
a = paddle.cast(paddle.randint(min_val, max_val, shape=(m, k)), "float32")
a_casted = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
b = paddle.cast(paddle.randint(min_val, max_val, shape=(n, k)), 'float32')
b = paddle.cast(paddle.randint(min_val, max_val, shape=(n, k)), "float32")
b_casted = cast_to_fp8(b, fp8_meta, FP8FwdTensors.GEMM1_WEIGHT, otype=fp8_dtype)
workspace = paddle.zeros(shape=[33_554_432], dtype='uint8')
workspace = paddle.zeros(shape=[33_554_432], dtype="uint8")
ref_out = paddle.matmul(a, b.T)
actual_out, _ = fp8_gemm(b_casted, fp8_meta.scale_inv, FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype, a_casted, fp8_meta.scale_inv, FP8FwdTensors.GEMM1_INPUT,
fp8_dtype, out_dtype, workspace)
actual_out, _ = fp8_gemm(
b_casted,
fp8_meta.scale_inv,
FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype,
a_casted,
fp8_meta.scale_inv,
FP8FwdTensors.GEMM1_INPUT,
fp8_dtype,
out_dtype,
workspace,
)
assert_allclose(actual_out, ref_out)
......@@ -434,14 +480,12 @@ class TestLayerNorm:
"""
Calculate reference using paddle layer_norm op
"""
y = paddle.nn.functional.layer_norm(x=x,
normalized_shape=x.shape[1:],
weight=gamma,
bias=beta,
epsilon=eps)
y = paddle.nn.functional.layer_norm(
x=x, normalized_shape=x.shape[1:], weight=gamma, bias=beta, epsilon=eps
)
mean = paddle.mean(x, axis=-1)
var = paddle.var(x, axis=-1)
inv_var = paddle.sqrt(1. / var)
inv_var = paddle.sqrt(1.0 / var)
return y, mean, inv_var
@staticmethod
......@@ -453,11 +497,9 @@ class TestLayerNorm:
gamma.stop_gradient = False
beta.stop_gradient = False
y = paddle.nn.functional.layer_norm(x=x,
normalized_shape=x.shape[1:],
weight=gamma,
bias=beta,
epsilon=eps)
y = paddle.nn.functional.layer_norm(
x=x, normalized_shape=x.shape[1:], weight=gamma, bias=beta, epsilon=eps
)
paddle.autograd.backward([y], [dy], True)
......@@ -469,9 +511,9 @@ class TestLayerNorm:
"""
N, H = (16, 32)
eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype='bfloat16')
gamma = paddle.uniform(shape=(H,), dtype='bfloat16')
beta = paddle.uniform(shape=(H,), dtype='bfloat16')
x = paddle.uniform(shape=(N, H), dtype="bfloat16")
gamma = paddle.uniform(shape=(H,), dtype="bfloat16")
beta = paddle.uniform(shape=(H,), dtype="bfloat16")
y, mu, rsigma = layernorm_fwd(x, gamma, beta, eps, tex.DType.kBFloat16)
......@@ -490,9 +532,9 @@ class TestLayerNorm:
N, H = (16, 32)
eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype='float32')
gamma = paddle.uniform(shape=(H,), dtype='float32')
beta = paddle.uniform(shape=(H,), dtype='float32')
x = paddle.uniform(shape=(N, H), dtype="float32")
gamma = paddle.uniform(shape=(H,), dtype="float32")
beta = paddle.uniform(shape=(H,), dtype="float32")
fp8_tensor = FP8FwdTensors.GEMM1_INPUT
fp8_meta = create_fp8_meta()
......@@ -513,10 +555,10 @@ class TestLayerNorm:
"""
N, H = (16, 32)
eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype='bfloat16')
dy = paddle.uniform(shape=(N, H), dtype='bfloat16')
gamma = paddle.uniform(shape=(H,), dtype='bfloat16')
beta = paddle.uniform(shape=(H,), dtype='bfloat16')
x = paddle.uniform(shape=(N, H), dtype="bfloat16")
dy = paddle.uniform(shape=(N, H), dtype="bfloat16")
gamma = paddle.uniform(shape=(H,), dtype="bfloat16")
beta = paddle.uniform(shape=(H,), dtype="bfloat16")
dx_ref, dgamma_ref, dbeta_ref = self.calc_bwd_ref(x, eps, gamma, beta, dy)
......@@ -563,8 +605,8 @@ class TestRMSNorm:
"""
N, H = (16, 32)
eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype='bfloat16')
gamma = paddle.uniform(shape=(H,), dtype='bfloat16')
x = paddle.uniform(shape=(N, H), dtype="bfloat16")
gamma = paddle.uniform(shape=(H,), dtype="bfloat16")
y, _ = rmsnorm_fwd(x, gamma, eps, tex.DType.kBFloat16)
......@@ -581,8 +623,8 @@ class TestRMSNorm:
N, H = (16, 32)
eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype='float32')
gamma = paddle.uniform(shape=(H,), dtype='float32')
x = paddle.uniform(shape=(N, H), dtype="float32")
gamma = paddle.uniform(shape=(H,), dtype="float32")
fp8_tensor = FP8FwdTensors.GEMM1_INPUT
fp8_meta = create_fp8_meta()
......@@ -602,9 +644,9 @@ class TestRMSNorm:
"""
N, H = (16, 32)
eps = 1e-3
x = paddle.uniform(shape=(N, H), dtype='bfloat16')
dy = paddle.uniform(shape=(N, H), dtype='bfloat16')
gamma = paddle.uniform(shape=(H,), dtype='bfloat16')
x = paddle.uniform(shape=(N, H), dtype="bfloat16")
dy = paddle.uniform(shape=(N, H), dtype="bfloat16")
gamma = paddle.uniform(shape=(H,), dtype="bfloat16")
dx_ref, dgamma_ref = self.calc_bwd_ref(x, eps, gamma, dy)
......@@ -620,7 +662,7 @@ class TestFusedAttn:
Test fused attention operators
"""
def set_input(self, b, s_q, s_kv, h, d, dtype, attn_mode='self_attn', is_causal_masking=False):
def set_input(self, b, s_q, s_kv, h, d, dtype, attn_mode="self_attn", is_causal_masking=False):
"""
set test input
"""
......@@ -682,10 +724,10 @@ class TestFusedAttn:
assert attn_mode == "self_attn", "only support causal masking for self attention"
for i in range(0, self.batch_size):
for j in range(self.q_actual_seqlen[i]):
self.attn_mask[i, :, j, :j + 1] = 0
self.attn_mask[i, :, j, : j + 1] = 0
else:
for i in range(0, self.batch_size):
self.attn_mask[i, :, :self.q_actual_seqlen[i], :self.kv_actual_seqlen[i]] = 0
self.attn_mask[i, :, : self.q_actual_seqlen[i], : self.kv_actual_seqlen[i]] = 0
dout = _random((self.batch_size, self.q_seqlen, self.num_heads, self.head_size))
self.dout = paddle.to_tensor(dout, dtype=self.dtype)
......@@ -707,10 +749,10 @@ class TestFusedAttn:
transpose_y=True,
)
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True).cast('bool')
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True).cast("bool")
attn_mask_vals = paddle.full(qk_out.shape, -1e4, qk_out.dtype)
attn_mask_out = paddle.where(attn_mask, attn_mask_vals, qk_out)
attn_mask_out = paddle.cast(attn_mask_out, 'float32')
attn_mask_out = paddle.cast(attn_mask_out, "float32")
softmax_out = F.softmax(attn_mask_out)
softmax_out = paddle.cast(softmax_out, self.dtype)
......@@ -748,7 +790,7 @@ class TestFusedAttn:
q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True)
kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True)
qkv_layout = ("bs3hd" if self.attn_mode == "self_attn" else "bshd_bs2hd")
qkv_layout = "bs3hd" if self.attn_mode == "self_attn" else "bshd_bs2hd"
fused_attention_backend = get_fused_attention_backend(
num_heads=self.num_heads,
num_gqa_groups=self.num_heads,
......@@ -764,7 +806,7 @@ class TestFusedAttn:
qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16
out, softmax_aux_tensor, q_grad, k_grad, v_grad = None, None, None, None, None
if self.attn_mode == 'self_attn':
if self.attn_mode == "self_attn":
out, softmax_aux_tensor, rng_state = fused_attn_fwd_qkvpacked(
qkv_tensor,
q_cu_seqlen_tensor,
......@@ -776,7 +818,8 @@ class TestFusedAttn:
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False,
attn_mask_type="causal" if self.is_causal_masking else "padding")
attn_mask_type="causal" if self.is_causal_masking else "padding",
)
dqkv, _ = fused_attn_bwd_qkvpacked(
qkv_tensor,
q_cu_seqlen_tensor,
......@@ -790,7 +833,8 @@ class TestFusedAttn:
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False,
attn_mask_type="causal" if self.is_causal_masking else "padding")
attn_mask_type="causal" if self.is_causal_masking else "padding",
)
q_grad = dqkv[:, :, 0, :, :]
k_grad = dqkv[:, :, 1, :, :]
v_grad = dqkv[:, :, 2, :, :]
......@@ -808,8 +852,10 @@ class TestFusedAttn:
Bias=None,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False)
dq, dkv, _ = fused_attn_bwd_kvpacked(q_tensor,
set_zero=False,
)
dq, dkv, _ = fused_attn_bwd_kvpacked(
q_tensor,
kv_tensor,
q_cu_seqlen_tensor,
kv_cu_seqlen_tensor,
......@@ -823,7 +869,8 @@ class TestFusedAttn:
qkv_dtype=qkv_dtype,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False)
set_zero=False,
)
q_grad = dq
k_grad = dkv[:, :, 0, :, :]
v_grad = dkv[:, :, 1, :, :]
......@@ -871,7 +918,8 @@ class TestFusedAttn:
dropout=self.dropout_prob,
set_zero=False,
qkv_layout=qkv_layout,
attn_mask_type="causal" if self.is_causal_masking else "padding")
attn_mask_type="causal" if self.is_causal_masking else "padding",
)
dq, dk, dv, _ = fused_attn_bwd(
q_tensor,
k_tensor,
......@@ -890,13 +938,14 @@ class TestFusedAttn:
dropout=self.dropout_prob,
set_zero=False,
qkv_layout=qkv_layout,
attn_mask_type="causal" if self.is_causal_masking else "padding")
attn_mask_type="causal" if self.is_causal_masking else "padding",
)
return out, dq, dk, dv
@pytest.mark.parametrize('b, s, h, d', SELF_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
@pytest.mark.parametrize('is_causal_masking', [True, False])
@pytest.mark.parametrize("b, s, h, d", SELF_ATTN_CASES)
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
@pytest.mark.parametrize("is_causal_masking", [True, False])
def test_self_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking):
"""
test self attention forward + backward
......@@ -922,8 +971,8 @@ class TestFusedAttn:
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
@pytest.mark.parametrize("b, s_q, s_kv, h, d", CROSS_ATTN_CASES)
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
def test_cross_attn_forward_backward(self, b, s_q, s_kv, h, d, dtype):
"""
test cross attention forward + backward
......@@ -949,9 +998,9 @@ class TestFusedAttn:
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
@pytest.mark.parametrize('b, s, h, d', FLASH_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
@pytest.mark.parametrize('is_causal_masking', [True])
@pytest.mark.parametrize("b, s, h, d", FLASH_ATTN_CASES)
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
@pytest.mark.parametrize("is_causal_masking", [True])
def test_flash_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking):
"""
test flash attention forward + backward
......@@ -977,11 +1026,12 @@ class TestFusedAttn:
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
@pytest.mark.parametrize('b, s, h, d', FLASH_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
@pytest.mark.parametrize('is_causal_masking', [False, True])
def test_fused_attn_with_separate_qkv_forward_backward(self, b, s, h, d, dtype,
is_causal_masking):
@pytest.mark.parametrize("b, s, h, d", FLASH_ATTN_CASES)
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
@pytest.mark.parametrize("is_causal_masking", [False, True])
def test_fused_attn_with_separate_qkv_forward_backward(
self, b, s, h, d, dtype, is_causal_masking
):
"""
test flash attention forward + backward with separate qkv inputs
"""
......@@ -1013,7 +1063,7 @@ class TestSoftmax:
"""
@staticmethod
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
def test_scaled_softmax_fwd_bwd(dtype):
"""test scaled softmax"""
B, H, S = (16, 4, 32)
......@@ -1034,7 +1084,7 @@ class TestSoftmax:
assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3)
@staticmethod
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
def test_scaled_masked_softmax_fwd_bwd(dtype):
"""test scaled masked softmax"""
B, H, S = (16, 4, 32)
......@@ -1058,7 +1108,7 @@ class TestSoftmax:
assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3)
@staticmethod
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
def test_scaled_upper_triang_masked_softmax_fwd_bwd(dtype):
"""test scaled upper triang masked softmax"""
B, S = (16, 32)
......@@ -1068,7 +1118,7 @@ class TestSoftmax:
x.stop_gradient = False
dy = paddle.uniform(shape=(B, S, S), dtype=dtype)
mask = paddle.ones((S, S), dtype='int32')
mask = paddle.ones((S, S), dtype="int32")
col_beg, col_end = 1, S
for row in range(0, S):
mask[row, col_beg:col_end] = 0
......@@ -1087,7 +1137,7 @@ class TestSoftmax:
assert_allclose(dx_ref, dx, rtol=1e-4, atol=5e-3)
@pytest.mark.parametrize('update_weight_scale_inv', [True, False])
@pytest.mark.parametrize("update_weight_scale_inv", [True, False])
def test_amax_and_scale_update(update_weight_scale_inv):
"""Test update_scale"""
num_gemm = 6
......@@ -1097,11 +1147,11 @@ def test_amax_and_scale_update(update_weight_scale_inv):
fp8_max = recipe.fp8_format.value.max_fwd
non_weight_mask = paddle.to_tensor([True, False] * (num_gemm // 2))
amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype='float32')
amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype="float32")
rolled_history_ref = paddle.roll(amax_history_tensor, -1, axis=0)
rolled_history_ref[0] = 0.0
amax_tensor = paddle.max(amax_history_tensor, axis=0)
scale_tensor = paddle.ones(shape=[num_gemm], dtype='float32')
scale_tensor = paddle.ones(shape=[num_gemm], dtype="float32")
def calc_ref(amax, scale, fp8_max, margin=0):
"""Calculate reference scale"""
......@@ -1110,12 +1160,12 @@ def test_amax_and_scale_update(update_weight_scale_inv):
sf = paddle.where(paddle.isfinite(amax), sf, scale)
return sf
scale_ref = calc_ref(amax_tensor, scale_tensor, fp8_max, 0.)
scale_ref = calc_ref(amax_tensor, scale_tensor, fp8_max, 0.0)
if update_weight_scale_inv:
scale_inv_ref = 1. / scale_ref
scale_inv_ref = 1.0 / scale_ref
else:
scale_inv_ref = paddle.zeros_like(scale_tensor)
scale_inv_ref = paddle.where(non_weight_mask, 1. / scale_ref, scale_inv_ref)
scale_inv_ref = paddle.where(non_weight_mask, 1.0 / scale_ref, scale_inv_ref)
# Placeholder
scale_actual = paddle.zeros_like(scale_tensor)
......@@ -1123,13 +1173,15 @@ def test_amax_and_scale_update(update_weight_scale_inv):
if update_weight_scale_inv:
non_weight_mask = paddle.empty([0])
tex.amax_and_scale_update_inplace(_amax_history=amax_history_tensor,
tex.amax_and_scale_update_inplace(
_amax_history=amax_history_tensor,
_scale=scale_actual,
_scale_inv=scale_inv_actual,
non_weight_mask=non_weight_mask,
fp8_dtype=int(fp8_dtype),
margin=0.,
amax_compute="max")
margin=0.0,
amax_compute="max",
)
assert_allclose(scale_actual, scale_ref, rtol=1e-7, atol=1e-7)
assert_allclose(scale_inv_actual, scale_inv_ref, rtol=1e-7, atol=1e-7)
......@@ -1141,8 +1193,8 @@ def test_update_latest_history():
num_gemm = 6
history_len = 1024
amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype='float32')
amax = paddle.rand(shape=[num_gemm], dtype='float32')
amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype="float32")
amax = paddle.rand(shape=[num_gemm], dtype="float32")
tex.update_latest_amax_history_inplace(_history=amax_history_tensor, amax=amax)
......
......@@ -22,7 +22,7 @@ class TestParallelLinear(TestDistributed):
@unittest.skipIf(not gpu_has_fp8, reason)
def test_linear_tp(self):
"""Tests linear with tensor parallel in BF16"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'linear_tp.py'))
self.run_2gpu(str(test_root / "parallel_tests" / "linear_tp.py"))
class TestParallelLayerNormLinear(TestDistributed):
......@@ -32,7 +32,7 @@ class TestParallelLayerNormLinear(TestDistributed):
@unittest.skipIf(not gpu_has_fp8, reason)
def test_layernorm_linear_tp(self):
"""Tests layernorm_linear with tensor parallel in BF16"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'layernorm_linear_tp.py'))
self.run_2gpu(str(test_root / "parallel_tests" / "layernorm_linear_tp.py"))
class TestParallelLayerNormMLP(TestDistributed):
......@@ -42,7 +42,7 @@ class TestParallelLayerNormMLP(TestDistributed):
@unittest.skipIf(not gpu_has_fp8, reason)
def test_layernorm_mlp_tp(self):
"""Tests layernorm_mlp with tensor parallel in BF16"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'layernorm_mlp_tp.py'))
self.run_2gpu(str(test_root / "parallel_tests" / "layernorm_mlp_tp.py"))
class TestAmaxReduction(TestDistributed):
......@@ -52,7 +52,7 @@ class TestAmaxReduction(TestDistributed):
@unittest.skipIf(not gpu_has_fp8, reason)
def test_amax_reduction(self):
"""Tests amax reduction"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'amax_reduction.py'))
self.run_2gpu(str(test_root / "parallel_tests" / "amax_reduction.py"))
class TestPipelineParallel(TestDistributed):
......@@ -62,7 +62,7 @@ class TestPipelineParallel(TestDistributed):
@unittest.skipIf(not gpu_has_fp8, reason)
def test_pipeline_parallel(self):
"""Tests pipeline parallel"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'linear_pp.py'))
self.run_2gpu(str(test_root / "parallel_tests" / "linear_pp.py"))
class TestGroupSharding(TestDistributed):
......@@ -72,7 +72,7 @@ class TestGroupSharding(TestDistributed):
@unittest.skipIf(not gpu_has_fp8, reason)
def test_group_sharding(self):
"""Tests group sharding"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'group_sharding.py'))
self.run_2gpu(str(test_root / "parallel_tests" / "group_sharding.py"))
class TestParallelAttention(TestDistributed):
......@@ -82,7 +82,7 @@ class TestParallelAttention(TestDistributed):
@unittest.skipIf(not gpu_has_fp8, reason)
def test_attention_tp(self):
"""Tests TransMultiHeadAttentionformer Layer with tensor parallel in BF16"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'attention_tp.py'))
self.run_2gpu(str(test_root / "parallel_tests" / "attention_tp.py"))
class TestParallelTransformerLayer(TestDistributed):
......@@ -92,8 +92,8 @@ class TestParallelTransformerLayer(TestDistributed):
@unittest.skipIf(not gpu_has_fp8, reason)
def test_transformer_tp(self):
"""Tests Transformer Layer with tensor parallel in BF16"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'transformer_tp.py'))
self.run_2gpu(str(test_root / "parallel_tests" / "transformer_tp.py"))
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
......@@ -17,7 +17,7 @@ is_fp8_supported, reason = is_fp8_available()
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('use_reentrant', [False, True])
@pytest.mark.parametrize("use_reentrant", [False, True])
def test_transformer_encoder_recompute(use_reentrant):
"""
Test TransformerLayer encoder recompute
......@@ -29,17 +29,17 @@ def test_transformer_encoder_recompute(use_reentrant):
"""Launch training in subprocess and check output"""
try:
cmd = [
'python',
str(test_root / 'recompute_tests' / 'recompute_transformer_encoder.py'),
"python",
str(test_root / "recompute_tests" / "recompute_transformer_encoder.py"),
str(int(enable_recompute)),
str(int(use_reentrant))
str(int(use_reentrant)),
]
result = subprocess.check_output(cmd, stderr=subprocess.STDOUT, universal_newlines=True)
print(result)
loss_match = re.search(r'Loss:\s+(-?\d+\.\d+)', result)
memory_match = re.search(r'Peak memory:\s+(\d+)', result)
loss_match = re.search(r"Loss:\s+(-?\d+\.\d+)", result)
memory_match = re.search(r"Peak memory:\s+(\d+)", result)
loss_value = float(loss_match.group(1))
memory_value = int(memory_match.group(1))
......
......@@ -3,4 +3,5 @@
# See LICENSE for license information.
import transformer_engine.paddle
print("OK")
......@@ -19,7 +19,9 @@ from transformer_engine.paddle.constants import (
FusedAttnBackend,
)
from transformer_engine.paddle.fp8 import FP8TensorMeta
from transformer_engine import transformer_engine_paddle as tex # pylint: disable=wrong-import-order
from transformer_engine import (
transformer_engine_paddle as tex,
) # pylint: disable=wrong-import-order
def create_fp8_meta(num_gemms=1, amax_history_len=10):
......@@ -31,18 +33,14 @@ def create_fp8_meta(num_gemms=1, amax_history_len=10):
return fp8_meta
def assert_allclose(actual,
desired,
rtol=1e-05,
atol=1e-08,
equal_nan=True,
err_msg='',
verbose=True):
def assert_allclose(
actual, desired, rtol=1e-05, atol=1e-08, equal_nan=True, err_msg="", verbose=True
):
"""Compare two input paddle tensors"""
if isinstance(actual, paddle.Tensor):
actual = paddle.cast(actual, 'float32')
actual = paddle.cast(actual, "float32")
if isinstance(desired, paddle.Tensor):
desired = paddle.cast(desired, 'float32')
desired = paddle.cast(desired, "float32")
if len(actual.shape) == 0:
actual = actual.item()
desired = desired.item()
......@@ -54,8 +52,9 @@ def assert_allclose(actual,
def assert_shape(inp, expected_shape):
"""Assert the shape of input tensor equals to expected shape"""
assert inp.shape == expected_shape, f"Expected tensor shape: {expected_shape} != " \
f"actual tensor shape: {inp.shape}"
assert (
inp.shape == expected_shape
), f"Expected tensor shape: {expected_shape} != actual tensor shape: {inp.shape}"
def is_devices_enough(required):
......@@ -91,12 +90,21 @@ def set_random_seed(seed):
np.random.seed(seed + 100 * pp_rank)
seed_offset = seed + 1024 + paddle.distributed.get_world_size()
global_seed = (seed_offset + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) +
sharding_rank * (mp_size * pp_size * dp_size))
global_seed = (
seed_offset
+ pp_rank * (mp_size)
+ dp_rank * (mp_size * pp_size)
+ sharding_rank * (mp_size * pp_size * dp_size)
)
seed_offset += paddle.distributed.get_world_size()
local_seed = (seed_offset + mp_rank + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) +
sharding_rank * (mp_size * pp_size * dp_size))
local_seed = (
seed_offset
+ mp_rank
+ pp_rank * (mp_size)
+ dp_rank * (mp_size * pp_size)
+ sharding_rank * (mp_size * pp_size * dp_size)
)
tracker = get_rng_state_tracker()
# tracker.reset()
......
......@@ -112,13 +112,17 @@ def perf_and_loss_plots():
lm_loss_data.append(lm_data["loss"])
lm_perf_data.append(lm_data["perf"])
save_plot(
model_config + " loss", legend,
lm_loss_data, model_config + "_loss.png",
model_config + " loss",
legend,
lm_loss_data,
model_config + "_loss.png",
"LM-Loss",
)
save_plot(
model_config + " perf",
legend, lm_perf_data, model_config + "_perf.png",
legend,
lm_perf_data,
model_config + "_perf.png",
"Time per step (ms)",
)
......
......@@ -68,7 +68,9 @@ def get_filename(
config = f"gpt3_{model}_dp{dp}_tp{tp}_pp{pp}_sp{sp}"
config_dir = os.path.join(mlm_log_dir, config)
os.makedirs(config_dir, exist_ok=True)
fname = f"{'te' if use_te else 'megatron'}" + (f"_fp8_{fp8_recipe}" if fp8_recipe else "") + ".txt"
fname = (
f"{'te' if use_te else 'megatron'}" + (f"_fp8_{fp8_recipe}" if fp8_recipe else "") + ".txt"
)
return os.path.join(config_dir, fname)
......@@ -106,4 +108,5 @@ def test_distributed(dtype, fp8_recipe, dp, tp, pp, sp, use_te, model):
TRANSFORMER_IMPL="transformer_engine" if use_te else "local",
**asdict(model_configs[model]),
),
check=True)
check=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment