"vscode:/vscode.git/clone" did not exist on "45d94235a56f30305408c32b48a03711c3bb6824"
Unverified Commit e547f8e2 authored by Tian Zheng's avatar Tian Zheng Committed by GitHub
Browse files

[Paddle] Add sequence parallel (#561)



* Add SP for linear
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add SP for LayerNormLinear
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add SP for LayerNormMLP
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add SP API for transformer layer
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add sequence_parallel attr
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add SP unittests for Transformer and Attention
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix compatibility with PaddleNLP
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Copyright
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7a3ed9e2
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Unittest for Transformer layer in tensor parallel"""
import unittest
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.layers.mpu import mp_ops
from utils import assert_allclose, set_random_seed, register_sequence_parallel_allreduce_hooks
import transformer_engine.paddle as te
class TestAttentionTp(unittest.TestCase):
"""Tests MultiHeadAttention layer with model parallel in BF16"""
def setUp(self):
self.set_attr()
self.init_dist_env()
paddle.set_default_dtype(self.global_dtype)
def init_dist_env(self):
"""Init Paddle Fleet environment"""
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
}
strategy.hybrid_configs["mp_configs"].need_broadcast_data = False
fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index()
self.hcg = fleet.get_hybrid_communicate_group()
self.tp_group = self.hcg.get_model_parallel_group()
self.world_size = self.hcg.get_model_parallel_world_size()
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.global_dtype = 'bfloat16'
self.rtol = 5e-3
self.atol = 5e-3
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = False
def _train_one_step(self, layer, inp_list, optimizer, fp8_enabled, sequence_parallel=False):
inp, mask = inp_list
if sequence_parallel:
split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :]
else:
input_parallel = inp
with te.fp8_autocast(enabled=fp8_enabled):
out = layer(input_parallel, mask)
if sequence_parallel:
total_out = mp_ops._c_concat(out, group=self.tp_group)
total_out = paddle.concat(paddle.split(total_out, self.world_size, axis=-1), axis=0)
else:
total_out = out
loss = total_out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss, total_out
def test_parallel_layer(self):
"""Tests parallel Transformer"""
set_random_seed(1024)
common_args = (
self.hidden_size,
self.num_heads,
)
common_kwargs = {
'layernorm_epsilon': self.eps,
'attention_dropout': 0.0,
'attn_mask_type': self.mask_type,
'attention_type': 'self',
"tp_group": self.tp_group,
"input_layernorm": True,
}
layer_tp = te.MultiHeadAttention(*common_args,
**common_kwargs,
set_parallel_mode=True,
sequence_parallel=self.sequence_parallel)
layer_single = te.MultiHeadAttention(*common_args, **common_kwargs, set_parallel_mode=False)
def _get_total_weight(local_weight, tp_group, axis, interleave=False):
total_weight = []
partial_weight = local_weight.clone().detach()
paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group)
if interleave:
# Due to the interleaved qkv layout, need to concat on num_head
# dimention for column parallel linear in MultiHeadAttention layer
assert axis == 0
assert [3 * self.hidden_size // self.world_size,
self.hidden_size] == partial_weight.shape
local_num_head = self.num_heads // self.world_size
for idx, _ in enumerate(total_weight):
total_weight[idx] = total_weight[idx].reshape(
[3, local_num_head, -1, self.hidden_size])
total_weight = paddle.concat(total_weight, axis=1).reshape([-1, self.hidden_size])
else:
total_weight = paddle.concat(total_weight, axis=axis)
return total_weight
def _get_weight(obj, weight_names):
for name in weight_names:
obj = getattr(obj, name)
return obj
def copy_weight(layer_src, layer_dst, partition_mode, weight_names, interleave=False):
weight_src = _get_weight(layer_src, weight_names)
weight_dst = _get_weight(layer_dst, weight_names)
if partition_mode is None:
total_weight = weight_src
elif partition_mode == 'column':
total_weight = _get_total_weight(weight_src,
tp_group=self.tp_group,
axis=0,
interleave=interleave)
elif partition_mode == 'row':
total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1)
else:
raise ValueError(f"Partition Mode {partition_mode} is not supported.")
assert weight_dst.shape == total_weight.shape, \
f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match."
weight_dst.copy_(total_weight, True)
copy_weight(layer_tp, layer_single, None, ['layernorm_qkv', 'ln_weight'])
copy_weight(layer_tp, layer_single, 'column', ['layernorm_qkv', 'weight'], interleave=True)
copy_weight(layer_tp, layer_single, 'row', ['proj', 'weight'])
if self.sequence_parallel:
register_sequence_parallel_allreduce_hooks(layer_tp, accumulation_steps=1)
optimizer_tp = paddle.optimizer.SGD(learning_rate=0.01, parameters=layer_tp.parameters())
optimizer_single = paddle.optimizer.SGD(learning_rate=0.01,
parameters=layer_single.parameters())
layer_tp = fleet.distributed_model(layer_tp)
optimizer_tp = fleet.distributed_optimizer(optimizer_tp)
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.q_seqlen, self.hidden_size],
self.global_dtype)
mask = paddle.zeros(shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen),
dtype='bool')
loss_tp, out_tp = self._train_one_step(layer_tp, [inp, mask], optimizer_tp, self.fp8,
self.sequence_parallel)
loss_single, out_single = self._train_one_step(layer_single, [inp, mask],
optimizer_single, self.fp8)
assert_allclose(out_tp, out_single, rtol=self.rtol, atol=self.atol)
assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol)
class TestAttentionTpFp8(TestAttentionTp):
"""Tests MultiHeadAttention layer with model parallel in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.global_dtype = 'bfloat16'
self.rtol = 5e-2
self.atol = 5e-2
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = False
class TestAttentionSp(TestAttentionTp):
"""Tests MultiHeadAttention layer with sequence parallel in BF16"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.global_dtype = 'bfloat16'
self.rtol = 5e-3
self.atol = 5e-3
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = True
class TestAttentionSpFp8(TestAttentionTp):
"""Tests MultiHeadAttention layer with sequence parallel in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.global_dtype = 'bfloat16'
self.rtol = 5e-2
self.atol = 1e-1
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = True
if __name__ == '__main__':
unittest.main()
......@@ -30,9 +30,12 @@ class TestLayerNormLinearTp(unittest.TestCase):
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
}
strategy.hybrid_configs["mp_configs"].need_broadcast_data = False
fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index()
self.hcg = fleet.get_hybrid_communicate_group()
self.tp_group = self.hcg.get_model_parallel_group()
self.world_size = self.hcg.get_model_parallel_world_size()
def set_attr(self):
"""Set test configs"""
......@@ -44,6 +47,39 @@ class TestLayerNormLinearTp(unittest.TestCase):
self.atol = 1e-3
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = False
def _train_one_step(self, layer, inp, optimizer, split_input='none', gather_output=False):
inp = paddle.to_tensor(inp, stop_gradient=True)
assert split_input in ['none', 'column', 'row']
if split_input == 'column':
split_size = inp.shape[1] // self.world_size
input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)]
elif split_input == 'row':
split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :]
else:
input_parallel = inp
input_parallel.stop_gradient = False
out = layer(input_parallel)
if gather_output:
total_out = mp_ops._c_concat(out, group=self.tp_group)
else:
total_out = out
loss = total_out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
if split_input != 'none':
grad_input = []
paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group)
if split_input == 'column':
grad_input = paddle.concat(grad_input, axis=1)
elif split_input == 'row':
grad_input = paddle.concat(grad_input, axis=0)
else:
grad_input = input_parallel.grad
return loss, grad_input
def test_column_parallel_layer(self):
"""Tests column parallel LayerNormLinear"""
......@@ -53,6 +89,7 @@ class TestLayerNormLinearTp(unittest.TestCase):
self.out_features,
eps=self.eps,
parallel_mode='column',
sequence_parallel=self.sequence_parallel,
)
layer_pd = te.LayerNormLinear(
self.in_features,
......@@ -77,25 +114,16 @@ class TestLayerNormLinearTp(unittest.TestCase):
layer_te = fleet.distributed_model(layer_te)
optimizer_te = fleet.distributed_optimizer(optimizer_te)
def train_one_step(layer, inp, optimizer, gather=False):
inp = paddle.to_tensor(inp)
inp.stop_gradient = False
out = layer(inp)
if gather:
total_out = mp_ops._c_concat(out, group=self.tp_group)
else:
total_out = out
loss = total_out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss, inp.grad
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype)
with te.fp8_autocast(enabled=self.fp8):
loss_tp, grad_input = train_one_step(layer_te, inp, optimizer_te, gather=True)
loss_ref, grad_input_ref = train_one_step(layer_pd, inp, optimizer_pd)
loss_tp, grad_input = self._train_one_step(
layer_te,
inp,
optimizer_te,
split_input='row' if self.sequence_parallel else 'none',
gather_output=True)
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
......@@ -113,6 +141,39 @@ class TestLayerNormLinearTpFp8(TestLayerNormLinearTp):
self.atol = 1e-2
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = False
class TestLayerNormLinearSp(TestLayerNormLinearTp):
"""Tests LayernormLinear layer with sequence parallelism"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = 'bfloat16'
self.rtol = 1e-3
self.atol = 1e-3
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = True
class TestLayerNormLinearSpFp8(TestLayerNormLinearTp):
"""Tests LayernormLinear layer with sequence parallelism in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = 'bfloat16'
self.rtol = 1e-2
self.atol = 1e-2
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = True
if __name__ == '__main__':
......
......@@ -7,6 +7,7 @@ import unittest
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.layers.mpu import mp_ops
from utils import assert_allclose, assert_shape, set_random_seed
import transformer_engine.paddle as te
......@@ -29,9 +30,12 @@ class TestLayerNormMLPTp(unittest.TestCase):
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
}
strategy.hybrid_configs["mp_configs"].need_broadcast_data = False
fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index()
self.hcg = fleet.get_hybrid_communicate_group()
self.tp_group = self.hcg.get_model_parallel_group()
self.world_size = self.hcg.get_model_parallel_world_size()
def set_attr(self):
"""Set test configs"""
......@@ -43,6 +47,40 @@ class TestLayerNormMLPTp(unittest.TestCase):
self.atol = 1e-3
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = False
def _train_one_step(self, layer, inp, optimizer, split_input='none', gather_output=False):
inp = paddle.to_tensor(inp, stop_gradient=True)
assert split_input in ['none', 'column', 'row']
if split_input == 'column':
split_size = inp.shape[1] // self.world_size
input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)]
elif split_input == 'row':
split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :]
else:
input_parallel = inp
input_parallel.stop_gradient = False
out = layer(input_parallel)
if gather_output:
# Need to concat on the first dim, while _c_concat concats on the last dim
total_out = mp_ops._c_concat(out.T, group=self.tp_group).T
else:
total_out = out
loss = total_out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
if split_input != 'none':
grad_input = []
paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group)
if split_input == 'column':
grad_input = paddle.concat(grad_input, axis=1)
elif split_input == 'row':
grad_input = paddle.concat(grad_input, axis=0)
else:
grad_input = input_parallel.grad
return loss, grad_input
def test_parallel_layer(self):
"""Tests parallel LayerNormMLP"""
......@@ -52,6 +90,7 @@ class TestLayerNormMLPTp(unittest.TestCase):
ffn_hidden_size=self.ffn_hidden_size,
eps=self.eps,
set_parallel_mode=True,
sequence_parallel=self.sequence_parallel,
)
layer_pd = te.LayerNormMLP(
hidden_size=self.hidden_size,
......@@ -87,21 +126,16 @@ class TestLayerNormMLPTp(unittest.TestCase):
layer_te = fleet.distributed_model(layer_te)
optimizer_te = fleet.distributed_optimizer(optimizer_te)
def train_one_step(layer, inp, optimizer):
inp = paddle.to_tensor(inp)
inp.stop_gradient = False
out = layer(inp)
loss = out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss, inp.grad
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.hidden_size], self.global_dtype)
with te.fp8_autocast(enabled=self.fp8):
loss_tp, grad_input = train_one_step(layer_te, inp, optimizer_te)
loss_ref, grad_input_ref = train_one_step(layer_pd, inp, optimizer_pd)
loss_tp, grad_input = self._train_one_step(
layer_te,
inp,
optimizer_te,
split_input='row' if self.sequence_parallel else 'none',
gather_output=self.sequence_parallel)
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
......@@ -119,6 +153,39 @@ class TestLayerNormMLPTpFp8(TestLayerNormMLPTp):
self.atol = 1e-2
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = False
class TestLayerNormMLPSp(TestLayerNormMLPTp):
"""Tests LayerNormMLP layer with sequence parallel in BF16"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 32
self.ffn_hidden_size = 64
self.global_dtype = 'bfloat16'
self.rtol = 1e-3
self.atol = 1e-3
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = True
class TestLayerNormMLPSpFp8(TestLayerNormMLPTp):
"""Tests LayerNormMLP layer with sequence parallelism in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 32
self.ffn_hidden_size = 64
self.global_dtype = 'bfloat16'
self.rtol = 1e-2
self.atol = 1e-2
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = True
if __name__ == '__main__':
......
......@@ -30,6 +30,7 @@ class TestLinearTp(unittest.TestCase):
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
}
strategy.hybrid_configs["mp_configs"].need_broadcast_data = False
fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index()
self.hcg = fleet.get_hybrid_communicate_group()
......@@ -45,6 +46,39 @@ class TestLinearTp(unittest.TestCase):
self.rtol = 1e-3
self.atol = 1e-3
self.fp8 = False
self.sequence_parallel = False
def _train_one_step(self, layer, inp, optimizer, split_input='none', gather_output=False):
inp = paddle.to_tensor(inp, stop_gradient=True)
assert split_input in ['none', 'column', 'row']
if split_input == 'column':
split_size = inp.shape[1] // self.world_size
input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)]
elif split_input == 'row':
split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :]
else:
input_parallel = inp
input_parallel.stop_gradient = False
out = layer(input_parallel)
if gather_output:
total_out = mp_ops._c_concat(out, group=self.tp_group)
else:
total_out = out
loss = total_out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
if split_input != 'none':
grad_input = []
paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group)
if split_input == 'column':
grad_input = paddle.concat(grad_input, axis=1)
elif split_input == 'row':
grad_input = paddle.concat(grad_input, axis=0)
else:
grad_input = input_parallel.grad
return loss, grad_input
def test_column_parallel_layer(self):
"""Tests column parallel linear"""
......@@ -53,6 +87,7 @@ class TestLinearTp(unittest.TestCase):
self.in_features,
self.out_features,
parallel_mode='column',
sequence_parallel=self.sequence_parallel,
)
layer_pd = te.Linear(
self.in_features,
......@@ -76,25 +111,16 @@ class TestLinearTp(unittest.TestCase):
layer_te = fleet.distributed_model(layer_te)
optimizer_te = fleet.distributed_optimizer(optimizer_te)
def train_one_step(layer, inp, optimizer, gather=False):
inp = paddle.to_tensor(inp)
inp.stop_gradient = False
out = layer(inp)
if gather:
total_out = mp_ops._c_concat(out, group=self.tp_group)
else:
total_out = out
loss = total_out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss, inp.grad
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype)
with te.fp8_autocast(enabled=self.fp8):
loss_tp, grad_input = train_one_step(layer_te, inp, optimizer_te, gather=True)
loss_ref, grad_input_ref = train_one_step(layer_pd, inp, optimizer_pd)
loss_tp, grad_input = self._train_one_step(
layer_te,
inp,
optimizer_te,
split_input='row' if self.sequence_parallel else 'none',
gather_output=True)
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
......@@ -105,6 +131,7 @@ class TestLinearTp(unittest.TestCase):
self.in_features,
self.out_features,
parallel_mode='row',
sequence_parallel=self.sequence_parallel,
)
layer_pd = te.Linear(
self.in_features,
......@@ -125,39 +152,18 @@ class TestLinearTp(unittest.TestCase):
optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters())
# Note(tizheng): For this test, we cannot wrap model with fleet.distributed_model,
# because it will broadcast inputs across mp group. However, RPL expects splitted
# inputs, which is different on each rank.
def train_one_step(layer, inp, optimizer, split=False):
inp = paddle.to_tensor(inp, stop_gradient=True)
if split:
# TODO(tizheng): Why not working?
# issue: https://github.com/PaddlePaddle/Paddle/issues/55565
# input_parallel = mp_ops._c_split(inp, group=layer.tp_group)
split_size = inp.shape[1] // self.world_size
input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)]
else:
input_parallel = inp
input_parallel.stop_gradient = False
out = layer(input_parallel)
loss = out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
if split:
grad_input = []
paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group)
grad_input = paddle.concat(grad_input, axis=1)
else:
grad_input = input_parallel.grad
return loss, grad_input
layer_te = fleet.distributed_model(layer_te)
optimizer_te = fleet.distributed_optimizer(optimizer_te)
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype)
with te.fp8_autocast(enabled=self.fp8):
loss_tp, grad_input = train_one_step(layer_te, inp, optimizer_te, split=True)
loss_ref, grad_input_ref = train_one_step(layer_pd, inp, optimizer_pd)
loss_tp, grad_input = self._train_one_step(layer_te,
inp,
optimizer_te,
split_input='column',
gather_output=self.sequence_parallel)
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
......@@ -174,6 +180,37 @@ class TestLinearTpFP8(TestLinearTp):
self.rtol = 1e-2
self.atol = 1e-2
self.fp8 = True
self.sequence_parallel = False
class TestLinearSp(TestLinearTp):
"""Tests Linear layer with sequence parallelism"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = 'bfloat16'
self.rtol = 1e-3
self.atol = 1e-3
self.fp8 = False
self.sequence_parallel = True
class TestLinearSpFP8(TestLinearTp):
"""Tests Linear layer with sequence parallelism in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = 'bfloat16'
self.rtol = 1e-2
self.atol = 1e-2
self.fp8 = True
self.sequence_parallel = True
if __name__ == '__main__':
......
......@@ -7,8 +7,9 @@ import unittest
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.layers.mpu import mp_ops
from utils import assert_allclose, set_random_seed
from utils import assert_allclose, set_random_seed, register_sequence_parallel_allreduce_hooks
import transformer_engine.paddle as te
......@@ -29,9 +30,12 @@ class TestTransformerTp(unittest.TestCase):
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
}
strategy.hybrid_configs["mp_configs"].need_broadcast_data = False
fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index()
self.hcg = fleet.get_hybrid_communicate_group()
self.tp_group = self.hcg.get_model_parallel_group()
self.world_size = self.hcg.get_model_parallel_world_size()
def set_attr(self):
"""Set test configs"""
......@@ -48,6 +52,27 @@ class TestTransformerTp(unittest.TestCase):
self.atol = 5e-2
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = False
def _train_one_step(self, layer, inp_list, optimizer, fp8_enabled, sequence_parallel=False):
inp, mask = inp_list
if sequence_parallel:
split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :]
else:
input_parallel = inp
with te.fp8_autocast(enabled=fp8_enabled):
out = layer(input_parallel, mask)
if sequence_parallel:
total_out = mp_ops._c_concat(out, group=self.tp_group)
total_out = paddle.concat(paddle.split(total_out, self.world_size, axis=-1), axis=0)
else:
total_out = out
loss = total_out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss, total_out
def test_parallel_layer(self):
"""Tests parallel Transformer"""
......@@ -64,14 +89,29 @@ class TestTransformerTp(unittest.TestCase):
'self_attn_mask_type': self.mask_type,
'layer_type': self.layer_type,
}
layer_tp = te.TransformerLayer(*common_args, **common_kwargs, set_parallel_mode=True)
layer_tp = te.TransformerLayer(*common_args,
**common_kwargs,
set_parallel_mode=True,
sequence_parallel=self.sequence_parallel)
layer_single = te.TransformerLayer(*common_args, **common_kwargs, set_parallel_mode=False)
def _get_total_weight(local_weight, tp_group, axis):
def _get_total_weight(local_weight, tp_group, axis, interleave=False):
total_weight = []
partial_weight = local_weight.clone().detach()
paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group)
total_weight = paddle.concat(total_weight, axis=axis)
if interleave:
# Due to the interleaved qkv layout, need to concat on num_head
# dimention for column parallel linear in MultiHeadAttention layer
assert axis == 0
assert [3 * self.hidden_size // self.world_size,
self.hidden_size] == partial_weight.shape
local_num_head = self.num_heads // self.world_size
for idx, _ in enumerate(total_weight):
total_weight[idx] = total_weight[idx].reshape(
[3, local_num_head, -1, self.hidden_size])
total_weight = paddle.concat(total_weight, axis=1).reshape([-1, self.hidden_size])
else:
total_weight = paddle.concat(total_weight, axis=axis)
return total_weight
def _get_weight(obj, weight_names):
......@@ -79,13 +119,16 @@ class TestTransformerTp(unittest.TestCase):
obj = getattr(obj, name)
return obj
def copy_weight(layer_src, layer_dst, partition_mode, weight_names):
def copy_weight(layer_src, layer_dst, partition_mode, weight_names, interleave=False):
weight_src = _get_weight(layer_src, weight_names)
weight_dst = _get_weight(layer_dst, weight_names)
if partition_mode is None:
total_weight = weight_src
elif partition_mode == 'column':
total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=0)
total_weight = _get_total_weight(weight_src,
tp_group=self.tp_group,
axis=0,
interleave=interleave)
elif partition_mode == 'row':
total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1)
else:
......@@ -95,41 +138,62 @@ class TestTransformerTp(unittest.TestCase):
weight_dst.copy_(total_weight, True)
copy_weight(layer_tp, layer_single, None, ['self_attention', 'layernorm_qkv', 'ln_weight'])
copy_weight(layer_tp, layer_single, 'column', ['self_attention', 'layernorm_qkv', 'weight'])
copy_weight(layer_tp,
layer_single,
'column', ['self_attention', 'layernorm_qkv', 'weight'],
interleave=True)
copy_weight(layer_tp, layer_single, 'row', ['self_attention', 'proj', 'weight'])
copy_weight(layer_tp, layer_single, None, ['layernorm_mlp', 'ln_weight'])
copy_weight(layer_tp, layer_single, 'column', ['layernorm_mlp', 'fc1_weight'])
copy_weight(layer_tp, layer_single, 'row', ['layernorm_mlp', 'fc2_weight'])
optimizer_tp = paddle.optimizer.SGD(learning_rate=0.1, parameters=layer_tp.parameters())
optimizer_single = paddle.optimizer.SGD(learning_rate=0.1,
if self.sequence_parallel:
register_sequence_parallel_allreduce_hooks(layer_tp, accumulation_steps=1)
optimizer_tp = paddle.optimizer.SGD(learning_rate=0.01, parameters=layer_tp.parameters())
optimizer_single = paddle.optimizer.SGD(learning_rate=0.01,
parameters=layer_single.parameters())
layer_tp = fleet.distributed_model(layer_tp)
optimizer_tp = fleet.distributed_optimizer(optimizer_tp)
def train_one_step(layer, inp_list, optimizer, fp8_enabled):
with te.fp8_autocast(enabled=fp8_enabled):
out = layer(*inp_list)
loss = out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.q_seqlen, self.hidden_size],
self.global_dtype)
mask = paddle.zeros(shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen),
dtype='bool')
loss_tp = train_one_step(layer_tp, [inp, mask], optimizer_tp, self.fp8)
loss_single = train_one_step(layer_single, [inp, mask], optimizer_single, self.fp8)
loss_tp, out_tp = self._train_one_step(layer_tp, [inp, mask], optimizer_tp, self.fp8,
self.sequence_parallel)
loss_single, out_single = self._train_one_step(layer_single, [inp, mask],
optimizer_single, self.fp8)
assert_allclose(out_tp, out_single, rtol=self.rtol, atol=self.atol)
assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol)
class TestTransformerTpFp8(TestTransformerTp):
"""Tests Transformer layer with tensor parallelism in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.ffn_hidden_size = 4096
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.layer_type = 'encoder'
self.global_dtype = 'bfloat16'
self.rtol = 5e-2
self.atol = 0.5
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = False
class TestTransformerSp(TestTransformerTp):
"""Tests Transformer layer with sequence parallel in BF16"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
......@@ -144,7 +208,29 @@ class TestTransformerTpFp8(TestTransformerTp):
self.rtol = 5e-2
self.atol = 5e-2
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = True
class TestTransformerSpFp8(TestTransformerSp):
"""Tests Transformer layer with sequence parallelism in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.ffn_hidden_size = 4096
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.layer_type = 'encoder'
self.global_dtype = 'bfloat16'
self.rtol = 5e-2
self.atol = 0.5
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = True
if __name__ == '__main__':
......
......@@ -75,6 +75,16 @@ class TestGroupSharding(TestDistributed):
self.run_2gpu(str(test_root / 'parallel_tests' / 'group_sharding.py'))
class TestParallelAttention(TestDistributed):
"""Test MultiHeadAttention Layer in Parallel mode"""
@unittest.skipIf(not is_devices_enough(2), "TestParallelAttention needs 2 GPUs")
@unittest.skipIf(not gpu_has_fp8, reason)
def test_attention_tp(self):
"""Tests TransMultiHeadAttentionformer Layer with tensor parallel in BF16"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'attention_tp.py'))
class TestParallelTransformerLayer(TestDistributed):
"""Test Transformer Layer in Parallel mode"""
......
......@@ -40,9 +40,15 @@ def assert_allclose(actual,
verbose=True):
"""Compare two input paddle tensors"""
if isinstance(actual, paddle.Tensor):
actual = paddle.cast(actual, 'float32').numpy()
actual = paddle.cast(actual, 'float32')
if isinstance(desired, paddle.Tensor):
desired = paddle.cast(desired, 'float32').numpy()
desired = paddle.cast(desired, 'float32')
if len(actual.shape) == 0:
actual = actual.item()
desired = desired.item()
else:
actual = actual.numpy()
desired = desired.numpy()
np.testing.assert_allclose(actual, desired, rtol, atol, equal_nan, err_msg, verbose)
......@@ -59,6 +65,7 @@ def is_devices_enough(required):
def set_random_seed(seed):
"""Set random seed for reproducability."""
fleet.meta_parallel.model_parallel_random_seed(seed)
hcg = fleet.get_hybrid_communicate_group()
if paddle.distributed.get_world_size() > 1:
......@@ -161,3 +168,46 @@ def is_fused_attention_supported(
mask_type=mask_type,
)
return backend != FusedAttnBackend["No_Backend"]
def register_sequence_parallel_allreduce_hooks(model, accumulation_steps) -> None:
"""Register allreduce hooks for sequence parallel tensors"""
def is_sequence_parallel_parameter(parameter):
"""If input tensor is marked as sequence parallel tensor"""
out = getattr(parameter, "sequence_parallel", False)
return out
def create_allreduce_gradient_hook(param, accumulation_steps):
"""Create allreduce gradient hook"""
hcg = fleet.get_hybrid_communicate_group()
pg = hcg.get_model_parallel_group().process_group
step = [0]
@paddle.autograd.no_grad()
def __impl__():
step[0] += 1
if (step[0] % accumulation_steps) == 0:
if hasattr(param, "main_grad"):
pg.allreduce(param.main_grad).wait()
else:
pg.allreduce(param.grad).wait()
return __impl__
if accumulation_steps <= 0 or not paddle.distributed.is_initialized():
return
hcg = fleet.get_hybrid_communicate_group()
mp_group = hcg.get_model_parallel_group()
if mp_group.nranks <= 1:
return
params = []
for p in model.parameters():
if is_sequence_parallel_parameter(p):
params.append(p)
for p in params:
hook = create_allreduce_gradient_hook(p, accumulation_steps)
p._register_backward_hook(hook)
......@@ -291,7 +291,7 @@ std::vector<paddle::Tensor> te_layernorm_fwd_fp8(const paddle::Tensor &input,
size_t N = shape[0];
size_t H = shape[1];
auto ln_out = paddle::empty_like(input, input.dtype(), input.place());
auto ln_out = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto mu = paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto rsigma =
paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
......
......@@ -114,6 +114,57 @@ def allreduce(
return output, wait_handle
def allgather(
input_: paddle.Tensor,
tp_group: Optional[dist_group_type] = None,
sync_op: bool = True,
) -> Tuple[paddle.Tensor, Any]:
"""All-gather the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if tp_group is None or tp_group.nranks == 1:
return input_, None
parallelism = tp_group.nranks
output_shape = input_.shape
output_shape[0] = output_shape[0] * parallelism
output = paddle.empty(shape=output_shape, dtype=input_.dtype)
wait_handle = tp_group.process_group.all_gather_into_tensor(output, input_, sync_op)
if sync_op:
wait_handle.wait()
return output, None
return output, wait_handle
def reduce_scatter(
input_: paddle.Tensor,
tp_group: Optional[dist_group_type] = None,
sync_op: bool = True,
) -> [paddle.Tensor, Any]:
"""Reduce-scatter the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if tp_group is None or tp_group.nranks == 1:
return input_, None
parallelism = tp_group.nranks
output_shape = input_.shape
assert (
input_.shape[0] % parallelism == 0
), f"Input sequence length {input_.shape[0]} can't be divided " \
f"exactly by sequence parallelism {parallelism}"
output_shape[0] = output_shape[0] // parallelism
output = paddle.empty(shape=output_shape, dtype=input_.dtype)
wait_handle = paddle.distributed.stream.reduce_scatter(output,
input_,
op=paddle.distributed.ReduceOp.SUM,
group=tp_group,
sync_op=sync_op)
if sync_op:
return output, None
return output, wait_handle
def identity(
input_: paddle.Tensor,
tp_group: Optional[dist_group_type] = None,
......@@ -125,3 +176,11 @@ def identity(
output = mp_ops._c_identity(input_, group=tp_group)
return output
def mark_as_sequence_parallel_parameter(parameter: paddle.Tensor):
"""
Set sequence_parallel attribute to input tensor. It is used for registering allreduce
hooks in PaddleNLP sequence parallel training.
"""
setattr(parameter, "sequence_parallel", True)
......@@ -237,8 +237,7 @@ class DotProductAttention(paddle.nn.Layer):
self.fused_attention_backend = tex.get_fused_attn_backend(
TE_DType[query_layer.dtype], TE_DType[query_layer.dtype],
tex.get_nvte_qkv_layout(self.qkv_layout), AttnBiasType[core_attention_bias_type],
AttnMaskType[self.attn_mask_type], self.attention_dropout,
query_layer.shape[-2],
AttnMaskType[self.attn_mask_type], self.attention_dropout, query_layer.shape[-2],
key_value_layer.shape[-2] if key_value_layer is not None else query_layer.shape[-2],
max_s_q, max_s_kv, query_layer.shape[-1])
......@@ -401,6 +400,8 @@ class MultiHeadAttention(paddle.nn.Layer):
if set to `True`, QKV and FC1 layers are used as Column Parallel
whereas PROJ and FC2 is used as Row Parallel as described
`here <https://arxiv.org/pdf/1909.08053.pdf>`_.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
rng_state_name : str, default = `local_seed`
......@@ -427,6 +428,7 @@ class MultiHeadAttention(paddle.nn.Layer):
attention_type: str = "self",
zero_centered_gamma: bool = False,
set_parallel_mode: bool = False,
sequence_parallel: bool = False,
tp_group: Optional[dist_group_type] = None,
rng_state_name: str = 'local_seed',
backend: str = 'transformer_engine',
......@@ -445,7 +447,7 @@ class MultiHeadAttention(paddle.nn.Layer):
self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group,
enable_tp=set_parallel_mode)
self.tensor_parallel = self.tp_size > 1
self.sequence_parallel = self.tensor_parallel and sequence_parallel
self.hidden_size_per_attention_head = hidden_size // num_attention_heads
self.num_attention_heads = num_attention_heads
norm_factor = math.sqrt(self.hidden_size_per_attention_head)
......@@ -467,6 +469,7 @@ class MultiHeadAttention(paddle.nn.Layer):
return_layernorm_output=return_layernorm_output,
zero_centered_gamma=zero_centered_gamma,
parallel_mode=qkv_parallel_mode,
sequence_parallel=self.sequence_parallel,
tp_group=self.tp_group,
backend=self.backend,
)
......@@ -477,6 +480,7 @@ class MultiHeadAttention(paddle.nn.Layer):
self.weight_attr,
self.bias_attr,
parallel_mode=qkv_parallel_mode,
sequence_parallel=self.sequence_parallel,
tp_group=self.tp_group,
backend=self.backend,
)
......@@ -492,6 +496,7 @@ class MultiHeadAttention(paddle.nn.Layer):
return_layernorm_output=return_layernorm_output,
zero_centered_gamma=zero_centered_gamma,
parallel_mode=qkv_parallel_mode,
sequence_parallel=self.sequence_parallel,
tp_group=self.tp_group,
backend=self.backend,
)
......@@ -502,6 +507,7 @@ class MultiHeadAttention(paddle.nn.Layer):
self.weight_attr,
self.bias_attr,
parallel_mode=qkv_parallel_mode,
sequence_parallel=self.sequence_parallel,
tp_group=self.tp_group,
backend=self.backend,
)
......@@ -511,6 +517,7 @@ class MultiHeadAttention(paddle.nn.Layer):
self.weight_attr,
self.bias_attr,
parallel_mode=qkv_parallel_mode,
sequence_parallel=self.sequence_parallel,
tp_group=self.tp_group,
backend=self.backend,
)
......@@ -531,6 +538,7 @@ class MultiHeadAttention(paddle.nn.Layer):
self.weight_attr,
self.bias_attr,
parallel_mode="row" if set_parallel_mode else None,
sequence_parallel=self.sequence_parallel,
tp_group=self.tp_group,
backend=self.backend,
)
......@@ -569,10 +577,21 @@ class MultiHeadAttention(paddle.nn.Layer):
backprop.
"""
# hidden_states: [b, s_q, hidden_size]
if self.attn_mask_type != "causal" and attention_mask is not None:
assert (attention_mask.dtype == paddle.bool), "Attention mask must be a boolean tensor"
input_dim = len(hidden_states.shape)
if input_dim == 2:
# hidden_states: [b * s_q, hidden_size]
# need to get max_seq_len from attention_mask
assert attention_mask is not None
max_seq_len = attention_mask.shape[-1]
elif input_dim == 3:
# hidden_states: [b, s_q, hidden_size]
max_seq_len = hidden_states.shape[1]
else:
raise ValueError(f"hidden_states should have 2 or 3 dimensions, got {input_dim}.")
if self.attention_type == "self":
if self.input_layernorm:
layernorm_qkv_outputs = self.layernorm_qkv(hidden_states)
......@@ -585,7 +604,8 @@ class MultiHeadAttention(paddle.nn.Layer):
# [b, s_q, 3 * hidden_size] --> [b, s_q, 3, num_heads, head_size]
mixed_qkv_layer = mixed_qkv_layer.reshape(shape=[
0, 0, 3, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
-1, max_seq_len, 3, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head
])
with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
......@@ -614,7 +634,8 @@ class MultiHeadAttention(paddle.nn.Layer):
mixed_kv_layer = self.key_value(encoder_output)
# [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size]
mixed_kv_layer = mixed_kv_layer.reshape(shape=[
0, 0, 2, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
-1, max_seq_len, 2, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head
])
if self.input_layernorm:
......@@ -627,7 +648,8 @@ class MultiHeadAttention(paddle.nn.Layer):
query_layer = self.query_layer(hidden_states)
query_layer = query_layer.reshape(shape=[
0, 0, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
-1, max_seq_len, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head
])
with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
if recompute_core_attention:
......@@ -651,8 +673,13 @@ class MultiHeadAttention(paddle.nn.Layer):
set_zero=set_zero,
)
context_layer = paddle.reshape(context_layer,
[0, 0, context_layer.shape[2] * context_layer.shape[3]])
if input_dim == 3:
context_layer = paddle.reshape(
context_layer, [-1, max_seq_len, context_layer.shape[2] * context_layer.shape[3]])
else: # input_dim == 2
context_layer = paddle.reshape(context_layer,
[-1, context_layer.shape[2] * context_layer.shape[3]])
# Output. [b, s, hidden]
attention_output = self.proj(context_layer)
......
......@@ -16,7 +16,7 @@ from paddle.fluid import core
from paddle.fluid.framework import _dygraph_tracer
from ..constants import FP8BwdTensors, dist_group_type
from ..cpp_extensions import cast_transpose, cast_transpose_bgrad, cast_to_fp8
from ..cpp_extensions import cast_transpose, cast_transpose_bgrad, cast_to_fp8, transpose
from ..fp8 import (
FP8State,
FP8TensorMeta,
......@@ -24,6 +24,7 @@ from ..fp8 import (
get_global_fp8_state,
get_fp8_te_dtype,
)
from ..distributed import allgather
from ..profile import nvtx_range
from ..recompute import is_in_recompute_phase
from ..fp8_buffer import FP8RecomputeBuffer
......@@ -310,8 +311,8 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
global_fp8_bwd_buffer.finalize(fp8_meta, tp_group, tp_size)
@staticmethod
def grad_output_preprocess(
ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]:
def grad_output_preprocess(ctx, grad_output: paddle.Tensor,
row_parallel_mode: bool) -> Tuple[Union[paddle.Tensor, None], ...]:
"""Utility function for backward.
Returns tuple in order (all optional/None based on training precion/recipe):
R1: gathered `grad_output` in higher precision.
......@@ -320,13 +321,37 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
R4: bias gradient on R1.
"""
grad_output_mat = grad_output.reshape((-1, grad_output.shape[-1]))
gather_grad_output = row_parallel_mode and ctx.sequence_parallel
# No-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8_enabled:
if gather_grad_output:
grad_output_mat, _ = allgather(grad_output_mat, ctx.tp_group)
return grad_output_mat, None, None, None
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
if gather_grad_output:
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
# FP8 case with gather: unfused bgrad, cast, transpose for efficient gather
if ctx.use_bias:
bgrad = grad_output_mat.sum(axis=0)
else:
bgrad = None
grad_output_c = cast_to_fp8(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
)
grad_output_c, _ = allgather(grad_output_c, ctx.tp_group)
grad_output_t = transpose(grad_output_c, fp8_dtype_backward)
return grad_output_mat, grad_output_c, grad_output_t, bgrad
# FP8 case with gather and non-FP8 wgrad
grad_output_mat, _ = allgather(grad_output_mat, ctx.tp_group)
# FP8 case without gather: cast, transpose, bgrad fused
if ctx.use_bias:
bgrad, grad_output_c, grad_output_t = cast_transpose_bgrad(
......
......@@ -12,6 +12,7 @@ from paddle.nn.initializer import Constant
from ..constants import TE_DType
from ..cpp_extensions import layernorm_fwd, layernorm_bwd
from ..distributed import mark_as_sequence_parallel_parameter
__all__ = ["LayerNorm"]
......@@ -90,6 +91,11 @@ class LayerNorm(paddle.nn.Layer):
(1 + \gamma) + \beta
backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
backend to use for softmax operation.
Parallelism parameters
----------------------
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
"""
def __init__(
......@@ -99,11 +105,13 @@ class LayerNorm(paddle.nn.Layer):
weight_attr: Union[paddle.ParamAttr, None] = None,
bias_attr: Union[paddle.ParamAttr, None, bool] = None,
zero_centered_gamma: bool = False,
sequence_parallel: bool = False,
backend: str = 'transformer_engine',
) -> None:
super().__init__()
self.eps = eps
self.zero_centered_gamma = zero_centered_gamma
self.sequence_parallel = sequence_parallel
self.backend = backend
self._dtype = self._helper.get_default_dtype()
......@@ -130,6 +138,10 @@ class LayerNorm(paddle.nn.Layer):
is_bias=True,
)
if self.sequence_parallel:
mark_as_sequence_parallel_parameter(self.weight)
mark_as_sequence_parallel_parameter(self.bias)
# These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
......
......@@ -16,7 +16,6 @@ from ..cpp_extensions import (
layernorm_fwd,
layernorm_fwd_fp8,
layernorm_bwd,
transpose,
)
from .base import TransformerEngineBaseLayer
......@@ -29,6 +28,7 @@ from ..distributed import (
track_rng_state,
set_tensor_dist_attr,
set_weight_tensor_dist_attr,
mark_as_sequence_parallel_parameter,
)
from ..fp8 import get_fp8_te_dtype
from ..utils import (
......@@ -145,6 +145,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
zero_centered_gamma: bool,
parallel_mode: Union[str, None],
tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]:
......@@ -190,6 +191,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
activation_dtype,
parallel_mode,
tensor_parallel,
sequence_parallel,
tp_group,
is_grad_enabled,
)
......@@ -217,6 +219,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ctx.zero_centered_gamma = zero_centered_gamma
ctx.parallel_mode = parallel_mode
ctx.tensor_parallel = tensor_parallel
ctx.sequence_parallel = sequence_parallel
ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.requires_dgrad = not inp.stop_gradient
......@@ -256,29 +259,26 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
grad_output_c,
grad_output_t,
bgrad,
) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_outputs[0])
) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_outputs[0],
ctx.parallel_mode == "row")
# Prepare ln_out for Linear bwd
ln_out_no_fp8, ln_out_t = None, None
linear_inputmat = ln_out
if ctx.fp8_enabled:
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_wgrad = not ctx.fp8_meta["recipe"].override_linear_precision.wgrad
if ctx.requires_wgrad:
if fp8_wgrad:
ln_out_t = transpose(ln_out, fp8_dtype_forward)
else:
ln_out_no_fp8 = cast_from_fp8(
ln_out,
ctx.fp8_meta["scaling_fwd"],
FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
)
if ctx.requires_wgrad and ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
linear_inputmat = cast_from_fp8(
ln_out,
ctx.fp8_meta["scaling_fwd"],
FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
)
# Linear Bwd
dgrad, wgrad, bgrad_ = _linear_bwd(
ln_out_no_fp8 if ctx.fp8_enabled else ln_out,
ln_out_t,
linear_inputmat,
None, # inputmat_t will be automatically computed if not provided
FP8FwdTensors.GEMM1_INPUT,
weight,
weight_t_fp8,
......@@ -296,6 +296,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ctx.activation_dtype,
ctx.parallel_mode,
ctx.tensor_parallel,
ctx.sequence_parallel,
ctx.tp_group,
)
......@@ -367,6 +368,8 @@ class LayerNormLinear(TransformerEngineBaseLayer):
used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
"""
def __init__(
......@@ -379,6 +382,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
return_layernorm_output: bool = False,
zero_centered_gamma: bool = False,
parallel_mode: Optional[str] = None,
sequence_parallel: bool = False,
tp_group: Union[dist_group_type, None] = None,
backend: str = 'transformer_engine',
) -> None:
......@@ -409,6 +413,8 @@ class LayerNormLinear(TransformerEngineBaseLayer):
elif self.parallel_mode == "row":
self.in_features = divide(self.in_features, self.tp_size)
self.sequence_parallel = self.tensor_parallel and sequence_parallel
# LayerNorm weights
self.ln_weight = self.create_parameter(
shape=[self.in_features],
......@@ -425,6 +431,10 @@ class LayerNormLinear(TransformerEngineBaseLayer):
is_bias=True,
)
if self.sequence_parallel:
mark_as_sequence_parallel_parameter(self.ln_weight)
mark_as_sequence_parallel_parameter(self.ln_bias)
# Initialize Linear weight parameter
with track_rng_state(enable=self.tensor_parallel):
# TE linear weight is in column major
......@@ -451,6 +461,8 @@ class LayerNormLinear(TransformerEngineBaseLayer):
)
if parallel_mode == "column":
set_tensor_dist_attr(self.bias, self.tensor_parallel, axis=0)
if parallel_mode == "row" and self.sequence_parallel:
mark_as_sequence_parallel_parameter(self.bias)
else:
self.bias = None
......@@ -500,6 +512,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
self.zero_centered_gamma,
self.parallel_mode,
self.tensor_parallel,
self.sequence_parallel,
self.tp_group,
self.tp_size,
)
......
......@@ -18,7 +18,6 @@ from ..cpp_extensions import (
cast_from_fp8,
dgelu_cast_transpose_bgrad_fp8,
gelu_fp8,
transpose,
)
from ..distributed import (
allreduce,
......@@ -27,6 +26,7 @@ from ..distributed import (
track_rng_state,
set_tensor_dist_attr,
set_weight_tensor_dist_attr,
mark_as_sequence_parallel_parameter,
)
from ..fp8 import get_fp8_te_dtype
from ..utils import (
......@@ -39,7 +39,6 @@ from ..utils import (
saved_tensor_allow_none,
)
__all__ = ["LayerNormMLP"]
......@@ -63,6 +62,7 @@ def _mlp_forward(
is_grad_enabled: bool,
set_parallel_mode: bool,
tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None],
):
if fp8_enabled:
......@@ -78,6 +78,7 @@ def _mlp_forward(
activation_dtype,
'column' if set_parallel_mode else None,
tensor_parallel,
sequence_parallel,
tp_group,
is_grad_enabled,
)
......@@ -100,6 +101,7 @@ def _mlp_forward(
activation_dtype,
'row' if set_parallel_mode else None,
tensor_parallel,
sequence_parallel,
tp_group,
is_grad_enabled,
)
......@@ -116,6 +118,7 @@ def _mlp_forward(
activation_dtype,
'column' if set_parallel_mode else None,
tensor_parallel,
sequence_parallel,
tp_group,
activation=activation,
)
......@@ -132,6 +135,7 @@ def _mlp_forward(
activation_dtype,
'row' if set_parallel_mode else None,
tensor_parallel,
sequence_parallel,
tp_group,
)
return (
......@@ -172,6 +176,7 @@ def _mlp_backward(
activation: str,
set_parallel_mode: bool,
tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None],
):
(
......@@ -186,23 +191,19 @@ def _mlp_backward(
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False)
# FC2 Bwd
fc2_input_no_fp8, fc2_input_t = None, None
fp8_wgrad = not fp8_meta["recipe"].override_linear_precision.wgrad
if requires_fc2_wgrad:
if fp8_wgrad:
fc2_input_t = transpose(fc2_input, fp8_dtype_forward)
else:
fc2_input_no_fp8 = cast_from_fp8(
fc2_input,
fp8_meta["scaling_fwd"],
fc2_input_fp8_index,
fp8_dtype_forward,
TE_DType[activation_dtype],
)
if requires_fc2_wgrad and not fp8_wgrad:
fc2_input = cast_from_fp8(
fc2_input,
fp8_meta["scaling_fwd"],
fc2_input_fp8_index,
fp8_dtype_forward,
TE_DType[activation_dtype],
)
fc2_dgrad, fc2_wgrad = _linear_bwd_fp8(
fc2_input_no_fp8,
fc2_input_t,
fc2_input,
None,
fc2_input_fp8_index,
fc2_weight_t_fp8,
fc2_weight_fp8_index,
......@@ -217,6 +218,7 @@ def _mlp_backward(
activation_dtype,
'row' if set_parallel_mode else None,
tensor_parallel,
sequence_parallel,
tp_group,
)
......@@ -233,30 +235,27 @@ def _mlp_backward(
fc1_bgrad = fc1_bgrad_
# FC1 Bwd
dgelu_no_fp8, fc1_input_no_fp8, fc1_input_t = None, None, None
if requires_fc1_wgrad:
if fp8_wgrad:
fc1_input_t = transpose(fc1_input, fp8_dtype_forward)
else:
# TODO(tizheng) Paddle lacks fused dgelu_bgrad OP. Cast from dgrad(fp8) instead.
dgelu_no_fp8 = cast_from_fp8(
dgelu,
fp8_meta["scaling_bwd"],
fc1_grad_output_fp8_index,
fp8_dtype_backward,
TE_DType[activation_dtype],
)
fc1_input_no_fp8 = cast_from_fp8(
fc1_input,
fp8_meta["scaling_fwd"],
fc1_input_fp8_index,
fp8_dtype_forward,
TE_DType[activation_dtype],
)
dgelu_no_fp8 = None
if requires_fc1_wgrad and not fp8_wgrad:
# TODO(tizheng) Paddle lacks fused dgelu_bgrad OP. Cast from dgrad(fp8) instead.
dgelu_no_fp8 = cast_from_fp8(
dgelu,
fp8_meta["scaling_bwd"],
fc1_grad_output_fp8_index,
fp8_dtype_backward,
TE_DType[activation_dtype],
)
fc1_input = cast_from_fp8(
fc1_input,
fp8_meta["scaling_fwd"],
fc1_input_fp8_index,
fp8_dtype_forward,
TE_DType[activation_dtype],
)
fc1_dgrad, fc1_wgrad = _linear_bwd_fp8(
fc1_input_no_fp8,
fc1_input_t,
fc1_input,
None,
fc1_input_fp8_index,
fc1_weight_t_fp8,
fc1_weight_fp8_index,
......@@ -271,6 +270,7 @@ def _mlp_backward(
activation_dtype,
'column' if set_parallel_mode else None,
tensor_parallel,
sequence_parallel,
tp_group,
)
else:
......@@ -284,6 +284,7 @@ def _mlp_backward(
activation_dtype,
'row' if set_parallel_mode else None,
tensor_parallel,
sequence_parallel,
tp_group,
gelu_input=fc1_out,
activation=activation,
......@@ -298,6 +299,7 @@ def _mlp_backward(
activation_dtype,
'column' if set_parallel_mode else None,
tensor_parallel,
sequence_parallel,
tp_group,
)
return (
......@@ -337,6 +339,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
activation: str,
set_parallel_mode: bool,
tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]:
......@@ -398,6 +401,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
is_grad_enabled,
set_parallel_mode,
tensor_parallel,
sequence_parallel,
tp_group,
)
......@@ -429,6 +433,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
ctx.zero_centered_gamma = zero_centered_gamma
ctx.set_parallel_mode = set_parallel_mode
ctx.tensor_parallel = tensor_parallel
ctx.sequence_parallel = sequence_parallel
ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.requires_dgrad = not inp.stop_gradient
......@@ -476,7 +481,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
grad_output_c,
grad_output_t,
fc2_bgrad,
) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_outputs[0])
) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_outputs[0], True)
(
fc1_dgrad,
......@@ -513,6 +518,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
ctx.activation,
ctx.set_parallel_mode,
ctx.tensor_parallel,
ctx.sequence_parallel,
ctx.tp_group,
)
if not ctx.fp8_enabled:
......@@ -588,6 +594,8 @@ class LayerNormMLP(TransformerEngineBaseLayer):
set_parallel_mode : bool, default = `False`
if set to `True`, FC1 is used as Column Parallel and FC2 is used as Row
Parallel as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_group : paddle.distributed.collective.Group, default = `None`
tensor parallel process group.
......@@ -604,6 +612,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
return_layernorm_output: bool = False,
zero_centered_gamma: bool = False,
set_parallel_mode: bool = False,
sequence_parallel: bool = False,
tp_group: Optional[dist_group_type] = None,
backend: str = 'transformer_engine',
) -> None:
......@@ -626,6 +635,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
enable_tp=set_parallel_mode)
self.tensor_parallel = self.tp_size > 1
self.set_parallel_mode = set_parallel_mode
self.sequence_parallel = self.tensor_parallel and sequence_parallel
if self.set_parallel_mode:
self.size_per_partition = divide(self.ffn_hidden_size, self.tp_size)
......@@ -648,6 +658,10 @@ class LayerNormMLP(TransformerEngineBaseLayer):
is_bias=True,
)
if self.sequence_parallel:
mark_as_sequence_parallel_parameter(self.ln_weight)
mark_as_sequence_parallel_parameter(self.ln_bias)
# FC1 weights
with track_rng_state(enable=self.tensor_parallel):
self.fc1_weight = self.create_parameter(
......@@ -698,6 +712,8 @@ class LayerNormMLP(TransformerEngineBaseLayer):
dtype=self._dtype,
is_bias=True,
)
if self.set_parallel_mode and self.sequence_parallel:
mark_as_sequence_parallel_parameter(self.fc2_bias)
else:
self.fc2_bias = None
......@@ -751,6 +767,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
self.activation,
self.set_parallel_mode,
self.tensor_parallel,
self.sequence_parallel,
self.tp_group,
self.tp_size,
)
......
......@@ -18,14 +18,17 @@ from .base import (
)
from ..constants import FP8FwdTensors, FP8BwdTensors, GemmParallelModes, dist_group_type
from ..cpp_extensions import gemm, fp8_gemm, cast_to_fp8, cast_transpose
from ..cpp_extensions import gemm, fp8_gemm, cast_to_fp8, cast_transpose, transpose
from ..distributed import (
allgather,
allreduce,
get_tp_group_and_world_size,
identity,
reduce_scatter,
track_rng_state,
set_tensor_dist_attr,
set_weight_tensor_dist_attr,
mark_as_sequence_parallel_parameter,
)
from ..fp8 import get_fp8_te_dtype
from ..utils import (
......@@ -52,6 +55,7 @@ def _linear_fwd_fp8(
activation_dtype: paddle.dtype,
parallel_mode: Union[str, None],
tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None],
is_grad_enabled: bool,
):
......@@ -60,6 +64,11 @@ def _linear_fwd_fp8(
bias_dtype = get_bias_dtype(activation_dtype)
bias = cast_if_needed(bias, bias_dtype)
if parallel_mode == "column" and sequence_parallel:
inputmat_total, _ = allgather(inputmat, tp_group)
else:
inputmat_total = inputmat
if is_grad_enabled:
weight_fp8, weight_t_fp8 = cast_transpose(
weight,
......@@ -81,7 +90,7 @@ def _linear_fwd_fp8(
fp8_meta["scaling_fwd"].scale_inv,
weight_fp8_index,
fp8_dtype_forward,
inputmat,
inputmat_total,
fp8_meta["scaling_fwd"].scale_inv,
inputmat_fp8_index,
fp8_dtype_forward,
......@@ -92,8 +101,9 @@ def _linear_fwd_fp8(
use_split_accumulator=_2X_ACC_FPROP,
)
# Row Parallel Linear
if parallel_mode == "row" and tensor_parallel:
if parallel_mode == "row" and sequence_parallel:
out, _ = reduce_scatter(out, tp_group)
elif parallel_mode == "row" and tensor_parallel:
out, _ = allreduce(out, tp_group)
return out, weight_t_fp8
......@@ -111,11 +121,17 @@ def _linear_fwd_non_fp8(
activation_dtype: paddle.dtype,
parallel_mode: Union[str, None],
tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None],
activation: str = "",
):
"""Non-FP8 path of Linear Fwd"""
if parallel_mode == "column" and sequence_parallel:
inputmat_total, _ = allgather(inputmat, tp_group)
else:
inputmat_total = inputmat
# Layer parameters are initialized as float32 dtype by default.
# Cast the parameters to activation_dtype if the current dtype
# does not match activation_dtype. The casting is inplace, so it
......@@ -126,13 +142,13 @@ def _linear_fwd_non_fp8(
if fp8_calibration:
# amax of input
fp8_meta["scaling_fwd"].amax_history[0, inputmat_fp8_index.value] = \
paddle.max(paddle.abs(inputmat)).item()
paddle.max(paddle.abs(inputmat_total)).item()
# amax of weight
fp8_meta["scaling_fwd"].amax_history[0, weight_fp8_index.value] = \
paddle.max(paddle.abs(weight)).item()
outputs = gemm(weight,
inputmat,
inputmat_total,
activation_dtype,
get_workspace(),
bias=bias,
......@@ -144,8 +160,10 @@ def _linear_fwd_non_fp8(
return out, gelu_out
out, _, _ = outputs
# Row Parallel Linear
if parallel_mode == "row" and tensor_parallel:
if parallel_mode == "row" and sequence_parallel:
out, _ = reduce_scatter(out, tp_group)
elif parallel_mode == "row" and tensor_parallel:
out, _ = allreduce(out, tp_group)
return out
......@@ -163,6 +181,7 @@ def _linear_fwd(
activation_dtype: paddle.dtype,
parallel_mode: Union[str, None],
tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None],
is_grad_enabled: bool,
):
......@@ -178,6 +197,7 @@ def _linear_fwd(
activation_dtype,
parallel_mode,
tensor_parallel,
sequence_parallel,
tp_group,
is_grad_enabled,
)
......@@ -194,6 +214,7 @@ def _linear_fwd(
activation_dtype,
parallel_mode,
tensor_parallel,
sequence_parallel,
tp_group,
)
return (
......@@ -219,9 +240,20 @@ def _linear_bwd_fp8(
activation_dtype: paddle.dtype,
parallel_mode: Union[str, None],
tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None],
):
dgrad, wgrad, handle = None, None, None
# Overlap input AG with dgrad
inputmat_total = None
inputmat_t_total = None
if requires_wgrad and parallel_mode == "column" and sequence_parallel:
inputmat_total, handle = allgather(inputmat, tp_group, sync_op=not requires_dgrad)
else:
inputmat_total = inputmat
inputmat_t_total = inputmat_t
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False)
if requires_dgrad:
......@@ -238,13 +270,21 @@ def _linear_bwd_fp8(
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
if parallel_mode == "column" and tensor_parallel:
# Overlap dgrad-RS/AR with wgrad
if parallel_mode == "column" and sequence_parallel:
if handle is not None:
handle.wait()
dgrad, handle = reduce_scatter(dgrad, tp_group, sync_op=False)
elif parallel_mode == "column" and tensor_parallel:
dgrad, handle = allreduce(dgrad, tp_group, sync_op=False)
if requires_wgrad:
if not fp8_meta["recipe"].override_linear_precision.wgrad:
if inputmat_t_total is None:
inputmat_t_total = transpose(inputmat_total, fp8_dtype_backward)
wgrad = fp8_gemm(
inputmat_t,
inputmat_t_total,
fwd_scale_inverses,
inputmat_fp8_index,
fp8_dtype_forward,
......@@ -258,7 +298,7 @@ def _linear_bwd_fp8(
)
else:
wgrad, _, _ = gemm(
inputmat,
inputmat_total,
grad_output,
activation_dtype,
get_workspace(),
......@@ -282,6 +322,7 @@ def _linear_bwd_non_fp8(
activation_dtype: paddle.dtype,
parallel_mode: Union[str, None],
tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None],
gelu_input: Union[paddle.Tensor, None] = None,
activation: str = "",
......@@ -290,6 +331,14 @@ def _linear_bwd_non_fp8(
Performs Linear Backward. Optionally, fuses GELU backward and dbias.
"""
dgrad, wgrad, bgrad, handle = None, None, None, None
# Overlap input AG with dgrad
inputmat_total = None
if requires_wgrad and parallel_mode == "column" and sequence_parallel:
inputmat_total, handle = allgather(inputmat, tp_group, sync_op=not requires_dgrad)
else:
inputmat_total = inputmat
if requires_dgrad:
dgrad, _, _ = gemm(
weight,
......@@ -301,12 +350,17 @@ def _linear_bwd_non_fp8(
gelu_input=gelu_input,
grad=True,
)
if parallel_mode == "column" and tensor_parallel:
# Overlap dgrad-RS/AR with wgrad
if parallel_mode == "column" and sequence_parallel:
if handle is not None:
handle.wait()
dgrad, handle = reduce_scatter(dgrad, tp_group, sync_op=False)
elif parallel_mode == "column" and tensor_parallel:
dgrad, handle = allreduce(dgrad, tp_group, sync_op=False)
if requires_wgrad:
wgrad, bgrad, _ = gemm(
inputmat,
inputmat_total,
grad_output,
activation_dtype,
get_workspace(),
......@@ -343,6 +397,7 @@ def _linear_bwd(
activation_dtype: paddle.dtype,
parallel_mode: Union[str, None],
tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None],
):
dgrad, wgrad, bgrad = None, None, None
......@@ -364,6 +419,7 @@ def _linear_bwd(
activation_dtype,
parallel_mode,
tensor_parallel,
sequence_parallel,
tp_group,
)
else:
......@@ -377,6 +433,7 @@ def _linear_bwd(
activation_dtype,
parallel_mode,
tensor_parallel,
sequence_parallel,
tp_group,
)
return dgrad, wgrad, bgrad
......@@ -399,6 +456,7 @@ class _Linear(paddle.autograd.PyLayer):
is_grad_enabled: bool,
parallel_mode: Union[str, None],
tensor_parallel: bool,
sequence_parallel: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
) -> paddle.Tensor:
......@@ -413,31 +471,24 @@ class _Linear(paddle.autograd.PyLayer):
inputmat_no_fp8 = inputmat
# FP8 casting
inputmat_t = None
if fp8_enabled:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not fp8_meta["recipe"].override_linear_precision.wgrad:
if is_grad_enabled:
inputmat, inputmat_t = cast_transpose(
inputmat,
fp8_meta["scaling_fwd"],
FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else:
inputmat = cast_to_fp8(
inputmat,
fp8_meta["scaling_fwd"],
FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
if (not fp8_meta["recipe"].override_linear_precision.wgrad and is_grad_enabled
and not sequence_parallel):
inputmat, inputmat_t = cast_transpose(
inputmat,
fp8_meta["scaling_fwd"],
FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else:
inputmat, inputmat_t = cast_to_fp8(
inputmat = cast_to_fp8(
inputmat,
fp8_meta["scaling_fwd"],
FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
), None
)
# GEMM Fwd
out, weight_t_fp8 = _linear_fwd(
......@@ -453,16 +504,21 @@ class _Linear(paddle.autograd.PyLayer):
activation_dtype,
parallel_mode,
tensor_parallel,
sequence_parallel,
tp_group,
is_grad_enabled,
)
if is_grad_enabled:
fp8_wgrad = fp8_enabled and not fp8_meta["recipe"].override_linear_precision.wgrad
saved_inputmat = None
if fp8_enabled and sequence_parallel:
saved_inputmat = inputmat
else:
saved_inputmat = inputmat_no_fp8
save_for_backward_allow_none(
ctx,
inputmat_no_fp8 if not weight.stop_gradient and not fp8_wgrad else None,
inputmat_t if not weight.stop_gradient and fp8_wgrad else None,
saved_inputmat,
inputmat_t,
weight,
weight_t_fp8 if fp8_enabled else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8_enabled else None,
......@@ -474,6 +530,7 @@ class _Linear(paddle.autograd.PyLayer):
ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode
ctx.tensor_parallel = tensor_parallel
ctx.sequence_parallel = sequence_parallel
ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.requires_dgrad = not inp.stop_gradient
......@@ -503,7 +560,8 @@ class _Linear(paddle.autograd.PyLayer):
grad_output_c,
grad_output_t,
bgrad,
) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_output)
) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_output,
ctx.parallel_mode == "row")
dgrad, wgrad, bgrad_ = _linear_bwd(
inputmat,
......@@ -525,6 +583,7 @@ class _Linear(paddle.autograd.PyLayer):
ctx.activation_dtype,
ctx.parallel_mode,
ctx.tensor_parallel,
ctx.sequence_parallel,
ctx.tp_group,
)
......@@ -570,7 +629,8 @@ class Linear(TransformerEngineBaseLayer):
used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
"""
def __init__(
......@@ -580,6 +640,7 @@ class Linear(TransformerEngineBaseLayer):
weight_attr: Union[paddle.ParamAttr, None] = None,
bias_attr: Union[paddle.ParamAttr, None, bool] = None,
parallel_mode: Optional[str] = None,
sequence_parallel: bool = False,
tp_group: Union[dist_group_type, None] = None,
backend: str = 'transformer_engine',
) -> None:
......@@ -605,6 +666,8 @@ class Linear(TransformerEngineBaseLayer):
elif self.parallel_mode == "row":
self.in_features = divide(self.in_features, self.tp_size)
self.sequence_parallel = self.tensor_parallel and sequence_parallel
# Initialize weight parameter
with track_rng_state(enable=self.tensor_parallel):
# TE linear weight is in column major
......@@ -631,6 +694,8 @@ class Linear(TransformerEngineBaseLayer):
)
if parallel_mode == "column":
set_tensor_dist_attr(self.bias, self.tensor_parallel, axis=0)
if parallel_mode == "row" and self.sequence_parallel:
mark_as_sequence_parallel_parameter(self.bias)
else:
self.bias = None
......@@ -665,6 +730,7 @@ class Linear(TransformerEngineBaseLayer):
paddle.is_grad_enabled(),
self.parallel_mode,
self.tensor_parallel,
self.sequence_parallel,
self.tp_group,
self.tp_size,
)
......
......@@ -4,6 +4,7 @@
"""Transformer"""
from typing import Optional, Union
import warnings
import paddle
from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd
......@@ -75,6 +76,8 @@ class TransformerLayer(paddle.nn.Layer):
if set to `True`, QKV and FC1 layers are used as Column Parallel
whereas PROJ and FC2 is used as Row Parallel as described
`here <https://arxiv.org/pdf/1909.08053.pdf>`_.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
attention_dropout_rng_state_name : str, default = `local_seed`
......@@ -107,6 +110,7 @@ class TransformerLayer(paddle.nn.Layer):
zero_centered_gamma: bool = False,
activation: str = 'gelu',
set_parallel_mode: bool = False,
sequence_parallel: bool = False,
tp_group: Optional[dist_group_type] = None,
attention_dropout_rng_state_name: str = 'local_seed',
hidden_dropout_rng_state_name: str = 'global_seed',
......@@ -122,7 +126,13 @@ class TransformerLayer(paddle.nn.Layer):
self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group,
enable_tp=set_parallel_mode)
self.tensor_parallel = self.tp_size > 1
self.sequence_parallel = self.tensor_parallel and sequence_parallel
self.hidden_dropout_rng_state_name = hidden_dropout_rng_state_name
# SP needs local seed for hidden dropout
if self.sequence_parallel and self.hidden_dropout_rng_state_name == 'global_seed':
warnings.warn("RNG state for hidden dropout needs to be different across TP ranks. "
"Forcing hidden_dropout_rng_state_name to 'local_seed'")
self.hidden_dropout_rng_state_name = 'local_seed'
assert (self_attn_mask_type
in AttnMaskTypes), f"self_attn_mask_type {self_attn_mask_type} not supported"
......@@ -141,6 +151,7 @@ class TransformerLayer(paddle.nn.Layer):
"return_layernorm_output": apply_residual_connection_post_layernorm,
"zero_centered_gamma": zero_centered_gamma,
"set_parallel_mode": set_parallel_mode,
"sequence_parallel": self.sequence_parallel,
"tp_group": tp_group,
"rng_state_name": attention_dropout_rng_state_name,
"backend": backend,
......@@ -173,6 +184,7 @@ class TransformerLayer(paddle.nn.Layer):
return_layernorm_output=apply_residual_connection_post_layernorm,
zero_centered_gamma=zero_centered_gamma,
set_parallel_mode=set_parallel_mode,
sequence_parallel=self.sequence_parallel,
tp_group=tp_group,
backend=backend,
)
......@@ -186,6 +198,7 @@ class TransformerLayer(paddle.nn.Layer):
weight_attr,
bias_attr,
zero_centered_gamma=zero_centered_gamma,
sequence_parallel=self.sequence_parallel,
backend=backend,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment