Unverified Commit e547f8e2 authored by Tian Zheng's avatar Tian Zheng Committed by GitHub
Browse files

[Paddle] Add sequence parallel (#561)



* Add SP for linear
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add SP for LayerNormLinear
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add SP for LayerNormMLP
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add SP API for transformer layer
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add sequence_parallel attr
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add SP unittests for Transformer and Attention
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix compatibility with PaddleNLP
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Copyright
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7a3ed9e2
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Unittest for Transformer layer in tensor parallel"""
import unittest
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.layers.mpu import mp_ops
from utils import assert_allclose, set_random_seed, register_sequence_parallel_allreduce_hooks
import transformer_engine.paddle as te
class TestAttentionTp(unittest.TestCase):
"""Tests MultiHeadAttention layer with model parallel in BF16"""
def setUp(self):
self.set_attr()
self.init_dist_env()
paddle.set_default_dtype(self.global_dtype)
def init_dist_env(self):
"""Init Paddle Fleet environment"""
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
}
strategy.hybrid_configs["mp_configs"].need_broadcast_data = False
fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index()
self.hcg = fleet.get_hybrid_communicate_group()
self.tp_group = self.hcg.get_model_parallel_group()
self.world_size = self.hcg.get_model_parallel_world_size()
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.global_dtype = 'bfloat16'
self.rtol = 5e-3
self.atol = 5e-3
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = False
def _train_one_step(self, layer, inp_list, optimizer, fp8_enabled, sequence_parallel=False):
inp, mask = inp_list
if sequence_parallel:
split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :]
else:
input_parallel = inp
with te.fp8_autocast(enabled=fp8_enabled):
out = layer(input_parallel, mask)
if sequence_parallel:
total_out = mp_ops._c_concat(out, group=self.tp_group)
total_out = paddle.concat(paddle.split(total_out, self.world_size, axis=-1), axis=0)
else:
total_out = out
loss = total_out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss, total_out
def test_parallel_layer(self):
"""Tests parallel Transformer"""
set_random_seed(1024)
common_args = (
self.hidden_size,
self.num_heads,
)
common_kwargs = {
'layernorm_epsilon': self.eps,
'attention_dropout': 0.0,
'attn_mask_type': self.mask_type,
'attention_type': 'self',
"tp_group": self.tp_group,
"input_layernorm": True,
}
layer_tp = te.MultiHeadAttention(*common_args,
**common_kwargs,
set_parallel_mode=True,
sequence_parallel=self.sequence_parallel)
layer_single = te.MultiHeadAttention(*common_args, **common_kwargs, set_parallel_mode=False)
def _get_total_weight(local_weight, tp_group, axis, interleave=False):
total_weight = []
partial_weight = local_weight.clone().detach()
paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group)
if interleave:
# Due to the interleaved qkv layout, need to concat on num_head
# dimention for column parallel linear in MultiHeadAttention layer
assert axis == 0
assert [3 * self.hidden_size // self.world_size,
self.hidden_size] == partial_weight.shape
local_num_head = self.num_heads // self.world_size
for idx, _ in enumerate(total_weight):
total_weight[idx] = total_weight[idx].reshape(
[3, local_num_head, -1, self.hidden_size])
total_weight = paddle.concat(total_weight, axis=1).reshape([-1, self.hidden_size])
else:
total_weight = paddle.concat(total_weight, axis=axis)
return total_weight
def _get_weight(obj, weight_names):
for name in weight_names:
obj = getattr(obj, name)
return obj
def copy_weight(layer_src, layer_dst, partition_mode, weight_names, interleave=False):
weight_src = _get_weight(layer_src, weight_names)
weight_dst = _get_weight(layer_dst, weight_names)
if partition_mode is None:
total_weight = weight_src
elif partition_mode == 'column':
total_weight = _get_total_weight(weight_src,
tp_group=self.tp_group,
axis=0,
interleave=interleave)
elif partition_mode == 'row':
total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1)
else:
raise ValueError(f"Partition Mode {partition_mode} is not supported.")
assert weight_dst.shape == total_weight.shape, \
f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match."
weight_dst.copy_(total_weight, True)
copy_weight(layer_tp, layer_single, None, ['layernorm_qkv', 'ln_weight'])
copy_weight(layer_tp, layer_single, 'column', ['layernorm_qkv', 'weight'], interleave=True)
copy_weight(layer_tp, layer_single, 'row', ['proj', 'weight'])
if self.sequence_parallel:
register_sequence_parallel_allreduce_hooks(layer_tp, accumulation_steps=1)
optimizer_tp = paddle.optimizer.SGD(learning_rate=0.01, parameters=layer_tp.parameters())
optimizer_single = paddle.optimizer.SGD(learning_rate=0.01,
parameters=layer_single.parameters())
layer_tp = fleet.distributed_model(layer_tp)
optimizer_tp = fleet.distributed_optimizer(optimizer_tp)
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.q_seqlen, self.hidden_size],
self.global_dtype)
mask = paddle.zeros(shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen),
dtype='bool')
loss_tp, out_tp = self._train_one_step(layer_tp, [inp, mask], optimizer_tp, self.fp8,
self.sequence_parallel)
loss_single, out_single = self._train_one_step(layer_single, [inp, mask],
optimizer_single, self.fp8)
assert_allclose(out_tp, out_single, rtol=self.rtol, atol=self.atol)
assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol)
class TestAttentionTpFp8(TestAttentionTp):
"""Tests MultiHeadAttention layer with model parallel in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.global_dtype = 'bfloat16'
self.rtol = 5e-2
self.atol = 5e-2
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = False
class TestAttentionSp(TestAttentionTp):
"""Tests MultiHeadAttention layer with sequence parallel in BF16"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.global_dtype = 'bfloat16'
self.rtol = 5e-3
self.atol = 5e-3
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = True
class TestAttentionSpFp8(TestAttentionTp):
"""Tests MultiHeadAttention layer with sequence parallel in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.global_dtype = 'bfloat16'
self.rtol = 5e-2
self.atol = 1e-1
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = True
if __name__ == '__main__':
unittest.main()
...@@ -30,9 +30,12 @@ class TestLayerNormLinearTp(unittest.TestCase): ...@@ -30,9 +30,12 @@ class TestLayerNormLinearTp(unittest.TestCase):
"mp_degree": self.model_parallel_size, "mp_degree": self.model_parallel_size,
"pp_degree": 1, "pp_degree": 1,
} }
strategy.hybrid_configs["mp_configs"].need_broadcast_data = False
fleet.init(is_collective=True, strategy=strategy) fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index()
self.hcg = fleet.get_hybrid_communicate_group() self.hcg = fleet.get_hybrid_communicate_group()
self.tp_group = self.hcg.get_model_parallel_group() self.tp_group = self.hcg.get_model_parallel_group()
self.world_size = self.hcg.get_model_parallel_world_size()
def set_attr(self): def set_attr(self):
"""Set test configs""" """Set test configs"""
...@@ -44,6 +47,39 @@ class TestLayerNormLinearTp(unittest.TestCase): ...@@ -44,6 +47,39 @@ class TestLayerNormLinearTp(unittest.TestCase):
self.atol = 1e-3 self.atol = 1e-3
self.eps = 1e-3 self.eps = 1e-3
self.fp8 = False self.fp8 = False
self.sequence_parallel = False
def _train_one_step(self, layer, inp, optimizer, split_input='none', gather_output=False):
inp = paddle.to_tensor(inp, stop_gradient=True)
assert split_input in ['none', 'column', 'row']
if split_input == 'column':
split_size = inp.shape[1] // self.world_size
input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)]
elif split_input == 'row':
split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :]
else:
input_parallel = inp
input_parallel.stop_gradient = False
out = layer(input_parallel)
if gather_output:
total_out = mp_ops._c_concat(out, group=self.tp_group)
else:
total_out = out
loss = total_out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
if split_input != 'none':
grad_input = []
paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group)
if split_input == 'column':
grad_input = paddle.concat(grad_input, axis=1)
elif split_input == 'row':
grad_input = paddle.concat(grad_input, axis=0)
else:
grad_input = input_parallel.grad
return loss, grad_input
def test_column_parallel_layer(self): def test_column_parallel_layer(self):
"""Tests column parallel LayerNormLinear""" """Tests column parallel LayerNormLinear"""
...@@ -53,6 +89,7 @@ class TestLayerNormLinearTp(unittest.TestCase): ...@@ -53,6 +89,7 @@ class TestLayerNormLinearTp(unittest.TestCase):
self.out_features, self.out_features,
eps=self.eps, eps=self.eps,
parallel_mode='column', parallel_mode='column',
sequence_parallel=self.sequence_parallel,
) )
layer_pd = te.LayerNormLinear( layer_pd = te.LayerNormLinear(
self.in_features, self.in_features,
...@@ -77,25 +114,16 @@ class TestLayerNormLinearTp(unittest.TestCase): ...@@ -77,25 +114,16 @@ class TestLayerNormLinearTp(unittest.TestCase):
layer_te = fleet.distributed_model(layer_te) layer_te = fleet.distributed_model(layer_te)
optimizer_te = fleet.distributed_optimizer(optimizer_te) optimizer_te = fleet.distributed_optimizer(optimizer_te)
def train_one_step(layer, inp, optimizer, gather=False):
inp = paddle.to_tensor(inp)
inp.stop_gradient = False
out = layer(inp)
if gather:
total_out = mp_ops._c_concat(out, group=self.tp_group)
else:
total_out = out
loss = total_out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss, inp.grad
for _ in range(5): for _ in range(5):
inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype) inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype)
with te.fp8_autocast(enabled=self.fp8): with te.fp8_autocast(enabled=self.fp8):
loss_tp, grad_input = train_one_step(layer_te, inp, optimizer_te, gather=True) loss_tp, grad_input = self._train_one_step(
loss_ref, grad_input_ref = train_one_step(layer_pd, inp, optimizer_pd) layer_te,
inp,
optimizer_te,
split_input='row' if self.sequence_parallel else 'none',
gather_output=True)
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
...@@ -113,6 +141,39 @@ class TestLayerNormLinearTpFp8(TestLayerNormLinearTp): ...@@ -113,6 +141,39 @@ class TestLayerNormLinearTpFp8(TestLayerNormLinearTp):
self.atol = 1e-2 self.atol = 1e-2
self.eps = 1e-3 self.eps = 1e-3
self.fp8 = True self.fp8 = True
self.sequence_parallel = False
class TestLayerNormLinearSp(TestLayerNormLinearTp):
"""Tests LayernormLinear layer with sequence parallelism"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = 'bfloat16'
self.rtol = 1e-3
self.atol = 1e-3
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = True
class TestLayerNormLinearSpFp8(TestLayerNormLinearTp):
"""Tests LayernormLinear layer with sequence parallelism in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = 'bfloat16'
self.rtol = 1e-2
self.atol = 1e-2
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = True
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -7,6 +7,7 @@ import unittest ...@@ -7,6 +7,7 @@ import unittest
import paddle import paddle
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.fleet.layers.mpu import mp_ops
from utils import assert_allclose, assert_shape, set_random_seed from utils import assert_allclose, assert_shape, set_random_seed
import transformer_engine.paddle as te import transformer_engine.paddle as te
...@@ -29,9 +30,12 @@ class TestLayerNormMLPTp(unittest.TestCase): ...@@ -29,9 +30,12 @@ class TestLayerNormMLPTp(unittest.TestCase):
"mp_degree": self.model_parallel_size, "mp_degree": self.model_parallel_size,
"pp_degree": 1, "pp_degree": 1,
} }
strategy.hybrid_configs["mp_configs"].need_broadcast_data = False
fleet.init(is_collective=True, strategy=strategy) fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index()
self.hcg = fleet.get_hybrid_communicate_group() self.hcg = fleet.get_hybrid_communicate_group()
self.tp_group = self.hcg.get_model_parallel_group() self.tp_group = self.hcg.get_model_parallel_group()
self.world_size = self.hcg.get_model_parallel_world_size()
def set_attr(self): def set_attr(self):
"""Set test configs""" """Set test configs"""
...@@ -43,6 +47,40 @@ class TestLayerNormMLPTp(unittest.TestCase): ...@@ -43,6 +47,40 @@ class TestLayerNormMLPTp(unittest.TestCase):
self.atol = 1e-3 self.atol = 1e-3
self.eps = 1e-3 self.eps = 1e-3
self.fp8 = False self.fp8 = False
self.sequence_parallel = False
def _train_one_step(self, layer, inp, optimizer, split_input='none', gather_output=False):
inp = paddle.to_tensor(inp, stop_gradient=True)
assert split_input in ['none', 'column', 'row']
if split_input == 'column':
split_size = inp.shape[1] // self.world_size
input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)]
elif split_input == 'row':
split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :]
else:
input_parallel = inp
input_parallel.stop_gradient = False
out = layer(input_parallel)
if gather_output:
# Need to concat on the first dim, while _c_concat concats on the last dim
total_out = mp_ops._c_concat(out.T, group=self.tp_group).T
else:
total_out = out
loss = total_out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
if split_input != 'none':
grad_input = []
paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group)
if split_input == 'column':
grad_input = paddle.concat(grad_input, axis=1)
elif split_input == 'row':
grad_input = paddle.concat(grad_input, axis=0)
else:
grad_input = input_parallel.grad
return loss, grad_input
def test_parallel_layer(self): def test_parallel_layer(self):
"""Tests parallel LayerNormMLP""" """Tests parallel LayerNormMLP"""
...@@ -52,6 +90,7 @@ class TestLayerNormMLPTp(unittest.TestCase): ...@@ -52,6 +90,7 @@ class TestLayerNormMLPTp(unittest.TestCase):
ffn_hidden_size=self.ffn_hidden_size, ffn_hidden_size=self.ffn_hidden_size,
eps=self.eps, eps=self.eps,
set_parallel_mode=True, set_parallel_mode=True,
sequence_parallel=self.sequence_parallel,
) )
layer_pd = te.LayerNormMLP( layer_pd = te.LayerNormMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
...@@ -87,21 +126,16 @@ class TestLayerNormMLPTp(unittest.TestCase): ...@@ -87,21 +126,16 @@ class TestLayerNormMLPTp(unittest.TestCase):
layer_te = fleet.distributed_model(layer_te) layer_te = fleet.distributed_model(layer_te)
optimizer_te = fleet.distributed_optimizer(optimizer_te) optimizer_te = fleet.distributed_optimizer(optimizer_te)
def train_one_step(layer, inp, optimizer):
inp = paddle.to_tensor(inp)
inp.stop_gradient = False
out = layer(inp)
loss = out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss, inp.grad
for _ in range(5): for _ in range(5):
inp = paddle.uniform([self.batch_size, self.hidden_size], self.global_dtype) inp = paddle.uniform([self.batch_size, self.hidden_size], self.global_dtype)
with te.fp8_autocast(enabled=self.fp8): with te.fp8_autocast(enabled=self.fp8):
loss_tp, grad_input = train_one_step(layer_te, inp, optimizer_te) loss_tp, grad_input = self._train_one_step(
loss_ref, grad_input_ref = train_one_step(layer_pd, inp, optimizer_pd) layer_te,
inp,
optimizer_te,
split_input='row' if self.sequence_parallel else 'none',
gather_output=self.sequence_parallel)
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
...@@ -119,6 +153,39 @@ class TestLayerNormMLPTpFp8(TestLayerNormMLPTp): ...@@ -119,6 +153,39 @@ class TestLayerNormMLPTpFp8(TestLayerNormMLPTp):
self.atol = 1e-2 self.atol = 1e-2
self.eps = 1e-3 self.eps = 1e-3
self.fp8 = True self.fp8 = True
self.sequence_parallel = False
class TestLayerNormMLPSp(TestLayerNormMLPTp):
"""Tests LayerNormMLP layer with sequence parallel in BF16"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 32
self.ffn_hidden_size = 64
self.global_dtype = 'bfloat16'
self.rtol = 1e-3
self.atol = 1e-3
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = True
class TestLayerNormMLPSpFp8(TestLayerNormMLPTp):
"""Tests LayerNormMLP layer with sequence parallelism in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 32
self.ffn_hidden_size = 64
self.global_dtype = 'bfloat16'
self.rtol = 1e-2
self.atol = 1e-2
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = True
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -30,6 +30,7 @@ class TestLinearTp(unittest.TestCase): ...@@ -30,6 +30,7 @@ class TestLinearTp(unittest.TestCase):
"mp_degree": self.model_parallel_size, "mp_degree": self.model_parallel_size,
"pp_degree": 1, "pp_degree": 1,
} }
strategy.hybrid_configs["mp_configs"].need_broadcast_data = False
fleet.init(is_collective=True, strategy=strategy) fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index() self.rank = fleet.worker_index()
self.hcg = fleet.get_hybrid_communicate_group() self.hcg = fleet.get_hybrid_communicate_group()
...@@ -45,6 +46,39 @@ class TestLinearTp(unittest.TestCase): ...@@ -45,6 +46,39 @@ class TestLinearTp(unittest.TestCase):
self.rtol = 1e-3 self.rtol = 1e-3
self.atol = 1e-3 self.atol = 1e-3
self.fp8 = False self.fp8 = False
self.sequence_parallel = False
def _train_one_step(self, layer, inp, optimizer, split_input='none', gather_output=False):
inp = paddle.to_tensor(inp, stop_gradient=True)
assert split_input in ['none', 'column', 'row']
if split_input == 'column':
split_size = inp.shape[1] // self.world_size
input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)]
elif split_input == 'row':
split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :]
else:
input_parallel = inp
input_parallel.stop_gradient = False
out = layer(input_parallel)
if gather_output:
total_out = mp_ops._c_concat(out, group=self.tp_group)
else:
total_out = out
loss = total_out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
if split_input != 'none':
grad_input = []
paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group)
if split_input == 'column':
grad_input = paddle.concat(grad_input, axis=1)
elif split_input == 'row':
grad_input = paddle.concat(grad_input, axis=0)
else:
grad_input = input_parallel.grad
return loss, grad_input
def test_column_parallel_layer(self): def test_column_parallel_layer(self):
"""Tests column parallel linear""" """Tests column parallel linear"""
...@@ -53,6 +87,7 @@ class TestLinearTp(unittest.TestCase): ...@@ -53,6 +87,7 @@ class TestLinearTp(unittest.TestCase):
self.in_features, self.in_features,
self.out_features, self.out_features,
parallel_mode='column', parallel_mode='column',
sequence_parallel=self.sequence_parallel,
) )
layer_pd = te.Linear( layer_pd = te.Linear(
self.in_features, self.in_features,
...@@ -76,25 +111,16 @@ class TestLinearTp(unittest.TestCase): ...@@ -76,25 +111,16 @@ class TestLinearTp(unittest.TestCase):
layer_te = fleet.distributed_model(layer_te) layer_te = fleet.distributed_model(layer_te)
optimizer_te = fleet.distributed_optimizer(optimizer_te) optimizer_te = fleet.distributed_optimizer(optimizer_te)
def train_one_step(layer, inp, optimizer, gather=False):
inp = paddle.to_tensor(inp)
inp.stop_gradient = False
out = layer(inp)
if gather:
total_out = mp_ops._c_concat(out, group=self.tp_group)
else:
total_out = out
loss = total_out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss, inp.grad
for _ in range(5): for _ in range(5):
inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype) inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype)
with te.fp8_autocast(enabled=self.fp8): with te.fp8_autocast(enabled=self.fp8):
loss_tp, grad_input = train_one_step(layer_te, inp, optimizer_te, gather=True) loss_tp, grad_input = self._train_one_step(
loss_ref, grad_input_ref = train_one_step(layer_pd, inp, optimizer_pd) layer_te,
inp,
optimizer_te,
split_input='row' if self.sequence_parallel else 'none',
gather_output=True)
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
...@@ -105,6 +131,7 @@ class TestLinearTp(unittest.TestCase): ...@@ -105,6 +131,7 @@ class TestLinearTp(unittest.TestCase):
self.in_features, self.in_features,
self.out_features, self.out_features,
parallel_mode='row', parallel_mode='row',
sequence_parallel=self.sequence_parallel,
) )
layer_pd = te.Linear( layer_pd = te.Linear(
self.in_features, self.in_features,
...@@ -125,39 +152,18 @@ class TestLinearTp(unittest.TestCase): ...@@ -125,39 +152,18 @@ class TestLinearTp(unittest.TestCase):
optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters()) optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters())
# Note(tizheng): For this test, we cannot wrap model with fleet.distributed_model, layer_te = fleet.distributed_model(layer_te)
# because it will broadcast inputs across mp group. However, RPL expects splitted optimizer_te = fleet.distributed_optimizer(optimizer_te)
# inputs, which is different on each rank.
def train_one_step(layer, inp, optimizer, split=False):
inp = paddle.to_tensor(inp, stop_gradient=True)
if split:
# TODO(tizheng): Why not working?
# issue: https://github.com/PaddlePaddle/Paddle/issues/55565
# input_parallel = mp_ops._c_split(inp, group=layer.tp_group)
split_size = inp.shape[1] // self.world_size
input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)]
else:
input_parallel = inp
input_parallel.stop_gradient = False
out = layer(input_parallel)
loss = out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
if split:
grad_input = []
paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group)
grad_input = paddle.concat(grad_input, axis=1)
else:
grad_input = input_parallel.grad
return loss, grad_input
for _ in range(5): for _ in range(5):
inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype) inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype)
with te.fp8_autocast(enabled=self.fp8): with te.fp8_autocast(enabled=self.fp8):
loss_tp, grad_input = train_one_step(layer_te, inp, optimizer_te, split=True) loss_tp, grad_input = self._train_one_step(layer_te,
loss_ref, grad_input_ref = train_one_step(layer_pd, inp, optimizer_pd) inp,
optimizer_te,
split_input='column',
gather_output=self.sequence_parallel)
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
...@@ -174,6 +180,37 @@ class TestLinearTpFP8(TestLinearTp): ...@@ -174,6 +180,37 @@ class TestLinearTpFP8(TestLinearTp):
self.rtol = 1e-2 self.rtol = 1e-2
self.atol = 1e-2 self.atol = 1e-2
self.fp8 = True self.fp8 = True
self.sequence_parallel = False
class TestLinearSp(TestLinearTp):
"""Tests Linear layer with sequence parallelism"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = 'bfloat16'
self.rtol = 1e-3
self.atol = 1e-3
self.fp8 = False
self.sequence_parallel = True
class TestLinearSpFP8(TestLinearTp):
"""Tests Linear layer with sequence parallelism in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = 'bfloat16'
self.rtol = 1e-2
self.atol = 1e-2
self.fp8 = True
self.sequence_parallel = True
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -7,8 +7,9 @@ import unittest ...@@ -7,8 +7,9 @@ import unittest
import paddle import paddle
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.fleet.layers.mpu import mp_ops
from utils import assert_allclose, set_random_seed from utils import assert_allclose, set_random_seed, register_sequence_parallel_allreduce_hooks
import transformer_engine.paddle as te import transformer_engine.paddle as te
...@@ -29,9 +30,12 @@ class TestTransformerTp(unittest.TestCase): ...@@ -29,9 +30,12 @@ class TestTransformerTp(unittest.TestCase):
"mp_degree": self.model_parallel_size, "mp_degree": self.model_parallel_size,
"pp_degree": 1, "pp_degree": 1,
} }
strategy.hybrid_configs["mp_configs"].need_broadcast_data = False
fleet.init(is_collective=True, strategy=strategy) fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index()
self.hcg = fleet.get_hybrid_communicate_group() self.hcg = fleet.get_hybrid_communicate_group()
self.tp_group = self.hcg.get_model_parallel_group() self.tp_group = self.hcg.get_model_parallel_group()
self.world_size = self.hcg.get_model_parallel_world_size()
def set_attr(self): def set_attr(self):
"""Set test configs""" """Set test configs"""
...@@ -48,6 +52,27 @@ class TestTransformerTp(unittest.TestCase): ...@@ -48,6 +52,27 @@ class TestTransformerTp(unittest.TestCase):
self.atol = 5e-2 self.atol = 5e-2
self.eps = 1e-3 self.eps = 1e-3
self.fp8 = False self.fp8 = False
self.sequence_parallel = False
def _train_one_step(self, layer, inp_list, optimizer, fp8_enabled, sequence_parallel=False):
inp, mask = inp_list
if sequence_parallel:
split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :]
else:
input_parallel = inp
with te.fp8_autocast(enabled=fp8_enabled):
out = layer(input_parallel, mask)
if sequence_parallel:
total_out = mp_ops._c_concat(out, group=self.tp_group)
total_out = paddle.concat(paddle.split(total_out, self.world_size, axis=-1), axis=0)
else:
total_out = out
loss = total_out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss, total_out
def test_parallel_layer(self): def test_parallel_layer(self):
"""Tests parallel Transformer""" """Tests parallel Transformer"""
...@@ -64,13 +89,28 @@ class TestTransformerTp(unittest.TestCase): ...@@ -64,13 +89,28 @@ class TestTransformerTp(unittest.TestCase):
'self_attn_mask_type': self.mask_type, 'self_attn_mask_type': self.mask_type,
'layer_type': self.layer_type, 'layer_type': self.layer_type,
} }
layer_tp = te.TransformerLayer(*common_args, **common_kwargs, set_parallel_mode=True) layer_tp = te.TransformerLayer(*common_args,
**common_kwargs,
set_parallel_mode=True,
sequence_parallel=self.sequence_parallel)
layer_single = te.TransformerLayer(*common_args, **common_kwargs, set_parallel_mode=False) layer_single = te.TransformerLayer(*common_args, **common_kwargs, set_parallel_mode=False)
def _get_total_weight(local_weight, tp_group, axis): def _get_total_weight(local_weight, tp_group, axis, interleave=False):
total_weight = [] total_weight = []
partial_weight = local_weight.clone().detach() partial_weight = local_weight.clone().detach()
paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group) paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group)
if interleave:
# Due to the interleaved qkv layout, need to concat on num_head
# dimention for column parallel linear in MultiHeadAttention layer
assert axis == 0
assert [3 * self.hidden_size // self.world_size,
self.hidden_size] == partial_weight.shape
local_num_head = self.num_heads // self.world_size
for idx, _ in enumerate(total_weight):
total_weight[idx] = total_weight[idx].reshape(
[3, local_num_head, -1, self.hidden_size])
total_weight = paddle.concat(total_weight, axis=1).reshape([-1, self.hidden_size])
else:
total_weight = paddle.concat(total_weight, axis=axis) total_weight = paddle.concat(total_weight, axis=axis)
return total_weight return total_weight
...@@ -79,13 +119,16 @@ class TestTransformerTp(unittest.TestCase): ...@@ -79,13 +119,16 @@ class TestTransformerTp(unittest.TestCase):
obj = getattr(obj, name) obj = getattr(obj, name)
return obj return obj
def copy_weight(layer_src, layer_dst, partition_mode, weight_names): def copy_weight(layer_src, layer_dst, partition_mode, weight_names, interleave=False):
weight_src = _get_weight(layer_src, weight_names) weight_src = _get_weight(layer_src, weight_names)
weight_dst = _get_weight(layer_dst, weight_names) weight_dst = _get_weight(layer_dst, weight_names)
if partition_mode is None: if partition_mode is None:
total_weight = weight_src total_weight = weight_src
elif partition_mode == 'column': elif partition_mode == 'column':
total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=0) total_weight = _get_total_weight(weight_src,
tp_group=self.tp_group,
axis=0,
interleave=interleave)
elif partition_mode == 'row': elif partition_mode == 'row':
total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1) total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1)
else: else:
...@@ -95,41 +138,62 @@ class TestTransformerTp(unittest.TestCase): ...@@ -95,41 +138,62 @@ class TestTransformerTp(unittest.TestCase):
weight_dst.copy_(total_weight, True) weight_dst.copy_(total_weight, True)
copy_weight(layer_tp, layer_single, None, ['self_attention', 'layernorm_qkv', 'ln_weight']) copy_weight(layer_tp, layer_single, None, ['self_attention', 'layernorm_qkv', 'ln_weight'])
copy_weight(layer_tp, layer_single, 'column', ['self_attention', 'layernorm_qkv', 'weight']) copy_weight(layer_tp,
layer_single,
'column', ['self_attention', 'layernorm_qkv', 'weight'],
interleave=True)
copy_weight(layer_tp, layer_single, 'row', ['self_attention', 'proj', 'weight']) copy_weight(layer_tp, layer_single, 'row', ['self_attention', 'proj', 'weight'])
copy_weight(layer_tp, layer_single, None, ['layernorm_mlp', 'ln_weight']) copy_weight(layer_tp, layer_single, None, ['layernorm_mlp', 'ln_weight'])
copy_weight(layer_tp, layer_single, 'column', ['layernorm_mlp', 'fc1_weight']) copy_weight(layer_tp, layer_single, 'column', ['layernorm_mlp', 'fc1_weight'])
copy_weight(layer_tp, layer_single, 'row', ['layernorm_mlp', 'fc2_weight']) copy_weight(layer_tp, layer_single, 'row', ['layernorm_mlp', 'fc2_weight'])
optimizer_tp = paddle.optimizer.SGD(learning_rate=0.1, parameters=layer_tp.parameters()) if self.sequence_parallel:
optimizer_single = paddle.optimizer.SGD(learning_rate=0.1, register_sequence_parallel_allreduce_hooks(layer_tp, accumulation_steps=1)
optimizer_tp = paddle.optimizer.SGD(learning_rate=0.01, parameters=layer_tp.parameters())
optimizer_single = paddle.optimizer.SGD(learning_rate=0.01,
parameters=layer_single.parameters()) parameters=layer_single.parameters())
layer_tp = fleet.distributed_model(layer_tp) layer_tp = fleet.distributed_model(layer_tp)
optimizer_tp = fleet.distributed_optimizer(optimizer_tp) optimizer_tp = fleet.distributed_optimizer(optimizer_tp)
def train_one_step(layer, inp_list, optimizer, fp8_enabled):
with te.fp8_autocast(enabled=fp8_enabled):
out = layer(*inp_list)
loss = out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss
for _ in range(5): for _ in range(5):
inp = paddle.uniform([self.batch_size, self.q_seqlen, self.hidden_size], inp = paddle.uniform([self.batch_size, self.q_seqlen, self.hidden_size],
self.global_dtype) self.global_dtype)
mask = paddle.zeros(shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), mask = paddle.zeros(shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen),
dtype='bool') dtype='bool')
loss_tp = train_one_step(layer_tp, [inp, mask], optimizer_tp, self.fp8) loss_tp, out_tp = self._train_one_step(layer_tp, [inp, mask], optimizer_tp, self.fp8,
loss_single = train_one_step(layer_single, [inp, mask], optimizer_single, self.fp8) self.sequence_parallel)
loss_single, out_single = self._train_one_step(layer_single, [inp, mask],
optimizer_single, self.fp8)
assert_allclose(out_tp, out_single, rtol=self.rtol, atol=self.atol)
assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol) assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol)
class TestTransformerTpFp8(TestTransformerTp): class TestTransformerTpFp8(TestTransformerTp):
"""Tests Transformer layer with tensor parallelism in FP8""" """Tests Transformer layer with tensor parallelism in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.ffn_hidden_size = 4096
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.layer_type = 'encoder'
self.global_dtype = 'bfloat16'
self.rtol = 5e-2
self.atol = 0.5
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = False
class TestTransformerSp(TestTransformerTp):
"""Tests Transformer layer with sequence parallel in BF16"""
def set_attr(self): def set_attr(self):
"""Set test configs""" """Set test configs"""
self.batch_size = 16 self.batch_size = 16
...@@ -144,7 +208,29 @@ class TestTransformerTpFp8(TestTransformerTp): ...@@ -144,7 +208,29 @@ class TestTransformerTpFp8(TestTransformerTp):
self.rtol = 5e-2 self.rtol = 5e-2
self.atol = 5e-2 self.atol = 5e-2
self.eps = 1e-3 self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = True
class TestTransformerSpFp8(TestTransformerSp):
"""Tests Transformer layer with sequence parallelism in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.ffn_hidden_size = 4096
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.layer_type = 'encoder'
self.global_dtype = 'bfloat16'
self.rtol = 5e-2
self.atol = 0.5
self.eps = 1e-3
self.fp8 = True self.fp8 = True
self.sequence_parallel = True
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -75,6 +75,16 @@ class TestGroupSharding(TestDistributed): ...@@ -75,6 +75,16 @@ class TestGroupSharding(TestDistributed):
self.run_2gpu(str(test_root / 'parallel_tests' / 'group_sharding.py')) self.run_2gpu(str(test_root / 'parallel_tests' / 'group_sharding.py'))
class TestParallelAttention(TestDistributed):
"""Test MultiHeadAttention Layer in Parallel mode"""
@unittest.skipIf(not is_devices_enough(2), "TestParallelAttention needs 2 GPUs")
@unittest.skipIf(not gpu_has_fp8, reason)
def test_attention_tp(self):
"""Tests TransMultiHeadAttentionformer Layer with tensor parallel in BF16"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'attention_tp.py'))
class TestParallelTransformerLayer(TestDistributed): class TestParallelTransformerLayer(TestDistributed):
"""Test Transformer Layer in Parallel mode""" """Test Transformer Layer in Parallel mode"""
......
...@@ -40,9 +40,15 @@ def assert_allclose(actual, ...@@ -40,9 +40,15 @@ def assert_allclose(actual,
verbose=True): verbose=True):
"""Compare two input paddle tensors""" """Compare two input paddle tensors"""
if isinstance(actual, paddle.Tensor): if isinstance(actual, paddle.Tensor):
actual = paddle.cast(actual, 'float32').numpy() actual = paddle.cast(actual, 'float32')
if isinstance(desired, paddle.Tensor): if isinstance(desired, paddle.Tensor):
desired = paddle.cast(desired, 'float32').numpy() desired = paddle.cast(desired, 'float32')
if len(actual.shape) == 0:
actual = actual.item()
desired = desired.item()
else:
actual = actual.numpy()
desired = desired.numpy()
np.testing.assert_allclose(actual, desired, rtol, atol, equal_nan, err_msg, verbose) np.testing.assert_allclose(actual, desired, rtol, atol, equal_nan, err_msg, verbose)
...@@ -59,6 +65,7 @@ def is_devices_enough(required): ...@@ -59,6 +65,7 @@ def is_devices_enough(required):
def set_random_seed(seed): def set_random_seed(seed):
"""Set random seed for reproducability.""" """Set random seed for reproducability."""
fleet.meta_parallel.model_parallel_random_seed(seed)
hcg = fleet.get_hybrid_communicate_group() hcg = fleet.get_hybrid_communicate_group()
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
...@@ -161,3 +168,46 @@ def is_fused_attention_supported( ...@@ -161,3 +168,46 @@ def is_fused_attention_supported(
mask_type=mask_type, mask_type=mask_type,
) )
return backend != FusedAttnBackend["No_Backend"] return backend != FusedAttnBackend["No_Backend"]
def register_sequence_parallel_allreduce_hooks(model, accumulation_steps) -> None:
"""Register allreduce hooks for sequence parallel tensors"""
def is_sequence_parallel_parameter(parameter):
"""If input tensor is marked as sequence parallel tensor"""
out = getattr(parameter, "sequence_parallel", False)
return out
def create_allreduce_gradient_hook(param, accumulation_steps):
"""Create allreduce gradient hook"""
hcg = fleet.get_hybrid_communicate_group()
pg = hcg.get_model_parallel_group().process_group
step = [0]
@paddle.autograd.no_grad()
def __impl__():
step[0] += 1
if (step[0] % accumulation_steps) == 0:
if hasattr(param, "main_grad"):
pg.allreduce(param.main_grad).wait()
else:
pg.allreduce(param.grad).wait()
return __impl__
if accumulation_steps <= 0 or not paddle.distributed.is_initialized():
return
hcg = fleet.get_hybrid_communicate_group()
mp_group = hcg.get_model_parallel_group()
if mp_group.nranks <= 1:
return
params = []
for p in model.parameters():
if is_sequence_parallel_parameter(p):
params.append(p)
for p in params:
hook = create_allreduce_gradient_hook(p, accumulation_steps)
p._register_backward_hook(hook)
...@@ -291,7 +291,7 @@ std::vector<paddle::Tensor> te_layernorm_fwd_fp8(const paddle::Tensor &input, ...@@ -291,7 +291,7 @@ std::vector<paddle::Tensor> te_layernorm_fwd_fp8(const paddle::Tensor &input,
size_t N = shape[0]; size_t N = shape[0];
size_t H = shape[1]; size_t H = shape[1];
auto ln_out = paddle::empty_like(input, input.dtype(), input.place()); auto ln_out = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto mu = paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place()); auto mu = paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto rsigma = auto rsigma =
paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place()); paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
......
...@@ -114,6 +114,57 @@ def allreduce( ...@@ -114,6 +114,57 @@ def allreduce(
return output, wait_handle return output, wait_handle
def allgather(
input_: paddle.Tensor,
tp_group: Optional[dist_group_type] = None,
sync_op: bool = True,
) -> Tuple[paddle.Tensor, Any]:
"""All-gather the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if tp_group is None or tp_group.nranks == 1:
return input_, None
parallelism = tp_group.nranks
output_shape = input_.shape
output_shape[0] = output_shape[0] * parallelism
output = paddle.empty(shape=output_shape, dtype=input_.dtype)
wait_handle = tp_group.process_group.all_gather_into_tensor(output, input_, sync_op)
if sync_op:
wait_handle.wait()
return output, None
return output, wait_handle
def reduce_scatter(
input_: paddle.Tensor,
tp_group: Optional[dist_group_type] = None,
sync_op: bool = True,
) -> [paddle.Tensor, Any]:
"""Reduce-scatter the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if tp_group is None or tp_group.nranks == 1:
return input_, None
parallelism = tp_group.nranks
output_shape = input_.shape
assert (
input_.shape[0] % parallelism == 0
), f"Input sequence length {input_.shape[0]} can't be divided " \
f"exactly by sequence parallelism {parallelism}"
output_shape[0] = output_shape[0] // parallelism
output = paddle.empty(shape=output_shape, dtype=input_.dtype)
wait_handle = paddle.distributed.stream.reduce_scatter(output,
input_,
op=paddle.distributed.ReduceOp.SUM,
group=tp_group,
sync_op=sync_op)
if sync_op:
return output, None
return output, wait_handle
def identity( def identity(
input_: paddle.Tensor, input_: paddle.Tensor,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
...@@ -125,3 +176,11 @@ def identity( ...@@ -125,3 +176,11 @@ def identity(
output = mp_ops._c_identity(input_, group=tp_group) output = mp_ops._c_identity(input_, group=tp_group)
return output return output
def mark_as_sequence_parallel_parameter(parameter: paddle.Tensor):
"""
Set sequence_parallel attribute to input tensor. It is used for registering allreduce
hooks in PaddleNLP sequence parallel training.
"""
setattr(parameter, "sequence_parallel", True)
...@@ -237,8 +237,7 @@ class DotProductAttention(paddle.nn.Layer): ...@@ -237,8 +237,7 @@ class DotProductAttention(paddle.nn.Layer):
self.fused_attention_backend = tex.get_fused_attn_backend( self.fused_attention_backend = tex.get_fused_attn_backend(
TE_DType[query_layer.dtype], TE_DType[query_layer.dtype], TE_DType[query_layer.dtype], TE_DType[query_layer.dtype],
tex.get_nvte_qkv_layout(self.qkv_layout), AttnBiasType[core_attention_bias_type], tex.get_nvte_qkv_layout(self.qkv_layout), AttnBiasType[core_attention_bias_type],
AttnMaskType[self.attn_mask_type], self.attention_dropout, AttnMaskType[self.attn_mask_type], self.attention_dropout, query_layer.shape[-2],
query_layer.shape[-2],
key_value_layer.shape[-2] if key_value_layer is not None else query_layer.shape[-2], key_value_layer.shape[-2] if key_value_layer is not None else query_layer.shape[-2],
max_s_q, max_s_kv, query_layer.shape[-1]) max_s_q, max_s_kv, query_layer.shape[-1])
...@@ -401,6 +400,8 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -401,6 +400,8 @@ class MultiHeadAttention(paddle.nn.Layer):
if set to `True`, QKV and FC1 layers are used as Column Parallel if set to `True`, QKV and FC1 layers are used as Column Parallel
whereas PROJ and FC2 is used as Row Parallel as described whereas PROJ and FC2 is used as Row Parallel as described
`here <https://arxiv.org/pdf/1909.08053.pdf>`_. `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default = `None` tp_group : ProcessGroup, default = `None`
tensor parallel process group. tensor parallel process group.
rng_state_name : str, default = `local_seed` rng_state_name : str, default = `local_seed`
...@@ -427,6 +428,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -427,6 +428,7 @@ class MultiHeadAttention(paddle.nn.Layer):
attention_type: str = "self", attention_type: str = "self",
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
set_parallel_mode: bool = False, set_parallel_mode: bool = False,
sequence_parallel: bool = False,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
rng_state_name: str = 'local_seed', rng_state_name: str = 'local_seed',
backend: str = 'transformer_engine', backend: str = 'transformer_engine',
...@@ -445,7 +447,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -445,7 +447,7 @@ class MultiHeadAttention(paddle.nn.Layer):
self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group, self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group,
enable_tp=set_parallel_mode) enable_tp=set_parallel_mode)
self.tensor_parallel = self.tp_size > 1 self.tensor_parallel = self.tp_size > 1
self.sequence_parallel = self.tensor_parallel and sequence_parallel
self.hidden_size_per_attention_head = hidden_size // num_attention_heads self.hidden_size_per_attention_head = hidden_size // num_attention_heads
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
norm_factor = math.sqrt(self.hidden_size_per_attention_head) norm_factor = math.sqrt(self.hidden_size_per_attention_head)
...@@ -467,6 +469,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -467,6 +469,7 @@ class MultiHeadAttention(paddle.nn.Layer):
return_layernorm_output=return_layernorm_output, return_layernorm_output=return_layernorm_output,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
sequence_parallel=self.sequence_parallel,
tp_group=self.tp_group, tp_group=self.tp_group,
backend=self.backend, backend=self.backend,
) )
...@@ -477,6 +480,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -477,6 +480,7 @@ class MultiHeadAttention(paddle.nn.Layer):
self.weight_attr, self.weight_attr,
self.bias_attr, self.bias_attr,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
sequence_parallel=self.sequence_parallel,
tp_group=self.tp_group, tp_group=self.tp_group,
backend=self.backend, backend=self.backend,
) )
...@@ -492,6 +496,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -492,6 +496,7 @@ class MultiHeadAttention(paddle.nn.Layer):
return_layernorm_output=return_layernorm_output, return_layernorm_output=return_layernorm_output,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
sequence_parallel=self.sequence_parallel,
tp_group=self.tp_group, tp_group=self.tp_group,
backend=self.backend, backend=self.backend,
) )
...@@ -502,6 +507,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -502,6 +507,7 @@ class MultiHeadAttention(paddle.nn.Layer):
self.weight_attr, self.weight_attr,
self.bias_attr, self.bias_attr,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
sequence_parallel=self.sequence_parallel,
tp_group=self.tp_group, tp_group=self.tp_group,
backend=self.backend, backend=self.backend,
) )
...@@ -511,6 +517,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -511,6 +517,7 @@ class MultiHeadAttention(paddle.nn.Layer):
self.weight_attr, self.weight_attr,
self.bias_attr, self.bias_attr,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
sequence_parallel=self.sequence_parallel,
tp_group=self.tp_group, tp_group=self.tp_group,
backend=self.backend, backend=self.backend,
) )
...@@ -531,6 +538,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -531,6 +538,7 @@ class MultiHeadAttention(paddle.nn.Layer):
self.weight_attr, self.weight_attr,
self.bias_attr, self.bias_attr,
parallel_mode="row" if set_parallel_mode else None, parallel_mode="row" if set_parallel_mode else None,
sequence_parallel=self.sequence_parallel,
tp_group=self.tp_group, tp_group=self.tp_group,
backend=self.backend, backend=self.backend,
) )
...@@ -569,10 +577,21 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -569,10 +577,21 @@ class MultiHeadAttention(paddle.nn.Layer):
backprop. backprop.
""" """
# hidden_states: [b, s_q, hidden_size]
if self.attn_mask_type != "causal" and attention_mask is not None: if self.attn_mask_type != "causal" and attention_mask is not None:
assert (attention_mask.dtype == paddle.bool), "Attention mask must be a boolean tensor" assert (attention_mask.dtype == paddle.bool), "Attention mask must be a boolean tensor"
input_dim = len(hidden_states.shape)
if input_dim == 2:
# hidden_states: [b * s_q, hidden_size]
# need to get max_seq_len from attention_mask
assert attention_mask is not None
max_seq_len = attention_mask.shape[-1]
elif input_dim == 3:
# hidden_states: [b, s_q, hidden_size]
max_seq_len = hidden_states.shape[1]
else:
raise ValueError(f"hidden_states should have 2 or 3 dimensions, got {input_dim}.")
if self.attention_type == "self": if self.attention_type == "self":
if self.input_layernorm: if self.input_layernorm:
layernorm_qkv_outputs = self.layernorm_qkv(hidden_states) layernorm_qkv_outputs = self.layernorm_qkv(hidden_states)
...@@ -585,7 +604,8 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -585,7 +604,8 @@ class MultiHeadAttention(paddle.nn.Layer):
# [b, s_q, 3 * hidden_size] --> [b, s_q, 3, num_heads, head_size] # [b, s_q, 3 * hidden_size] --> [b, s_q, 3, num_heads, head_size]
mixed_qkv_layer = mixed_qkv_layer.reshape(shape=[ mixed_qkv_layer = mixed_qkv_layer.reshape(shape=[
0, 0, 3, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head -1, max_seq_len, 3, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head
]) ])
with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name): with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
...@@ -614,7 +634,8 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -614,7 +634,8 @@ class MultiHeadAttention(paddle.nn.Layer):
mixed_kv_layer = self.key_value(encoder_output) mixed_kv_layer = self.key_value(encoder_output)
# [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size] # [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size]
mixed_kv_layer = mixed_kv_layer.reshape(shape=[ mixed_kv_layer = mixed_kv_layer.reshape(shape=[
0, 0, 2, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head -1, max_seq_len, 2, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head
]) ])
if self.input_layernorm: if self.input_layernorm:
...@@ -627,7 +648,8 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -627,7 +648,8 @@ class MultiHeadAttention(paddle.nn.Layer):
query_layer = self.query_layer(hidden_states) query_layer = self.query_layer(hidden_states)
query_layer = query_layer.reshape(shape=[ query_layer = query_layer.reshape(shape=[
0, 0, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head -1, max_seq_len, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head
]) ])
with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name): with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
if recompute_core_attention: if recompute_core_attention:
...@@ -651,8 +673,13 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -651,8 +673,13 @@ class MultiHeadAttention(paddle.nn.Layer):
set_zero=set_zero, set_zero=set_zero,
) )
if input_dim == 3:
context_layer = paddle.reshape(
context_layer, [-1, max_seq_len, context_layer.shape[2] * context_layer.shape[3]])
else: # input_dim == 2
context_layer = paddle.reshape(context_layer, context_layer = paddle.reshape(context_layer,
[0, 0, context_layer.shape[2] * context_layer.shape[3]]) [-1, context_layer.shape[2] * context_layer.shape[3]])
# Output. [b, s, hidden] # Output. [b, s, hidden]
attention_output = self.proj(context_layer) attention_output = self.proj(context_layer)
......
...@@ -16,7 +16,7 @@ from paddle.fluid import core ...@@ -16,7 +16,7 @@ from paddle.fluid import core
from paddle.fluid.framework import _dygraph_tracer from paddle.fluid.framework import _dygraph_tracer
from ..constants import FP8BwdTensors, dist_group_type from ..constants import FP8BwdTensors, dist_group_type
from ..cpp_extensions import cast_transpose, cast_transpose_bgrad, cast_to_fp8 from ..cpp_extensions import cast_transpose, cast_transpose_bgrad, cast_to_fp8, transpose
from ..fp8 import ( from ..fp8 import (
FP8State, FP8State,
FP8TensorMeta, FP8TensorMeta,
...@@ -24,6 +24,7 @@ from ..fp8 import ( ...@@ -24,6 +24,7 @@ from ..fp8 import (
get_global_fp8_state, get_global_fp8_state,
get_fp8_te_dtype, get_fp8_te_dtype,
) )
from ..distributed import allgather
from ..profile import nvtx_range from ..profile import nvtx_range
from ..recompute import is_in_recompute_phase from ..recompute import is_in_recompute_phase
from ..fp8_buffer import FP8RecomputeBuffer from ..fp8_buffer import FP8RecomputeBuffer
...@@ -310,8 +311,8 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): ...@@ -310,8 +311,8 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
global_fp8_bwd_buffer.finalize(fp8_meta, tp_group, tp_size) global_fp8_bwd_buffer.finalize(fp8_meta, tp_group, tp_size)
@staticmethod @staticmethod
def grad_output_preprocess( def grad_output_preprocess(ctx, grad_output: paddle.Tensor,
ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: row_parallel_mode: bool) -> Tuple[Union[paddle.Tensor, None], ...]:
"""Utility function for backward. """Utility function for backward.
Returns tuple in order (all optional/None based on training precion/recipe): Returns tuple in order (all optional/None based on training precion/recipe):
R1: gathered `grad_output` in higher precision. R1: gathered `grad_output` in higher precision.
...@@ -320,13 +321,37 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): ...@@ -320,13 +321,37 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
R4: bias gradient on R1. R4: bias gradient on R1.
""" """
grad_output_mat = grad_output.reshape((-1, grad_output.shape[-1])) grad_output_mat = grad_output.reshape((-1, grad_output.shape[-1]))
gather_grad_output = row_parallel_mode and ctx.sequence_parallel
# No-FP8 case: bgrad is fused with wgrad for this case. # No-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8_enabled: if not ctx.fp8_enabled:
if gather_grad_output:
grad_output_mat, _ = allgather(grad_output_mat, ctx.tp_group)
return grad_output_mat, None, None, None return grad_output_mat, None, None, None
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
if gather_grad_output:
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
# FP8 case with gather: unfused bgrad, cast, transpose for efficient gather
if ctx.use_bias:
bgrad = grad_output_mat.sum(axis=0)
else:
bgrad = None
grad_output_c = cast_to_fp8(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
)
grad_output_c, _ = allgather(grad_output_c, ctx.tp_group)
grad_output_t = transpose(grad_output_c, fp8_dtype_backward)
return grad_output_mat, grad_output_c, grad_output_t, bgrad
# FP8 case with gather and non-FP8 wgrad
grad_output_mat, _ = allgather(grad_output_mat, ctx.tp_group)
# FP8 case without gather: cast, transpose, bgrad fused # FP8 case without gather: cast, transpose, bgrad fused
if ctx.use_bias: if ctx.use_bias:
bgrad, grad_output_c, grad_output_t = cast_transpose_bgrad( bgrad, grad_output_c, grad_output_t = cast_transpose_bgrad(
......
...@@ -12,6 +12,7 @@ from paddle.nn.initializer import Constant ...@@ -12,6 +12,7 @@ from paddle.nn.initializer import Constant
from ..constants import TE_DType from ..constants import TE_DType
from ..cpp_extensions import layernorm_fwd, layernorm_bwd from ..cpp_extensions import layernorm_fwd, layernorm_bwd
from ..distributed import mark_as_sequence_parallel_parameter
__all__ = ["LayerNorm"] __all__ = ["LayerNorm"]
...@@ -90,6 +91,11 @@ class LayerNorm(paddle.nn.Layer): ...@@ -90,6 +91,11 @@ class LayerNorm(paddle.nn.Layer):
(1 + \gamma) + \beta (1 + \gamma) + \beta
backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
backend to use for softmax operation. backend to use for softmax operation.
Parallelism parameters
----------------------
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
""" """
def __init__( def __init__(
...@@ -99,11 +105,13 @@ class LayerNorm(paddle.nn.Layer): ...@@ -99,11 +105,13 @@ class LayerNorm(paddle.nn.Layer):
weight_attr: Union[paddle.ParamAttr, None] = None, weight_attr: Union[paddle.ParamAttr, None] = None,
bias_attr: Union[paddle.ParamAttr, None, bool] = None, bias_attr: Union[paddle.ParamAttr, None, bool] = None,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
sequence_parallel: bool = False,
backend: str = 'transformer_engine', backend: str = 'transformer_engine',
) -> None: ) -> None:
super().__init__() super().__init__()
self.eps = eps self.eps = eps
self.zero_centered_gamma = zero_centered_gamma self.zero_centered_gamma = zero_centered_gamma
self.sequence_parallel = sequence_parallel
self.backend = backend self.backend = backend
self._dtype = self._helper.get_default_dtype() self._dtype = self._helper.get_default_dtype()
...@@ -130,6 +138,10 @@ class LayerNorm(paddle.nn.Layer): ...@@ -130,6 +138,10 @@ class LayerNorm(paddle.nn.Layer):
is_bias=True, is_bias=True,
) )
if self.sequence_parallel:
mark_as_sequence_parallel_parameter(self.weight)
mark_as_sequence_parallel_parameter(self.bias)
# These many SMs are subtracted from the total SM count when calling forward # These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN # and backward LayerNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as # kernels from using all SMs in the device. This is useful for cases such as
......
...@@ -16,7 +16,6 @@ from ..cpp_extensions import ( ...@@ -16,7 +16,6 @@ from ..cpp_extensions import (
layernorm_fwd, layernorm_fwd,
layernorm_fwd_fp8, layernorm_fwd_fp8,
layernorm_bwd, layernorm_bwd,
transpose,
) )
from .base import TransformerEngineBaseLayer from .base import TransformerEngineBaseLayer
...@@ -29,6 +28,7 @@ from ..distributed import ( ...@@ -29,6 +28,7 @@ from ..distributed import (
track_rng_state, track_rng_state,
set_tensor_dist_attr, set_tensor_dist_attr,
set_weight_tensor_dist_attr, set_weight_tensor_dist_attr,
mark_as_sequence_parallel_parameter,
) )
from ..fp8 import get_fp8_te_dtype from ..fp8 import get_fp8_te_dtype
from ..utils import ( from ..utils import (
...@@ -145,6 +145,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -145,6 +145,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
zero_centered_gamma: bool, zero_centered_gamma: bool,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
tensor_parallel: bool, tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
tp_size: int, tp_size: int,
) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]: ) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]:
...@@ -190,6 +191,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -190,6 +191,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
activation_dtype, activation_dtype,
parallel_mode, parallel_mode,
tensor_parallel, tensor_parallel,
sequence_parallel,
tp_group, tp_group,
is_grad_enabled, is_grad_enabled,
) )
...@@ -217,6 +219,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -217,6 +219,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ctx.zero_centered_gamma = zero_centered_gamma ctx.zero_centered_gamma = zero_centered_gamma
ctx.parallel_mode = parallel_mode ctx.parallel_mode = parallel_mode
ctx.tensor_parallel = tensor_parallel ctx.tensor_parallel = tensor_parallel
ctx.sequence_parallel = sequence_parallel
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.tp_size = tp_size ctx.tp_size = tp_size
ctx.requires_dgrad = not inp.stop_gradient ctx.requires_dgrad = not inp.stop_gradient
...@@ -256,18 +259,15 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -256,18 +259,15 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
grad_output_c, grad_output_c,
grad_output_t, grad_output_t,
bgrad, bgrad,
) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_outputs[0]) ) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_outputs[0],
ctx.parallel_mode == "row")
# Prepare ln_out for Linear bwd # Prepare ln_out for Linear bwd
ln_out_no_fp8, ln_out_t = None, None linear_inputmat = ln_out
if ctx.fp8_enabled: if ctx.fp8_enabled:
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_wgrad = not ctx.fp8_meta["recipe"].override_linear_precision.wgrad if ctx.requires_wgrad and ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if ctx.requires_wgrad: linear_inputmat = cast_from_fp8(
if fp8_wgrad:
ln_out_t = transpose(ln_out, fp8_dtype_forward)
else:
ln_out_no_fp8 = cast_from_fp8(
ln_out, ln_out,
ctx.fp8_meta["scaling_fwd"], ctx.fp8_meta["scaling_fwd"],
FP8FwdTensors.GEMM1_INPUT, FP8FwdTensors.GEMM1_INPUT,
...@@ -277,8 +277,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -277,8 +277,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
# Linear Bwd # Linear Bwd
dgrad, wgrad, bgrad_ = _linear_bwd( dgrad, wgrad, bgrad_ = _linear_bwd(
ln_out_no_fp8 if ctx.fp8_enabled else ln_out, linear_inputmat,
ln_out_t, None, # inputmat_t will be automatically computed if not provided
FP8FwdTensors.GEMM1_INPUT, FP8FwdTensors.GEMM1_INPUT,
weight, weight,
weight_t_fp8, weight_t_fp8,
...@@ -296,6 +296,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -296,6 +296,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ctx.activation_dtype, ctx.activation_dtype,
ctx.parallel_mode, ctx.parallel_mode,
ctx.tensor_parallel, ctx.tensor_parallel,
ctx.sequence_parallel,
ctx.tp_group, ctx.tp_group,
) )
...@@ -367,6 +368,8 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -367,6 +368,8 @@ class LayerNormLinear(TransformerEngineBaseLayer):
used to decide whether this Linear layer is Column Parallel Linear or Row used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_. Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed. When set to `None`, no communication is performed.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
""" """
def __init__( def __init__(
...@@ -379,6 +382,7 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -379,6 +382,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
return_layernorm_output: bool = False, return_layernorm_output: bool = False,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
parallel_mode: Optional[str] = None, parallel_mode: Optional[str] = None,
sequence_parallel: bool = False,
tp_group: Union[dist_group_type, None] = None, tp_group: Union[dist_group_type, None] = None,
backend: str = 'transformer_engine', backend: str = 'transformer_engine',
) -> None: ) -> None:
...@@ -409,6 +413,8 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -409,6 +413,8 @@ class LayerNormLinear(TransformerEngineBaseLayer):
elif self.parallel_mode == "row": elif self.parallel_mode == "row":
self.in_features = divide(self.in_features, self.tp_size) self.in_features = divide(self.in_features, self.tp_size)
self.sequence_parallel = self.tensor_parallel and sequence_parallel
# LayerNorm weights # LayerNorm weights
self.ln_weight = self.create_parameter( self.ln_weight = self.create_parameter(
shape=[self.in_features], shape=[self.in_features],
...@@ -425,6 +431,10 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -425,6 +431,10 @@ class LayerNormLinear(TransformerEngineBaseLayer):
is_bias=True, is_bias=True,
) )
if self.sequence_parallel:
mark_as_sequence_parallel_parameter(self.ln_weight)
mark_as_sequence_parallel_parameter(self.ln_bias)
# Initialize Linear weight parameter # Initialize Linear weight parameter
with track_rng_state(enable=self.tensor_parallel): with track_rng_state(enable=self.tensor_parallel):
# TE linear weight is in column major # TE linear weight is in column major
...@@ -451,6 +461,8 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -451,6 +461,8 @@ class LayerNormLinear(TransformerEngineBaseLayer):
) )
if parallel_mode == "column": if parallel_mode == "column":
set_tensor_dist_attr(self.bias, self.tensor_parallel, axis=0) set_tensor_dist_attr(self.bias, self.tensor_parallel, axis=0)
if parallel_mode == "row" and self.sequence_parallel:
mark_as_sequence_parallel_parameter(self.bias)
else: else:
self.bias = None self.bias = None
...@@ -500,6 +512,7 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -500,6 +512,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
self.zero_centered_gamma, self.zero_centered_gamma,
self.parallel_mode, self.parallel_mode,
self.tensor_parallel, self.tensor_parallel,
self.sequence_parallel,
self.tp_group, self.tp_group,
self.tp_size, self.tp_size,
) )
......
...@@ -18,7 +18,6 @@ from ..cpp_extensions import ( ...@@ -18,7 +18,6 @@ from ..cpp_extensions import (
cast_from_fp8, cast_from_fp8,
dgelu_cast_transpose_bgrad_fp8, dgelu_cast_transpose_bgrad_fp8,
gelu_fp8, gelu_fp8,
transpose,
) )
from ..distributed import ( from ..distributed import (
allreduce, allreduce,
...@@ -27,6 +26,7 @@ from ..distributed import ( ...@@ -27,6 +26,7 @@ from ..distributed import (
track_rng_state, track_rng_state,
set_tensor_dist_attr, set_tensor_dist_attr,
set_weight_tensor_dist_attr, set_weight_tensor_dist_attr,
mark_as_sequence_parallel_parameter,
) )
from ..fp8 import get_fp8_te_dtype from ..fp8 import get_fp8_te_dtype
from ..utils import ( from ..utils import (
...@@ -39,7 +39,6 @@ from ..utils import ( ...@@ -39,7 +39,6 @@ from ..utils import (
saved_tensor_allow_none, saved_tensor_allow_none,
) )
__all__ = ["LayerNormMLP"] __all__ = ["LayerNormMLP"]
...@@ -63,6 +62,7 @@ def _mlp_forward( ...@@ -63,6 +62,7 @@ def _mlp_forward(
is_grad_enabled: bool, is_grad_enabled: bool,
set_parallel_mode: bool, set_parallel_mode: bool,
tensor_parallel: bool, tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
): ):
if fp8_enabled: if fp8_enabled:
...@@ -78,6 +78,7 @@ def _mlp_forward( ...@@ -78,6 +78,7 @@ def _mlp_forward(
activation_dtype, activation_dtype,
'column' if set_parallel_mode else None, 'column' if set_parallel_mode else None,
tensor_parallel, tensor_parallel,
sequence_parallel,
tp_group, tp_group,
is_grad_enabled, is_grad_enabled,
) )
...@@ -100,6 +101,7 @@ def _mlp_forward( ...@@ -100,6 +101,7 @@ def _mlp_forward(
activation_dtype, activation_dtype,
'row' if set_parallel_mode else None, 'row' if set_parallel_mode else None,
tensor_parallel, tensor_parallel,
sequence_parallel,
tp_group, tp_group,
is_grad_enabled, is_grad_enabled,
) )
...@@ -116,6 +118,7 @@ def _mlp_forward( ...@@ -116,6 +118,7 @@ def _mlp_forward(
activation_dtype, activation_dtype,
'column' if set_parallel_mode else None, 'column' if set_parallel_mode else None,
tensor_parallel, tensor_parallel,
sequence_parallel,
tp_group, tp_group,
activation=activation, activation=activation,
) )
...@@ -132,6 +135,7 @@ def _mlp_forward( ...@@ -132,6 +135,7 @@ def _mlp_forward(
activation_dtype, activation_dtype,
'row' if set_parallel_mode else None, 'row' if set_parallel_mode else None,
tensor_parallel, tensor_parallel,
sequence_parallel,
tp_group, tp_group,
) )
return ( return (
...@@ -172,6 +176,7 @@ def _mlp_backward( ...@@ -172,6 +176,7 @@ def _mlp_backward(
activation: str, activation: str,
set_parallel_mode: bool, set_parallel_mode: bool,
tensor_parallel: bool, tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
): ):
( (
...@@ -186,13 +191,9 @@ def _mlp_backward( ...@@ -186,13 +191,9 @@ def _mlp_backward(
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False)
# FC2 Bwd # FC2 Bwd
fc2_input_no_fp8, fc2_input_t = None, None
fp8_wgrad = not fp8_meta["recipe"].override_linear_precision.wgrad fp8_wgrad = not fp8_meta["recipe"].override_linear_precision.wgrad
if requires_fc2_wgrad: if requires_fc2_wgrad and not fp8_wgrad:
if fp8_wgrad: fc2_input = cast_from_fp8(
fc2_input_t = transpose(fc2_input, fp8_dtype_forward)
else:
fc2_input_no_fp8 = cast_from_fp8(
fc2_input, fc2_input,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
fc2_input_fp8_index, fc2_input_fp8_index,
...@@ -201,8 +202,8 @@ def _mlp_backward( ...@@ -201,8 +202,8 @@ def _mlp_backward(
) )
fc2_dgrad, fc2_wgrad = _linear_bwd_fp8( fc2_dgrad, fc2_wgrad = _linear_bwd_fp8(
fc2_input_no_fp8, fc2_input,
fc2_input_t, None,
fc2_input_fp8_index, fc2_input_fp8_index,
fc2_weight_t_fp8, fc2_weight_t_fp8,
fc2_weight_fp8_index, fc2_weight_fp8_index,
...@@ -217,6 +218,7 @@ def _mlp_backward( ...@@ -217,6 +218,7 @@ def _mlp_backward(
activation_dtype, activation_dtype,
'row' if set_parallel_mode else None, 'row' if set_parallel_mode else None,
tensor_parallel, tensor_parallel,
sequence_parallel,
tp_group, tp_group,
) )
...@@ -233,11 +235,8 @@ def _mlp_backward( ...@@ -233,11 +235,8 @@ def _mlp_backward(
fc1_bgrad = fc1_bgrad_ fc1_bgrad = fc1_bgrad_
# FC1 Bwd # FC1 Bwd
dgelu_no_fp8, fc1_input_no_fp8, fc1_input_t = None, None, None dgelu_no_fp8 = None
if requires_fc1_wgrad: if requires_fc1_wgrad and not fp8_wgrad:
if fp8_wgrad:
fc1_input_t = transpose(fc1_input, fp8_dtype_forward)
else:
# TODO(tizheng) Paddle lacks fused dgelu_bgrad OP. Cast from dgrad(fp8) instead. # TODO(tizheng) Paddle lacks fused dgelu_bgrad OP. Cast from dgrad(fp8) instead.
dgelu_no_fp8 = cast_from_fp8( dgelu_no_fp8 = cast_from_fp8(
dgelu, dgelu,
...@@ -246,7 +245,7 @@ def _mlp_backward( ...@@ -246,7 +245,7 @@ def _mlp_backward(
fp8_dtype_backward, fp8_dtype_backward,
TE_DType[activation_dtype], TE_DType[activation_dtype],
) )
fc1_input_no_fp8 = cast_from_fp8( fc1_input = cast_from_fp8(
fc1_input, fc1_input,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
fc1_input_fp8_index, fc1_input_fp8_index,
...@@ -255,8 +254,8 @@ def _mlp_backward( ...@@ -255,8 +254,8 @@ def _mlp_backward(
) )
fc1_dgrad, fc1_wgrad = _linear_bwd_fp8( fc1_dgrad, fc1_wgrad = _linear_bwd_fp8(
fc1_input_no_fp8, fc1_input,
fc1_input_t, None,
fc1_input_fp8_index, fc1_input_fp8_index,
fc1_weight_t_fp8, fc1_weight_t_fp8,
fc1_weight_fp8_index, fc1_weight_fp8_index,
...@@ -271,6 +270,7 @@ def _mlp_backward( ...@@ -271,6 +270,7 @@ def _mlp_backward(
activation_dtype, activation_dtype,
'column' if set_parallel_mode else None, 'column' if set_parallel_mode else None,
tensor_parallel, tensor_parallel,
sequence_parallel,
tp_group, tp_group,
) )
else: else:
...@@ -284,6 +284,7 @@ def _mlp_backward( ...@@ -284,6 +284,7 @@ def _mlp_backward(
activation_dtype, activation_dtype,
'row' if set_parallel_mode else None, 'row' if set_parallel_mode else None,
tensor_parallel, tensor_parallel,
sequence_parallel,
tp_group, tp_group,
gelu_input=fc1_out, gelu_input=fc1_out,
activation=activation, activation=activation,
...@@ -298,6 +299,7 @@ def _mlp_backward( ...@@ -298,6 +299,7 @@ def _mlp_backward(
activation_dtype, activation_dtype,
'column' if set_parallel_mode else None, 'column' if set_parallel_mode else None,
tensor_parallel, tensor_parallel,
sequence_parallel,
tp_group, tp_group,
) )
return ( return (
...@@ -337,6 +339,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -337,6 +339,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
activation: str, activation: str,
set_parallel_mode: bool, set_parallel_mode: bool,
tensor_parallel: bool, tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
tp_size: int, tp_size: int,
) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]: ) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]:
...@@ -398,6 +401,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -398,6 +401,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
is_grad_enabled, is_grad_enabled,
set_parallel_mode, set_parallel_mode,
tensor_parallel, tensor_parallel,
sequence_parallel,
tp_group, tp_group,
) )
...@@ -429,6 +433,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -429,6 +433,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
ctx.zero_centered_gamma = zero_centered_gamma ctx.zero_centered_gamma = zero_centered_gamma
ctx.set_parallel_mode = set_parallel_mode ctx.set_parallel_mode = set_parallel_mode
ctx.tensor_parallel = tensor_parallel ctx.tensor_parallel = tensor_parallel
ctx.sequence_parallel = sequence_parallel
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.tp_size = tp_size ctx.tp_size = tp_size
ctx.requires_dgrad = not inp.stop_gradient ctx.requires_dgrad = not inp.stop_gradient
...@@ -476,7 +481,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -476,7 +481,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
grad_output_c, grad_output_c,
grad_output_t, grad_output_t,
fc2_bgrad, fc2_bgrad,
) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_outputs[0]) ) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_outputs[0], True)
( (
fc1_dgrad, fc1_dgrad,
...@@ -513,6 +518,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -513,6 +518,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
ctx.activation, ctx.activation,
ctx.set_parallel_mode, ctx.set_parallel_mode,
ctx.tensor_parallel, ctx.tensor_parallel,
ctx.sequence_parallel,
ctx.tp_group, ctx.tp_group,
) )
if not ctx.fp8_enabled: if not ctx.fp8_enabled:
...@@ -588,6 +594,8 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -588,6 +594,8 @@ class LayerNormMLP(TransformerEngineBaseLayer):
set_parallel_mode : bool, default = `False` set_parallel_mode : bool, default = `False`
if set to `True`, FC1 is used as Column Parallel and FC2 is used as Row if set to `True`, FC1 is used as Column Parallel and FC2 is used as Row
Parallel as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_. Parallel as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_group : paddle.distributed.collective.Group, default = `None` tp_group : paddle.distributed.collective.Group, default = `None`
tensor parallel process group. tensor parallel process group.
...@@ -604,6 +612,7 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -604,6 +612,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
return_layernorm_output: bool = False, return_layernorm_output: bool = False,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
set_parallel_mode: bool = False, set_parallel_mode: bool = False,
sequence_parallel: bool = False,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
backend: str = 'transformer_engine', backend: str = 'transformer_engine',
) -> None: ) -> None:
...@@ -626,6 +635,7 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -626,6 +635,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
enable_tp=set_parallel_mode) enable_tp=set_parallel_mode)
self.tensor_parallel = self.tp_size > 1 self.tensor_parallel = self.tp_size > 1
self.set_parallel_mode = set_parallel_mode self.set_parallel_mode = set_parallel_mode
self.sequence_parallel = self.tensor_parallel and sequence_parallel
if self.set_parallel_mode: if self.set_parallel_mode:
self.size_per_partition = divide(self.ffn_hidden_size, self.tp_size) self.size_per_partition = divide(self.ffn_hidden_size, self.tp_size)
...@@ -648,6 +658,10 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -648,6 +658,10 @@ class LayerNormMLP(TransformerEngineBaseLayer):
is_bias=True, is_bias=True,
) )
if self.sequence_parallel:
mark_as_sequence_parallel_parameter(self.ln_weight)
mark_as_sequence_parallel_parameter(self.ln_bias)
# FC1 weights # FC1 weights
with track_rng_state(enable=self.tensor_parallel): with track_rng_state(enable=self.tensor_parallel):
self.fc1_weight = self.create_parameter( self.fc1_weight = self.create_parameter(
...@@ -698,6 +712,8 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -698,6 +712,8 @@ class LayerNormMLP(TransformerEngineBaseLayer):
dtype=self._dtype, dtype=self._dtype,
is_bias=True, is_bias=True,
) )
if self.set_parallel_mode and self.sequence_parallel:
mark_as_sequence_parallel_parameter(self.fc2_bias)
else: else:
self.fc2_bias = None self.fc2_bias = None
...@@ -751,6 +767,7 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -751,6 +767,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
self.activation, self.activation,
self.set_parallel_mode, self.set_parallel_mode,
self.tensor_parallel, self.tensor_parallel,
self.sequence_parallel,
self.tp_group, self.tp_group,
self.tp_size, self.tp_size,
) )
......
...@@ -18,14 +18,17 @@ from .base import ( ...@@ -18,14 +18,17 @@ from .base import (
) )
from ..constants import FP8FwdTensors, FP8BwdTensors, GemmParallelModes, dist_group_type from ..constants import FP8FwdTensors, FP8BwdTensors, GemmParallelModes, dist_group_type
from ..cpp_extensions import gemm, fp8_gemm, cast_to_fp8, cast_transpose from ..cpp_extensions import gemm, fp8_gemm, cast_to_fp8, cast_transpose, transpose
from ..distributed import ( from ..distributed import (
allgather,
allreduce, allreduce,
get_tp_group_and_world_size, get_tp_group_and_world_size,
identity, identity,
reduce_scatter,
track_rng_state, track_rng_state,
set_tensor_dist_attr, set_tensor_dist_attr,
set_weight_tensor_dist_attr, set_weight_tensor_dist_attr,
mark_as_sequence_parallel_parameter,
) )
from ..fp8 import get_fp8_te_dtype from ..fp8 import get_fp8_te_dtype
from ..utils import ( from ..utils import (
...@@ -52,6 +55,7 @@ def _linear_fwd_fp8( ...@@ -52,6 +55,7 @@ def _linear_fwd_fp8(
activation_dtype: paddle.dtype, activation_dtype: paddle.dtype,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
tensor_parallel: bool, tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
is_grad_enabled: bool, is_grad_enabled: bool,
): ):
...@@ -60,6 +64,11 @@ def _linear_fwd_fp8( ...@@ -60,6 +64,11 @@ def _linear_fwd_fp8(
bias_dtype = get_bias_dtype(activation_dtype) bias_dtype = get_bias_dtype(activation_dtype)
bias = cast_if_needed(bias, bias_dtype) bias = cast_if_needed(bias, bias_dtype)
if parallel_mode == "column" and sequence_parallel:
inputmat_total, _ = allgather(inputmat, tp_group)
else:
inputmat_total = inputmat
if is_grad_enabled: if is_grad_enabled:
weight_fp8, weight_t_fp8 = cast_transpose( weight_fp8, weight_t_fp8 = cast_transpose(
weight, weight,
...@@ -81,7 +90,7 @@ def _linear_fwd_fp8( ...@@ -81,7 +90,7 @@ def _linear_fwd_fp8(
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
weight_fp8_index, weight_fp8_index,
fp8_dtype_forward, fp8_dtype_forward,
inputmat, inputmat_total,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
inputmat_fp8_index, inputmat_fp8_index,
fp8_dtype_forward, fp8_dtype_forward,
...@@ -92,8 +101,9 @@ def _linear_fwd_fp8( ...@@ -92,8 +101,9 @@ def _linear_fwd_fp8(
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=_2X_ACC_FPROP,
) )
# Row Parallel Linear if parallel_mode == "row" and sequence_parallel:
if parallel_mode == "row" and tensor_parallel: out, _ = reduce_scatter(out, tp_group)
elif parallel_mode == "row" and tensor_parallel:
out, _ = allreduce(out, tp_group) out, _ = allreduce(out, tp_group)
return out, weight_t_fp8 return out, weight_t_fp8
...@@ -111,11 +121,17 @@ def _linear_fwd_non_fp8( ...@@ -111,11 +121,17 @@ def _linear_fwd_non_fp8(
activation_dtype: paddle.dtype, activation_dtype: paddle.dtype,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
tensor_parallel: bool, tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
activation: str = "", activation: str = "",
): ):
"""Non-FP8 path of Linear Fwd""" """Non-FP8 path of Linear Fwd"""
if parallel_mode == "column" and sequence_parallel:
inputmat_total, _ = allgather(inputmat, tp_group)
else:
inputmat_total = inputmat
# Layer parameters are initialized as float32 dtype by default. # Layer parameters are initialized as float32 dtype by default.
# Cast the parameters to activation_dtype if the current dtype # Cast the parameters to activation_dtype if the current dtype
# does not match activation_dtype. The casting is inplace, so it # does not match activation_dtype. The casting is inplace, so it
...@@ -126,13 +142,13 @@ def _linear_fwd_non_fp8( ...@@ -126,13 +142,13 @@ def _linear_fwd_non_fp8(
if fp8_calibration: if fp8_calibration:
# amax of input # amax of input
fp8_meta["scaling_fwd"].amax_history[0, inputmat_fp8_index.value] = \ fp8_meta["scaling_fwd"].amax_history[0, inputmat_fp8_index.value] = \
paddle.max(paddle.abs(inputmat)).item() paddle.max(paddle.abs(inputmat_total)).item()
# amax of weight # amax of weight
fp8_meta["scaling_fwd"].amax_history[0, weight_fp8_index.value] = \ fp8_meta["scaling_fwd"].amax_history[0, weight_fp8_index.value] = \
paddle.max(paddle.abs(weight)).item() paddle.max(paddle.abs(weight)).item()
outputs = gemm(weight, outputs = gemm(weight,
inputmat, inputmat_total,
activation_dtype, activation_dtype,
get_workspace(), get_workspace(),
bias=bias, bias=bias,
...@@ -144,8 +160,10 @@ def _linear_fwd_non_fp8( ...@@ -144,8 +160,10 @@ def _linear_fwd_non_fp8(
return out, gelu_out return out, gelu_out
out, _, _ = outputs out, _, _ = outputs
# Row Parallel Linear
if parallel_mode == "row" and tensor_parallel: if parallel_mode == "row" and sequence_parallel:
out, _ = reduce_scatter(out, tp_group)
elif parallel_mode == "row" and tensor_parallel:
out, _ = allreduce(out, tp_group) out, _ = allreduce(out, tp_group)
return out return out
...@@ -163,6 +181,7 @@ def _linear_fwd( ...@@ -163,6 +181,7 @@ def _linear_fwd(
activation_dtype: paddle.dtype, activation_dtype: paddle.dtype,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
tensor_parallel: bool, tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
is_grad_enabled: bool, is_grad_enabled: bool,
): ):
...@@ -178,6 +197,7 @@ def _linear_fwd( ...@@ -178,6 +197,7 @@ def _linear_fwd(
activation_dtype, activation_dtype,
parallel_mode, parallel_mode,
tensor_parallel, tensor_parallel,
sequence_parallel,
tp_group, tp_group,
is_grad_enabled, is_grad_enabled,
) )
...@@ -194,6 +214,7 @@ def _linear_fwd( ...@@ -194,6 +214,7 @@ def _linear_fwd(
activation_dtype, activation_dtype,
parallel_mode, parallel_mode,
tensor_parallel, tensor_parallel,
sequence_parallel,
tp_group, tp_group,
) )
return ( return (
...@@ -219,9 +240,20 @@ def _linear_bwd_fp8( ...@@ -219,9 +240,20 @@ def _linear_bwd_fp8(
activation_dtype: paddle.dtype, activation_dtype: paddle.dtype,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
tensor_parallel: bool, tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
): ):
dgrad, wgrad, handle = None, None, None dgrad, wgrad, handle = None, None, None
# Overlap input AG with dgrad
inputmat_total = None
inputmat_t_total = None
if requires_wgrad and parallel_mode == "column" and sequence_parallel:
inputmat_total, handle = allgather(inputmat, tp_group, sync_op=not requires_dgrad)
else:
inputmat_total = inputmat
inputmat_t_total = inputmat_t
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False)
if requires_dgrad: if requires_dgrad:
...@@ -238,13 +270,21 @@ def _linear_bwd_fp8( ...@@ -238,13 +270,21 @@ def _linear_bwd_fp8(
get_workspace(), get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
) )
if parallel_mode == "column" and tensor_parallel:
# Overlap dgrad-RS/AR with wgrad
if parallel_mode == "column" and sequence_parallel:
if handle is not None:
handle.wait()
dgrad, handle = reduce_scatter(dgrad, tp_group, sync_op=False)
elif parallel_mode == "column" and tensor_parallel:
dgrad, handle = allreduce(dgrad, tp_group, sync_op=False) dgrad, handle = allreduce(dgrad, tp_group, sync_op=False)
if requires_wgrad: if requires_wgrad:
if not fp8_meta["recipe"].override_linear_precision.wgrad: if not fp8_meta["recipe"].override_linear_precision.wgrad:
if inputmat_t_total is None:
inputmat_t_total = transpose(inputmat_total, fp8_dtype_backward)
wgrad = fp8_gemm( wgrad = fp8_gemm(
inputmat_t, inputmat_t_total,
fwd_scale_inverses, fwd_scale_inverses,
inputmat_fp8_index, inputmat_fp8_index,
fp8_dtype_forward, fp8_dtype_forward,
...@@ -258,7 +298,7 @@ def _linear_bwd_fp8( ...@@ -258,7 +298,7 @@ def _linear_bwd_fp8(
) )
else: else:
wgrad, _, _ = gemm( wgrad, _, _ = gemm(
inputmat, inputmat_total,
grad_output, grad_output,
activation_dtype, activation_dtype,
get_workspace(), get_workspace(),
...@@ -282,6 +322,7 @@ def _linear_bwd_non_fp8( ...@@ -282,6 +322,7 @@ def _linear_bwd_non_fp8(
activation_dtype: paddle.dtype, activation_dtype: paddle.dtype,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
tensor_parallel: bool, tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
gelu_input: Union[paddle.Tensor, None] = None, gelu_input: Union[paddle.Tensor, None] = None,
activation: str = "", activation: str = "",
...@@ -290,6 +331,14 @@ def _linear_bwd_non_fp8( ...@@ -290,6 +331,14 @@ def _linear_bwd_non_fp8(
Performs Linear Backward. Optionally, fuses GELU backward and dbias. Performs Linear Backward. Optionally, fuses GELU backward and dbias.
""" """
dgrad, wgrad, bgrad, handle = None, None, None, None dgrad, wgrad, bgrad, handle = None, None, None, None
# Overlap input AG with dgrad
inputmat_total = None
if requires_wgrad and parallel_mode == "column" and sequence_parallel:
inputmat_total, handle = allgather(inputmat, tp_group, sync_op=not requires_dgrad)
else:
inputmat_total = inputmat
if requires_dgrad: if requires_dgrad:
dgrad, _, _ = gemm( dgrad, _, _ = gemm(
weight, weight,
...@@ -301,12 +350,17 @@ def _linear_bwd_non_fp8( ...@@ -301,12 +350,17 @@ def _linear_bwd_non_fp8(
gelu_input=gelu_input, gelu_input=gelu_input,
grad=True, grad=True,
) )
if parallel_mode == "column" and tensor_parallel: # Overlap dgrad-RS/AR with wgrad
if parallel_mode == "column" and sequence_parallel:
if handle is not None:
handle.wait()
dgrad, handle = reduce_scatter(dgrad, tp_group, sync_op=False)
elif parallel_mode == "column" and tensor_parallel:
dgrad, handle = allreduce(dgrad, tp_group, sync_op=False) dgrad, handle = allreduce(dgrad, tp_group, sync_op=False)
if requires_wgrad: if requires_wgrad:
wgrad, bgrad, _ = gemm( wgrad, bgrad, _ = gemm(
inputmat, inputmat_total,
grad_output, grad_output,
activation_dtype, activation_dtype,
get_workspace(), get_workspace(),
...@@ -343,6 +397,7 @@ def _linear_bwd( ...@@ -343,6 +397,7 @@ def _linear_bwd(
activation_dtype: paddle.dtype, activation_dtype: paddle.dtype,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
tensor_parallel: bool, tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
): ):
dgrad, wgrad, bgrad = None, None, None dgrad, wgrad, bgrad = None, None, None
...@@ -364,6 +419,7 @@ def _linear_bwd( ...@@ -364,6 +419,7 @@ def _linear_bwd(
activation_dtype, activation_dtype,
parallel_mode, parallel_mode,
tensor_parallel, tensor_parallel,
sequence_parallel,
tp_group, tp_group,
) )
else: else:
...@@ -377,6 +433,7 @@ def _linear_bwd( ...@@ -377,6 +433,7 @@ def _linear_bwd(
activation_dtype, activation_dtype,
parallel_mode, parallel_mode,
tensor_parallel, tensor_parallel,
sequence_parallel,
tp_group, tp_group,
) )
return dgrad, wgrad, bgrad return dgrad, wgrad, bgrad
...@@ -399,6 +456,7 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -399,6 +456,7 @@ class _Linear(paddle.autograd.PyLayer):
is_grad_enabled: bool, is_grad_enabled: bool,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
tensor_parallel: bool, tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
tp_size: int, tp_size: int,
) -> paddle.Tensor: ) -> paddle.Tensor:
...@@ -413,11 +471,11 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -413,11 +471,11 @@ class _Linear(paddle.autograd.PyLayer):
inputmat_no_fp8 = inputmat inputmat_no_fp8 = inputmat
# FP8 casting # FP8 casting
inputmat_t = None
if fp8_enabled: if fp8_enabled:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if (not fp8_meta["recipe"].override_linear_precision.wgrad and is_grad_enabled
if not fp8_meta["recipe"].override_linear_precision.wgrad: and not sequence_parallel):
if is_grad_enabled:
inputmat, inputmat_t = cast_transpose( inputmat, inputmat_t = cast_transpose(
inputmat, inputmat,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
...@@ -431,13 +489,6 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -431,13 +489,6 @@ class _Linear(paddle.autograd.PyLayer):
FP8FwdTensors.GEMM1_INPUT, FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
) )
else:
inputmat, inputmat_t = cast_to_fp8(
inputmat,
fp8_meta["scaling_fwd"],
FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
), None
# GEMM Fwd # GEMM Fwd
out, weight_t_fp8 = _linear_fwd( out, weight_t_fp8 = _linear_fwd(
...@@ -453,16 +504,21 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -453,16 +504,21 @@ class _Linear(paddle.autograd.PyLayer):
activation_dtype, activation_dtype,
parallel_mode, parallel_mode,
tensor_parallel, tensor_parallel,
sequence_parallel,
tp_group, tp_group,
is_grad_enabled, is_grad_enabled,
) )
if is_grad_enabled: if is_grad_enabled:
fp8_wgrad = fp8_enabled and not fp8_meta["recipe"].override_linear_precision.wgrad saved_inputmat = None
if fp8_enabled and sequence_parallel:
saved_inputmat = inputmat
else:
saved_inputmat = inputmat_no_fp8
save_for_backward_allow_none( save_for_backward_allow_none(
ctx, ctx,
inputmat_no_fp8 if not weight.stop_gradient and not fp8_wgrad else None, saved_inputmat,
inputmat_t if not weight.stop_gradient and fp8_wgrad else None, inputmat_t,
weight, weight,
weight_t_fp8 if fp8_enabled else None, weight_t_fp8 if fp8_enabled else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8_enabled else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8_enabled else None,
...@@ -474,6 +530,7 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -474,6 +530,7 @@ class _Linear(paddle.autograd.PyLayer):
ctx.inp_shape = inp.shape ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode ctx.parallel_mode = parallel_mode
ctx.tensor_parallel = tensor_parallel ctx.tensor_parallel = tensor_parallel
ctx.sequence_parallel = sequence_parallel
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.tp_size = tp_size ctx.tp_size = tp_size
ctx.requires_dgrad = not inp.stop_gradient ctx.requires_dgrad = not inp.stop_gradient
...@@ -503,7 +560,8 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -503,7 +560,8 @@ class _Linear(paddle.autograd.PyLayer):
grad_output_c, grad_output_c,
grad_output_t, grad_output_t,
bgrad, bgrad,
) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_output) ) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_output,
ctx.parallel_mode == "row")
dgrad, wgrad, bgrad_ = _linear_bwd( dgrad, wgrad, bgrad_ = _linear_bwd(
inputmat, inputmat,
...@@ -525,6 +583,7 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -525,6 +583,7 @@ class _Linear(paddle.autograd.PyLayer):
ctx.activation_dtype, ctx.activation_dtype,
ctx.parallel_mode, ctx.parallel_mode,
ctx.tensor_parallel, ctx.tensor_parallel,
ctx.sequence_parallel,
ctx.tp_group, ctx.tp_group,
) )
...@@ -570,7 +629,8 @@ class Linear(TransformerEngineBaseLayer): ...@@ -570,7 +629,8 @@ class Linear(TransformerEngineBaseLayer):
used to decide whether this Linear layer is Column Parallel Linear or Row used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_. Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed. When set to `None`, no communication is performed.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
""" """
def __init__( def __init__(
...@@ -580,6 +640,7 @@ class Linear(TransformerEngineBaseLayer): ...@@ -580,6 +640,7 @@ class Linear(TransformerEngineBaseLayer):
weight_attr: Union[paddle.ParamAttr, None] = None, weight_attr: Union[paddle.ParamAttr, None] = None,
bias_attr: Union[paddle.ParamAttr, None, bool] = None, bias_attr: Union[paddle.ParamAttr, None, bool] = None,
parallel_mode: Optional[str] = None, parallel_mode: Optional[str] = None,
sequence_parallel: bool = False,
tp_group: Union[dist_group_type, None] = None, tp_group: Union[dist_group_type, None] = None,
backend: str = 'transformer_engine', backend: str = 'transformer_engine',
) -> None: ) -> None:
...@@ -605,6 +666,8 @@ class Linear(TransformerEngineBaseLayer): ...@@ -605,6 +666,8 @@ class Linear(TransformerEngineBaseLayer):
elif self.parallel_mode == "row": elif self.parallel_mode == "row":
self.in_features = divide(self.in_features, self.tp_size) self.in_features = divide(self.in_features, self.tp_size)
self.sequence_parallel = self.tensor_parallel and sequence_parallel
# Initialize weight parameter # Initialize weight parameter
with track_rng_state(enable=self.tensor_parallel): with track_rng_state(enable=self.tensor_parallel):
# TE linear weight is in column major # TE linear weight is in column major
...@@ -631,6 +694,8 @@ class Linear(TransformerEngineBaseLayer): ...@@ -631,6 +694,8 @@ class Linear(TransformerEngineBaseLayer):
) )
if parallel_mode == "column": if parallel_mode == "column":
set_tensor_dist_attr(self.bias, self.tensor_parallel, axis=0) set_tensor_dist_attr(self.bias, self.tensor_parallel, axis=0)
if parallel_mode == "row" and self.sequence_parallel:
mark_as_sequence_parallel_parameter(self.bias)
else: else:
self.bias = None self.bias = None
...@@ -665,6 +730,7 @@ class Linear(TransformerEngineBaseLayer): ...@@ -665,6 +730,7 @@ class Linear(TransformerEngineBaseLayer):
paddle.is_grad_enabled(), paddle.is_grad_enabled(),
self.parallel_mode, self.parallel_mode,
self.tensor_parallel, self.tensor_parallel,
self.sequence_parallel,
self.tp_group, self.tp_group,
self.tp_size, self.tp_size,
) )
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""Transformer""" """Transformer"""
from typing import Optional, Union from typing import Optional, Union
import warnings
import paddle import paddle
from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd
...@@ -75,6 +76,8 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -75,6 +76,8 @@ class TransformerLayer(paddle.nn.Layer):
if set to `True`, QKV and FC1 layers are used as Column Parallel if set to `True`, QKV and FC1 layers are used as Column Parallel
whereas PROJ and FC2 is used as Row Parallel as described whereas PROJ and FC2 is used as Row Parallel as described
`here <https://arxiv.org/pdf/1909.08053.pdf>`_. `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default = `None` tp_group : ProcessGroup, default = `None`
tensor parallel process group. tensor parallel process group.
attention_dropout_rng_state_name : str, default = `local_seed` attention_dropout_rng_state_name : str, default = `local_seed`
...@@ -107,6 +110,7 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -107,6 +110,7 @@ class TransformerLayer(paddle.nn.Layer):
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
activation: str = 'gelu', activation: str = 'gelu',
set_parallel_mode: bool = False, set_parallel_mode: bool = False,
sequence_parallel: bool = False,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
attention_dropout_rng_state_name: str = 'local_seed', attention_dropout_rng_state_name: str = 'local_seed',
hidden_dropout_rng_state_name: str = 'global_seed', hidden_dropout_rng_state_name: str = 'global_seed',
...@@ -122,7 +126,13 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -122,7 +126,13 @@ class TransformerLayer(paddle.nn.Layer):
self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group, self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group,
enable_tp=set_parallel_mode) enable_tp=set_parallel_mode)
self.tensor_parallel = self.tp_size > 1 self.tensor_parallel = self.tp_size > 1
self.sequence_parallel = self.tensor_parallel and sequence_parallel
self.hidden_dropout_rng_state_name = hidden_dropout_rng_state_name self.hidden_dropout_rng_state_name = hidden_dropout_rng_state_name
# SP needs local seed for hidden dropout
if self.sequence_parallel and self.hidden_dropout_rng_state_name == 'global_seed':
warnings.warn("RNG state for hidden dropout needs to be different across TP ranks. "
"Forcing hidden_dropout_rng_state_name to 'local_seed'")
self.hidden_dropout_rng_state_name = 'local_seed'
assert (self_attn_mask_type assert (self_attn_mask_type
in AttnMaskTypes), f"self_attn_mask_type {self_attn_mask_type} not supported" in AttnMaskTypes), f"self_attn_mask_type {self_attn_mask_type} not supported"
...@@ -141,6 +151,7 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -141,6 +151,7 @@ class TransformerLayer(paddle.nn.Layer):
"return_layernorm_output": apply_residual_connection_post_layernorm, "return_layernorm_output": apply_residual_connection_post_layernorm,
"zero_centered_gamma": zero_centered_gamma, "zero_centered_gamma": zero_centered_gamma,
"set_parallel_mode": set_parallel_mode, "set_parallel_mode": set_parallel_mode,
"sequence_parallel": self.sequence_parallel,
"tp_group": tp_group, "tp_group": tp_group,
"rng_state_name": attention_dropout_rng_state_name, "rng_state_name": attention_dropout_rng_state_name,
"backend": backend, "backend": backend,
...@@ -173,6 +184,7 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -173,6 +184,7 @@ class TransformerLayer(paddle.nn.Layer):
return_layernorm_output=apply_residual_connection_post_layernorm, return_layernorm_output=apply_residual_connection_post_layernorm,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
set_parallel_mode=set_parallel_mode, set_parallel_mode=set_parallel_mode,
sequence_parallel=self.sequence_parallel,
tp_group=tp_group, tp_group=tp_group,
backend=backend, backend=backend,
) )
...@@ -186,6 +198,7 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -186,6 +198,7 @@ class TransformerLayer(paddle.nn.Layer):
weight_attr, weight_attr,
bias_attr, bias_attr,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
sequence_parallel=self.sequence_parallel,
backend=backend, backend=backend,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment