Unverified Commit b8ba734e authored by Tian Zheng's avatar Tian Zheng Committed by GitHub
Browse files

[Paddle] Add parallel support (#357)



* [Paddle] Add TP, DP, PP, FSDP
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Minor fix
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix CI failure
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Remove set_nccl_overlap_warning_if_tp
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Improve variable naming
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Refactor FP8 Buffer
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Stylic changes
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix FP32 parallel training
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix numel performance issue
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Squashed commit of the following:

commit 79e2e5fd774e67dcdda9aae01a9f31a6479c5d70
Author: Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
Date:   Sun Aug 20 14:39:16 2023 +0000

    Add TP test
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

commit 1d40ad60540490f97ed82ba877cc6eda8902cbf6
Author: Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
Date:   Sun Aug 20 14:22:25 2023 +0000

    Fix tp_size when disabled
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

commit 6632f735a0c8251862355fc74622af59fae3a509
Author: Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
Date:   Sun Aug 20 05:52:18 2023 +0000

    Add TP for attention and transformer layer
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add shape check
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add FSDP check for stage 1,2,3
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Review changes
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix group_sharding test
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Support NVTE_FUSE_ATTN
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix CI errors
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

---------
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 6aa1fcc8
......@@ -57,11 +57,13 @@ class Net(nn.Layer):
def train(args, model, train_loader, optimizer, epoch, use_fp8):
"""Training function."""
model.train()
losses = []
for batch_id, (data, labels) in enumerate(train_loader):
with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager
with te.fp8_autocast(enabled=use_fp8):
outputs = model(data)
loss = F.cross_entropy(outputs, labels)
losses.append(loss.item())
loss.backward()
optimizer.step()
......@@ -74,7 +76,9 @@ def train(args, model, train_loader, optimizer, epoch, use_fp8):
f"Loss: {loss.item():.6f}")
if args.dry_run:
return loss.item()
return loss.item()
avg_loss = sum(losses) / len(losses)
print(f"Train Epoch: {epoch}, Average Loss: {avg_loss}")
return avg_loss
def evaluate(model, test_loader, epoch, use_fp8):
......@@ -226,7 +230,7 @@ class TestMNIST(unittest.TestCase):
@staticmethod
def verify(actual):
"""Check If loss and accuracy match target"""
desired_traing_loss = 0.5
desired_traing_loss = 0.1
desired_test_accuracy = 0.98
assert actual[0] < desired_traing_loss
assert actual[1] > desired_test_accuracy
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Helper functions to launch distributed tests"""
import copy
import os
from pathlib import Path
import subprocess
import time
import unittest
from paddle import fluid
from paddle.distributed.utils.launch_utils import (
TrainerProc,
find_free_ports,
get_cluster,
watch_local_trainers,
)
__all__ = ['TestDistributed']
def get_cluster_from_args(selected_gpus):
"""Get node information from selected GPUs"""
cluster_node_ips = '127.0.0.1'
node_ip = '127.0.0.1'
node_ips = [x.strip() for x in cluster_node_ips.split(',')]
node_ips.index(node_ip)
free_ports = None
free_ports = find_free_ports(len(selected_gpus))
if free_ports is not None:
free_ports = list(free_ports)
trainer_endpoints = []
for ip in node_ips:
trainer_endpoints.append([f"{ip}:{port}" for port in free_ports])
return get_cluster(node_ips, node_ip, trainer_endpoints, selected_gpus)
def get_gpus(selected_gpus):
"""Get selected GPU string"""
selected_gpus = [x.strip() for x in selected_gpus.split(',')]
return selected_gpus
def start_local_trainers(
cluster,
pod,
training_script,
training_script_args,
allocator_strategy="auto_growth",
):
"""Launch trainers"""
current_env = copy.copy(os.environ.copy())
# paddle broadcast ncclUniqueId use socket, and
# proxy maybe make trainers unreachable, so delete them.
# if we set them to "", grpc will log error message "bad uri"
# so just delete them.
current_env.pop("http_proxy", None)
current_env.pop("https_proxy", None)
procs = []
for t in pod.trainers:
proc_env = {
"FLAGS_selected_gpus": ",".join([str(g) for g in t.gpus]),
"PADDLE_TRAINER_ID": f"{t.rank}",
"PADDLE_CURRENT_ENDPOINT": f"{t.endpoint}",
"PADDLE_TRAINERS_NUM": f"{cluster.trainers_nranks()}",
"PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()),
"PYTHONPATH": str(Path(__file__).resolve().parent),
}
proc_env["FLAGS_allocator_strategy"] = allocator_strategy
if allocator_strategy == "auto_growth":
proc_env["FLAGS_fraction_of_gpu_memory_to_use"] = "0.1"
current_env.update(proc_env)
print(f"trainer proc env:{current_env}")
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
cmd = "python -m coverage run --branch -p " + training_script
else:
cmd = "python -u " + training_script
print(f"start trainer proc:{cmd} env:{proc_env}")
fn = None
proc = subprocess.Popen(cmd.split(" ") + training_script_args, env=current_env) # pylint: disable=consider-using-with
tp = TrainerProc()
tp.proc = proc
tp.rank = t.rank
tp.log_fn = fn
tp.cmd = cmd
procs.append(tp)
return procs
class TestDistributed(unittest.TestCase):
"""Base class for distributed test"""
@staticmethod
def run_2gpu(
target_file_name,
allocator_strategy="auto_growth",
):
"""Run target file in subprocesses"""
if (not fluid.core.is_compiled_with_cuda() or fluid.core.get_cuda_device_count() == 0):
return
selected_gpus = get_gpus('0,1')
cluster = None
pod = None
cluster, pod = get_cluster_from_args(selected_gpus)
procs = start_local_trainers(
cluster,
pod,
allocator_strategy=allocator_strategy,
training_script=target_file_name,
training_script_args=[],
)
while True:
alive = watch_local_trainers(procs, cluster.trainers_endpoints())
if not alive:
print(f"Local procs complete, POD info:{pod}")
break
time.sleep(3)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Unittest for Linear layer in tensor parallel"""
import unittest
import paddle
from paddle.distributed import fleet
from utils import assert_allclose, set_random_seed
import transformer_engine.paddle as te
def assert_allclose_across_ranks(tensor, group=None):
"""Assert tensor is identical in all ranks"""
gathered_list = []
paddle.distributed.all_gather(gathered_list, tensor, group=group)
assert len(gathered_list) > 1
for gathered_tensor in gathered_list:
assert_allclose(tensor, gathered_tensor)
class TestAmaxReduction(unittest.TestCase):
"""Tests Amax reduction"""
def setUp(self):
self.data_parallel_size = 2
self.init_dist_env()
self.global_dtype = 'bfloat16'
paddle.set_default_dtype(self.global_dtype)
def init_dist_env(self):
"""Init Paddle Fleet environment"""
strategy = fleet.DistributedStrategy()
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": 1,
"pp_degree": 1,
}
fleet.init(is_collective=True, strategy=strategy)
def test_amax_reduction(self):
"""Tests column parallel linear"""
set_random_seed(1024)
layer1 = te.Linear(16, 16)
layer2 = te.Linear(16, 16)
model = paddle.nn.Sequential(layer1, layer2)
model = fleet.distributed_model(model)
rank_id = paddle.distributed.get_rank()
set_random_seed(rank_id)
optimizer = paddle.optimizer.SGD(learning_rate=10.0, parameters=model.parameters())
optimizer = fleet.distributed_optimizer(optimizer)
def train_one_step(layer, inp, optimizer):
inp = paddle.to_tensor(inp)
inp.stop_gradient = False
out = layer(inp)
loss = out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss
for _ in range(5):
inp = paddle.uniform([16, 16], self.global_dtype)
with te.fp8_autocast(enabled=True):
train_one_step(model, inp, optimizer)
assert_allclose_across_ranks(layer1.fp8_meta["scaling_fwd"].amax_history[-1])
assert_allclose_across_ranks(layer1.fp8_meta["scaling_fwd"].scale)
assert_allclose_across_ranks(layer1.fp8_meta["scaling_fwd"].scale_inv)
assert_allclose_across_ranks(layer2.fp8_meta["scaling_fwd"].amax_history[-1])
assert_allclose_across_ranks(layer2.fp8_meta["scaling_fwd"].scale)
assert_allclose_across_ranks(layer2.fp8_meta["scaling_fwd"].scale_inv)
assert_allclose_across_ranks(layer1.fp8_meta["scaling_bwd"].amax_history[-1])
assert_allclose_across_ranks(layer1.fp8_meta["scaling_bwd"].scale)
assert_allclose_across_ranks(layer1.fp8_meta["scaling_bwd"].scale_inv)
assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].amax_history[-1])
assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].scale)
assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].scale_inv)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Unittest for group sharding"""
import unittest
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import (
DygraphShardingOptimizer,)
from utils import assert_allclose, set_random_seed
import transformer_engine.paddle as te
class TestGroupSharding(unittest.TestCase):
"""Tests group sharding"""
def setUp(self):
self.set_attr()
self.init_dist_env()
paddle.set_default_dtype(self.global_dtype)
def set_attr(self):
"""Set test configs"""
self.sharding_degree = 2
self.global_dtype = 'float32'
self.rtol = 1e-5
self.atol = 1e-5
self.batch_size = 16
self.in_channels = 16
self.out_channels = 32
self.fp8 = False
def init_dist_env(self):
"""Init Paddle Fleet environment"""
strategy = fleet.DistributedStrategy()
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": 1,
"pp_degree": 1,
"sharding_degree": self.sharding_degree,
}
self.strategy = strategy
fleet.init(is_collective=True, strategy=strategy)
def _get_model_and_optimizer(self, model, stage):
if stage == 1:
optimizer = DygraphShardingOptimizer(
hcg=fleet.get_hybrid_communicate_group(),
user_defined_strategy=self.strategy,
params=model.parameters(),
inner_optimizer_class=paddle.optimizer.AdamW,
learning_rate=0.01,
)
model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer)
elif stage in [2, 3]:
optimizer = paddle.optimizer.AdamW(learning_rate=0.01, parameters=model.parameters())
group = fleet.get_hybrid_communicate_group().get_sharding_parallel_group()
class ShardingLevel: # pylint: disable=too-few-public-methods,
"""Paddle sharding options"""
kStage1 = 'os'
kStage2 = 'os_g'
kStage3 = 'p_g_os'
level = ShardingLevel.kStage3 if stage == 3 else ShardingLevel.kStage2
model, optimizer, _ = paddle.distributed.sharding.group_sharded_parallel(
model=model,
optimizer=optimizer,
level=level,
group=group,
segment_size=256,
)
else:
raise ValueError(f"Stage {stage} not supported")
return model, optimizer
def test_group_sharding_stage1(self):
"""Tests group sharding training"""
set_random_seed(1024)
model_te = te.Linear(self.in_channels, self.out_channels)
model_pd = paddle.nn.Linear(self.in_channels, self.out_channels)
model_pd.weight.copy_(model_te.weight.T, True)
model_pd.bias.copy_(model_te.bias, True)
model_te, optimizer_te = self._get_model_and_optimizer(model_te, stage=1)
model_pd, optimizer_pd = self._get_model_and_optimizer(model_pd, stage=1)
rank_id = paddle.distributed.get_rank()
paddle.seed(rank_id)
def train_one_step(model, inp, optimizer):
out = model(inp)
loss = out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.in_channels], self.global_dtype)
with te.fp8_autocast(enabled=False):
loss_te = train_one_step(model_te, inp, optimizer_te)
loss_pd = train_one_step(model_pd, inp, optimizer_pd)
assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol)
assert len(optimizer_te.state_dict()) == 4, \
"Expect each rank to hold 4 optimizer state entries."
def test_group_sharding_stage2(self):
"""Tests group sharding training"""
set_random_seed(1024)
model_te = te.Linear(self.in_channels, self.out_channels)
model_pd = paddle.nn.Linear(self.in_channels, self.out_channels)
model_pd.weight.copy_(model_te.weight.T, True)
model_pd.bias.copy_(model_te.bias, True)
model_te, optimizer_te = self._get_model_and_optimizer(model_te, stage=2)
model_pd, optimizer_pd = self._get_model_and_optimizer(model_pd, stage=2)
rank_id = paddle.distributed.get_rank()
paddle.seed(rank_id)
def train_one_step(model, inp, optimizer):
out = model(inp)
loss = out.mean()
loss.backward()
# Check gradients are split to different trainers
if rank_id == 0:
assert model.bias.grad is None and model.weight.grad is not None
elif rank_id == 1:
assert model.weight.grad is None and model.bias.grad is not None
optimizer.step()
optimizer.clear_grad()
return loss
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.in_channels], self.global_dtype)
with te.fp8_autocast(enabled=False):
loss_te = train_one_step(model_te, inp, optimizer_te)
loss_pd = train_one_step(model_pd, inp, optimizer_pd)
assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol)
assert len(optimizer_te.state_dict()) == 4, \
"Expect each rank to hold 4 optimizer state entries."
def test_group_sharding_stage3(self):
"""Tests group sharding training"""
set_random_seed(1024)
model_te = te.Linear(self.in_channels, self.out_channels)
model_pd = paddle.nn.Linear(self.in_channels, self.out_channels)
model_pd.weight.copy_(model_te.weight.T, True)
model_pd.bias.copy_(model_te.bias, True)
model_te, optimizer_te = self._get_model_and_optimizer(model_te, stage=3)
model_pd, optimizer_pd = self._get_model_and_optimizer(model_pd, stage=3)
rank_id = paddle.distributed.get_rank()
paddle.seed(rank_id)
def train_one_step(model, inp, optimizer):
out = model(inp)
loss = out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.in_channels], self.global_dtype)
with te.fp8_autocast(enabled=False):
loss_te = train_one_step(model_te, inp, optimizer_te)
loss_pd = train_one_step(model_pd, inp, optimizer_pd)
assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol)
for name, value in optimizer_te.state_dict().items():
if name.endswith('w_0_moment1_0'):
assert value.numel() == \
self.in_channels * self.out_channels // self.sharding_degree, \
"Expect optimizer state to be sharded across trainers."
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Unittest for LayerNormLinear layer in tensor parallel"""
import unittest
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.layers.mpu import mp_ops
from utils import assert_allclose, assert_shape, set_random_seed
import transformer_engine.paddle as te
class TestLayerNormLinearTp(unittest.TestCase):
"""Tests LayerNormLinear layer with column/row parallelism in BF16"""
def setUp(self):
self.set_attr()
self.init_dist_env()
paddle.set_default_dtype(self.global_dtype)
def init_dist_env(self):
"""Init Paddle Fleet environment"""
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
}
fleet.init(is_collective=True, strategy=strategy)
self.hcg = fleet.get_hybrid_communicate_group()
self.tp_group = self.hcg.get_model_parallel_group()
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = 'bfloat16'
self.rtol = 1e-3
self.atol = 1e-3
self.eps = 1e-3
self.fp8 = False
def test_column_parallel_layer(self):
"""Tests column parallel LayerNormLinear"""
set_random_seed(1024)
layer_te = te.LayerNormLinear(
self.in_features,
self.out_features,
eps=self.eps,
parallel_mode='column',
)
layer_pd = te.LayerNormLinear(
self.in_features,
self.out_features,
eps=self.eps,
backend='paddle',
)
# Get total weight
total_weight = []
partial_weight = layer_te.weight.clone().detach()
paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group)
total_weight = paddle.concat(total_weight, axis=0)
layer_pd.weight.copy_(total_weight.T, True)
assert_shape(layer_te.weight,
[self.out_features // self.model_parallel_size, self.in_features])
assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size])
optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters())
layer_te = fleet.distributed_model(layer_te)
optimizer_te = fleet.distributed_optimizer(optimizer_te)
def train_one_step(layer, inp, optimizer, gather=False):
inp = paddle.to_tensor(inp)
inp.stop_gradient = False
out = layer(inp)
if gather:
total_out = mp_ops._c_concat(out, group=self.tp_group)
else:
total_out = out
loss = total_out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss, inp.grad
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype)
with te.fp8_autocast(enabled=self.fp8):
loss_tp, grad_input = train_one_step(layer_te, inp, optimizer_te, gather=True)
loss_ref, grad_input_ref = train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
class TestLayerNormLinearTpFp8(TestLayerNormLinearTp):
"""Tests LayernormLinear layer with column/row parallelism in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = 'bfloat16'
self.rtol = 1e-2
self.atol = 1e-2
self.eps = 1e-3
self.fp8 = True
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Unittest for LayerNormMLP layer in tensor parallel"""
import unittest
import paddle
from paddle.distributed import fleet
from utils import assert_allclose, assert_shape, set_random_seed
import transformer_engine.paddle as te
class TestLayerNormMLPTp(unittest.TestCase):
"""Tests LayerNormMLP layer with model parallel in BF16"""
def setUp(self):
self.set_attr()
self.init_dist_env()
paddle.set_default_dtype(self.global_dtype)
def init_dist_env(self):
"""Init Paddle Fleet environment"""
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
}
fleet.init(is_collective=True, strategy=strategy)
self.hcg = fleet.get_hybrid_communicate_group()
self.tp_group = self.hcg.get_model_parallel_group()
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 32
self.ffn_hidden_size = 64
self.global_dtype = 'bfloat16'
self.rtol = 1e-3
self.atol = 1e-3
self.eps = 1e-3
self.fp8 = False
def test_parallel_layer(self):
"""Tests parallel LayerNormMLP"""
set_random_seed(1024)
layer_te = te.LayerNormMLP(
hidden_size=self.hidden_size,
ffn_hidden_size=self.ffn_hidden_size,
eps=self.eps,
set_parallel_mode=True,
)
layer_pd = te.LayerNormMLP(
hidden_size=self.hidden_size,
ffn_hidden_size=self.ffn_hidden_size,
eps=self.eps,
set_parallel_mode=False,
backend='paddle',
)
def _get_total_weight(local_weight, tp_group, axis):
total_weight = []
partial_weight = local_weight.clone().detach()
paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group)
total_weight = paddle.concat(total_weight, axis=axis)
return total_weight
# Get total weight
total_fc1_weight = _get_total_weight(layer_te.fc1_weight, tp_group=self.tp_group, axis=0)
total_fc2_weight = _get_total_weight(layer_te.fc2_weight, tp_group=self.tp_group, axis=1)
layer_pd.fc1_weight.copy_(total_fc1_weight.T, True)
layer_pd.fc2_weight.copy_(total_fc2_weight.T, True)
assert_shape(layer_te.fc1_weight,
[self.ffn_hidden_size // self.model_parallel_size, self.hidden_size])
assert_shape(layer_te.fc1_bias, [self.ffn_hidden_size // self.model_parallel_size])
assert_shape(layer_te.fc2_weight,
[self.hidden_size, self.ffn_hidden_size // self.model_parallel_size])
assert_shape(layer_te.fc2_bias, [self.hidden_size])
optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters())
layer_te = fleet.distributed_model(layer_te)
optimizer_te = fleet.distributed_optimizer(optimizer_te)
def train_one_step(layer, inp, optimizer):
inp = paddle.to_tensor(inp)
inp.stop_gradient = False
out = layer(inp)
loss = out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss, inp.grad
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.hidden_size], self.global_dtype)
with te.fp8_autocast(enabled=self.fp8):
loss_tp, grad_input = train_one_step(layer_te, inp, optimizer_te)
loss_ref, grad_input_ref = train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
class TestLayerNormMLPTpFp8(TestLayerNormMLPTp):
"""Tests LayerNormMLP layer with tensor parallelism in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 32
self.ffn_hidden_size = 64
self.global_dtype = 'bfloat16'
self.rtol = 1e-2
self.atol = 1e-2
self.eps = 1e-3
self.fp8 = True
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Unittest for Linear layer in pipeline parallel"""
import unittest
import numpy as np
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import (
LayerDesc,
PipelineLayer,
)
from utils import assert_allclose, set_random_seed
import transformer_engine.paddle as te
class TEPipelineModel(PipelineLayer):
"""Model for pipeline parallel test"""
def __init__(self,
in_features,
hidden_features,
weight_attrs,
use_te=True,
use_fp8=False,
**kwargs):
self.in_features = in_features
self.hidden_features = hidden_features
self.fp8 = use_fp8
hcg = fleet.get_hybrid_communicate_group()
self.dp_group = hcg.get_data_parallel_group()
Linear = te.Linear if use_te else paddle.nn.Linear
model_desc = [
LayerDesc(Linear, self.in_features, self.hidden_features, weight_attr=weight_attrs[0]),
LayerDesc(Linear, self.hidden_features, self.in_features, weight_attr=weight_attrs[1]),
]
super().__init__(layers=model_desc, loss_fn=paddle.nn.CrossEntropyLoss(), **kwargs)
def forward(self, *args, **kwargs):
with te.fp8_autocast(enabled=self.fp8, fp8_group=self.dp_group):
return super().forward(*args, **kwargs)
class StandaloneModel(paddle.nn.Layer):
"""Model for pipeline parallel test"""
def __init__(self, in_features, hidden_features, weight_attrs):
super().__init__()
self.in_features = in_features
self.hidden_features = hidden_features
Linear = paddle.nn.Linear
self.layer = paddle.nn.Sequential(
Linear(self.in_features, self.hidden_features, weight_attr=weight_attrs[0]),
Linear(self.hidden_features, self.in_features, weight_attr=weight_attrs[1]),
)
self.loss = paddle.nn.CrossEntropyLoss()
def forward(self, inp):
out = self.layer(inp[0])
loss = self.loss(out, inp[1])
return loss
class TestLinearPipelineParallel(unittest.TestCase):
"""Tests Linear layer with pipeline parallel"""
def setUp(self):
self.set_attr()
self.init_dist_env()
paddle.set_default_dtype(self.global_dtype)
def init_dist_env(self):
"""Init Paddle Fleet environment"""
strategy = fleet.DistributedStrategy()
self.pipeline_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": 1,
"pp_degree": self.pipeline_parallel_size,
}
strategy.pipeline_configs = {
"accumulate_steps": self.batch_size // self.micro_batch_size,
"micro_batch_size": self.micro_batch_size,
}
fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index()
self.hcg = fleet.get_hybrid_communicate_group()
def set_attr(self):
"""Set test configs"""
self.batch_size = 32
self.micro_batch_size = 16
self.in_features = 32
self.hidden_features = 64
self.global_dtype = 'float32'
self.rtol = 1e-5
self.atol = 1e-5
self.iter = 10
self.fp8 = False
def test_pipeline_train(self):
"""Test pipeline parallel training"""
set_random_seed(1024)
weight1_np = np.random.normal(size=[self.in_features, self.hidden_features])
weight2_np = np.random.normal(size=[self.hidden_features, self.in_features])
weight_attrs = [
paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight1_np)),
paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight2_np)),
]
weight_attrs_transposed = [
paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight1_np.T)),
paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight2_np.T)),
]
pipe_model = TEPipelineModel(
self.in_features,
self.hidden_features,
weight_attrs_transposed,
use_te=True,
use_fp8=self.fp8,
seg_method="layer:Linear",
num_stages=self.pipeline_parallel_size,
)
# Check if model is split across ranks as expected
for name, sublayer in pipe_model.named_sublayers():
if name in ('_loss_fn', 'shared_layers'):
continue
if self.rank == 0:
assert tuple(sublayer.weight.shape) == weight1_np.T.shape, \
f"Shape does not match, expect: {weight1_np.T.shape}, " \
f"actual: {tuple(sublayer.weight.shape)}"
elif self.rank == 1:
assert tuple(sublayer.weight.shape) == weight2_np.T.shape, \
f"Shape does not match, expect: {weight2_np.T.shape}, " \
f"actual: {tuple(sublayer.weight.shape)}"
standalone_model = StandaloneModel(
self.in_features,
self.hidden_features,
weight_attrs,
)
optimizer_te = paddle.optimizer.SGD(learning_rate=0.1, parameters=pipe_model.parameters())
optimizer_pd = paddle.optimizer.SGD(learning_rate=0.1,
parameters=standalone_model.parameters())
pipe_model = fleet.distributed_model(pipe_model)
optimizer_te = fleet.distributed_optimizer(optimizer_te)
def train_one_step(layer, inp, optimizer):
loss = layer(inp)
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss
for i in range(self.iter):
inp = paddle.to_tensor(np.random.normal(size=[self.batch_size, self.in_features]),
dtype=self.global_dtype)
label = paddle.to_tensor(np.random.randint(self.in_features, size=[self.batch_size, 1]))
loss_te = pipe_model.train_batch([inp, label], optimizer_te)
loss_pd = train_one_step(standalone_model, [inp, label], optimizer_pd)
print(f"Iter: {i}, loss_te: {loss_te.item()}, loss_pd: {loss_pd.item()}")
assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol)
class TestLinearPipelineParallelFP8(TestLinearPipelineParallel):
"""Tests Linear layer with column/row parallelism in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 32
self.micro_batch_size = 16
self.in_features = 32
self.hidden_features = 64
self.global_dtype = 'float32'
self.rtol = 5e-2
self.atol = 5e-2
self.iter = 10
self.fp8 = True
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Unittest for Linear layer in tensor parallel"""
import unittest
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.layers.mpu import mp_ops
from utils import assert_allclose, assert_shape, set_random_seed
import transformer_engine.paddle as te
class TestLinearTp(unittest.TestCase):
"""Tests Linear layer with column/row parallelism in BF16"""
def setUp(self):
self.set_attr()
self.init_dist_env()
paddle.set_default_dtype(self.global_dtype)
def init_dist_env(self):
"""Init Paddle Fleet environment"""
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
}
fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index()
self.hcg = fleet.get_hybrid_communicate_group()
self.tp_group = self.hcg.get_model_parallel_group()
self.world_size = self.hcg.get_model_parallel_world_size()
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = 'bfloat16'
self.rtol = 1e-3
self.atol = 1e-3
self.fp8 = False
def test_column_parallel_layer(self):
"""Tests column parallel linear"""
set_random_seed(1024)
layer_te = te.Linear(
self.in_features,
self.out_features,
parallel_mode='column',
)
layer_pd = te.Linear(
self.in_features,
self.out_features,
backend='paddle',
)
# Get total weight
total_weight = []
partial_weight = layer_te.weight.clone().detach()
paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group)
total_weight = paddle.concat(total_weight, axis=0)
layer_pd.weight.copy_(total_weight.T, True)
assert_shape(layer_te.weight,
[self.out_features // self.model_parallel_size, self.in_features])
assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size])
optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters())
layer_te = fleet.distributed_model(layer_te)
optimizer_te = fleet.distributed_optimizer(optimizer_te)
def train_one_step(layer, inp, optimizer, gather=False):
inp = paddle.to_tensor(inp)
inp.stop_gradient = False
out = layer(inp)
if gather:
total_out = mp_ops._c_concat(out, group=self.tp_group)
else:
total_out = out
loss = total_out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss, inp.grad
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype)
with te.fp8_autocast(enabled=self.fp8):
loss_tp, grad_input = train_one_step(layer_te, inp, optimizer_te, gather=True)
loss_ref, grad_input_ref = train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
def test_row_parallel_layer(self):
"""Tests row parallel linear"""
set_random_seed(1024)
layer_te = te.Linear(
self.in_features,
self.out_features,
parallel_mode='row',
)
layer_pd = te.Linear(
self.in_features,
self.out_features,
backend='paddle',
)
# Get total weight
total_weight = []
partial_weight = layer_te.weight.clone().detach()
paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group)
total_weight = paddle.concat(total_weight, axis=1)
layer_pd.weight.copy_(total_weight.T, True)
assert_shape(layer_te.weight,
[self.out_features, self.in_features // self.model_parallel_size])
assert_shape(layer_te.bias, [self.out_features])
optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters())
# Note(tizheng): For this test, we cannot wrap model with fleet.distributed_model,
# because it will broadcast inputs across mp group. However, RPL expects splitted
# inputs, which is different on each rank.
def train_one_step(layer, inp, optimizer, split=False):
inp = paddle.to_tensor(inp, stop_gradient=True)
if split:
# TODO(tizheng): Why not working?
# issue: https://github.com/PaddlePaddle/Paddle/issues/55565
# input_parallel = mp_ops._c_split(inp, group=layer.tp_group)
split_size = inp.shape[1] // self.world_size
input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)]
else:
input_parallel = inp
input_parallel.stop_gradient = False
out = layer(input_parallel)
loss = out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
if split:
grad_input = []
paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group)
grad_input = paddle.concat(grad_input, axis=1)
else:
grad_input = input_parallel.grad
return loss, grad_input
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype)
with te.fp8_autocast(enabled=self.fp8):
loss_tp, grad_input = train_one_step(layer_te, inp, optimizer_te, split=True)
loss_ref, grad_input_ref = train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
class TestLinearTpFP8(TestLinearTp):
"""Tests Linear layer with column/row parallelism in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = 'bfloat16'
self.rtol = 1e-2
self.atol = 1e-2
self.fp8 = True
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Unittest for Transformer layer in tensor parallel"""
import unittest
import paddle
from paddle.distributed import fleet
from utils import assert_allclose, set_random_seed
import transformer_engine.paddle as te
class TestTransformerTp(unittest.TestCase):
"""Tests Transformer layer with model parallel in BF16"""
def setUp(self):
self.set_attr()
self.init_dist_env()
paddle.set_default_dtype(self.global_dtype)
def init_dist_env(self):
"""Init Paddle Fleet environment"""
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
}
fleet.init(is_collective=True, strategy=strategy)
self.hcg = fleet.get_hybrid_communicate_group()
self.tp_group = self.hcg.get_model_parallel_group()
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.ffn_hidden_size = 4096
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.layer_type = 'encoder'
self.global_dtype = 'bfloat16'
self.rtol = 5e-2
self.atol = 5e-2
self.eps = 1e-3
self.fp8 = False
def test_parallel_layer(self):
"""Tests parallel Transformer"""
set_random_seed(1024)
common_args = [
self.hidden_size,
self.ffn_hidden_size,
self.num_heads,
]
common_kwargs = {
'layernorm_epsilon': self.eps,
'hidden_dropout': 0.0,
'attention_dropout': 0.0,
'self_attn_mask_type': self.mask_type,
'layer_type': self.layer_type,
}
layer_tp = te.TransformerLayer(*common_args, **common_kwargs, set_parallel_mode=True)
layer_single = te.TransformerLayer(*common_args, **common_kwargs, set_parallel_mode=False)
def _get_total_weight(local_weight, tp_group, axis):
total_weight = []
partial_weight = local_weight.clone().detach()
paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group)
total_weight = paddle.concat(total_weight, axis=axis)
return total_weight
def _get_weight(obj, weight_names):
for name in weight_names:
obj = getattr(obj, name)
return obj
def copy_weight(layer_src, layer_dst, partition_mode, weight_names):
weight_src = _get_weight(layer_src, weight_names)
weight_dst = _get_weight(layer_dst, weight_names)
if partition_mode is None:
total_weight = weight_src
elif partition_mode == 'column':
total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=0)
elif partition_mode == 'row':
total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1)
else:
raise ValueError(f"Partition Mode {partition_mode} is not supported.")
assert weight_dst.shape == total_weight.shape, \
f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match."
weight_dst.copy_(total_weight, True)
copy_weight(layer_tp, layer_single, None, ['self_attention', 'layernorm_qkv', 'ln_weight'])
copy_weight(layer_tp, layer_single, 'column', ['self_attention', 'layernorm_qkv', 'weight'])
copy_weight(layer_tp, layer_single, 'row', ['self_attention', 'proj', 'weight'])
copy_weight(layer_tp, layer_single, None, ['layernorm_mlp', 'ln_weight'])
copy_weight(layer_tp, layer_single, 'column', ['layernorm_mlp', 'fc1_weight'])
copy_weight(layer_tp, layer_single, 'row', ['layernorm_mlp', 'fc2_weight'])
optimizer_tp = paddle.optimizer.SGD(learning_rate=0.1, parameters=layer_tp.parameters())
optimizer_single = paddle.optimizer.SGD(learning_rate=0.1,
parameters=layer_single.parameters())
layer_tp = fleet.distributed_model(layer_tp)
optimizer_tp = fleet.distributed_optimizer(optimizer_tp)
def train_one_step(layer, inp_list, optimizer, fp8_enabled):
with te.fp8_autocast(enabled=fp8_enabled):
out = layer(*inp_list)
loss = out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.q_seqlen, self.hidden_size],
self.global_dtype)
mask = paddle.zeros(shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen),
dtype='bool')
loss_tp = train_one_step(layer_tp, [inp, mask], optimizer_tp, self.fp8)
loss_single = train_one_step(layer_single, [inp, mask], optimizer_single, self.fp8)
assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol)
class TestTransformerTpFp8(TestTransformerTp):
"""Tests Transformer layer with tensor parallelism in FP8"""
def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.ffn_hidden_size = 4096
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.layer_type = 'encoder'
self.global_dtype = 'bfloat16'
self.rtol = 5e-2
self.atol = 5e-2
self.eps = 1e-3
self.fp8 = True
if __name__ == '__main__':
unittest.main()
......@@ -98,8 +98,8 @@ class TestLinear:
"""
Test BF16 Linear
"""
rtol = 1e-2
atol = 1e-2
rtol = 5e-2
atol = 5e-2
input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype)
input_tensor.stop_gradient = no_dgrad
......@@ -258,8 +258,8 @@ class TestLayerNormLinear:
Test BF16 LayerNormLinear Layer
"""
paddle.set_default_dtype(activation_dtype)
rtol = 1e-2
atol = 1e-2
rtol = 5e-2
atol = 5e-2
input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype)
input_tensor.stop_gradient = no_dgrad
......@@ -905,7 +905,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
"""
paddle.set_default_dtype(math_dtype)
rtol = 5e-2
atol = 5e-2
atol = 6e-2
eps = 1e-3
encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)
......
......@@ -728,8 +728,8 @@ class TestFusedAttn:
return out, q_grad, k_grad, v_grad
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="cuDNN fMHA requires Ampere+ GPU")
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() not in ((8, 0), (9, 0)),
reason="cuDNN fMHA requires Ampere and Hopper GPU")
@pytest.mark.parametrize('b, s, h, d', SELF_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
@pytest.mark.parametrize('is_causal_masking', [True, False])
......@@ -745,8 +745,8 @@ class TestFusedAttn:
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="cuDNN fMHA requires Ampere+ GPU")
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() not in ((8, 0), (9, 0)),
reason="cuDNN fMHA requires Ampere and Hopper GPU")
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
def test_cross_attn_forward_backward(self, b, s_q, s_kv, h, d, dtype):
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Test TE Paddle Parallel"""
from pathlib import Path
import unittest
from dist_launcher import TestDistributed
from utils import is_devices_enough
from transformer_engine.paddle.fp8 import is_fp8_available
test_root = Path(__file__).resolve().parent
gpu_has_fp8, reason = is_fp8_available()
class TestParallelLinear(TestDistributed):
"""Test Linear in Parallel mode"""
@unittest.skipIf(not is_devices_enough(2), "TestParallelLinear needs 2 GPUs")
@unittest.skipIf(not gpu_has_fp8, reason)
def test_linear_tp(self):
"""Tests linear with tensor parallel in BF16"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'linear_tp.py'))
class TestParallelLayerNormLinear(TestDistributed):
"""Test LayerNormLinear in Parallel mode"""
@unittest.skipIf(not is_devices_enough(2), "TestParallelLayerNormLinear needs 2 GPUs")
@unittest.skipIf(not gpu_has_fp8, reason)
def test_layernorm_linear_tp(self):
"""Tests layernorm_linear with tensor parallel in BF16"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'layernorm_linear_tp.py'))
class TestParallelLayerNormMLP(TestDistributed):
"""Test LayerNormMLP in Parallel mode"""
@unittest.skipIf(not is_devices_enough(2), "TestParallelLayerNormMLP needs 2 GPUs")
@unittest.skipIf(not gpu_has_fp8, reason)
def test_layernorm_mlp_tp(self):
"""Tests layernorm_mlp with tensor parallel in BF16"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'layernorm_mlp_tp.py'))
class TestAmaxReduction(TestDistributed):
"""Test amax reduction in dp mode"""
@unittest.skipIf(not is_devices_enough(2), "TestAmaxReduction needs 2 GPUs")
@unittest.skipIf(not gpu_has_fp8, reason)
def test_amax_reduction(self):
"""Tests amax reduction"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'amax_reduction.py'))
class TestPipelineParallel(TestDistributed):
"""Test pipeline parallel"""
@unittest.skipIf(not is_devices_enough(2), "TestPipelineParallel needs 2 GPUs")
@unittest.skipIf(not gpu_has_fp8, reason)
def test_pipeline_parallel(self):
"""Tests pipeline parallel"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'linear_pp.py'))
class TestGroupSharding(TestDistributed):
"""Test group sharding"""
@unittest.skipIf(not is_devices_enough(2), "TestGroupSharding needs 2 GPUs")
@unittest.skipIf(not gpu_has_fp8, reason)
def test_group_sharding(self):
"""Tests group sharding"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'group_sharding.py'))
class TestParallelTransformerLayer(TestDistributed):
"""Test Transformer Layer in Parallel mode"""
@unittest.skipIf(not is_devices_enough(2), "TestParallelTransformerLayer needs 2 GPUs")
@unittest.skipIf(not gpu_has_fp8, reason)
def test_transformer_tp(self):
"""Tests Transformer Layer with tensor parallel in BF16"""
self.run_2gpu(str(test_root / 'parallel_tests' / 'transformer_tp.py'))
if __name__ == '__main__':
unittest.main()
......@@ -34,3 +34,21 @@ def assert_allclose(actual,
if isinstance(desired, paddle.Tensor):
desired = paddle.cast(desired, 'float32').numpy()
np.testing.assert_allclose(actual, desired, rtol, atol, equal_nan, err_msg, verbose)
def assert_shape(inp, expected_shape):
"""Assert the shape of input tensor equals to expected shape"""
assert inp.shape == expected_shape, f"Expected tensor shape: {expected_shape} != " \
f"actual tensor shape: {inp.shape}"
def is_devices_enough(required):
"""If the number of device is enough"""
return paddle.device.cuda.device_count() >= required
def set_random_seed(seed):
"""Set random seed for reproducability."""
np.random.seed(seed)
paddle.seed(seed)
paddle.distributed.fleet.meta_parallel.model_parallel_random_seed(seed)
......@@ -46,3 +46,7 @@ AttnMaskTypes = ("causal", "padding", "no_mask")
AttnTypes = ("self", "cross")
LayerTypes = ("encoder", "decoder")
GemmParallelModes = ("row", "column", None)
dist_group_type = paddle.distributed.collective.Group
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Methods needed for distributed training."""
from contextlib import contextmanager
from typing import Optional, Union, Tuple
import paddle
import paddle.distributed.fleet.base.topology as tp
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.fleet.layers.mpu import mp_ops
from .constants import dist_group_type
_weight_split_axis = {
'transformer_engine': {
'row': 1,
'column': 0
},
'paddle': {
'row': 0,
'column': 1
}
}
def get_tp_group_and_world_size(tp_group: Union[dist_group_type, None],
enable_tp: bool = True) -> Tuple[Union[dist_group_type, None], int]:
"""Get TP group and world size using Fleet API"""
if not (paddle.distributed.is_initialized() and enable_tp):
return None, 1
model_parallel_group = (tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group()
if tp_group is None else tp_group)
world_size = (tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size()
if tp_group is None else tp_group.nranks)
return model_parallel_group, world_size
@contextmanager
def track_rng_state(enable: bool) -> None:
"""
Applies get_rng_state_tracker().rng_state() to the context.
If not enabled, it does nothing.
"""
if enable:
with get_rng_state_tracker().rng_state():
yield
else:
yield
def set_tensor_dist_attr(tensor: paddle.Tensor, is_parallel: bool, axis: int) -> None:
"""Set distributed attributes for the input tensor"""
tensor.is_distributed = is_parallel
if is_parallel:
tensor.split_axis = axis
def set_weight_tensor_dist_attr(tensor: paddle.Tensor, is_parallel: bool,
parallel_mode: Optional[str], backend: str) -> None:
"""Set distributed attributes for the weight tensor"""
if not is_parallel or parallel_mode is None:
return
set_tensor_dist_attr(tensor, is_parallel, axis=_weight_split_axis[backend][parallel_mode])
def allreduce(
input_: paddle.Tensor,
tp_group: Optional[dist_group_type] = None,
) -> paddle.Tensor:
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if tp_group is None or tp_group.nranks == 1:
return input_
# All-reduce.
output = mp_ops._mp_allreduce(
input_,
group=tp_group,
use_calc_stream=True,
use_model_parallel=True,
)
return output
def identity(
input_: paddle.Tensor,
tp_group: Optional[dist_group_type] = None,
) -> paddle.Tensor:
"""
Identity when forward.
Allreduce across model parallel group when backward.
"""
output = mp_ops._c_identity(input_, group=tp_group)
return output
......@@ -3,9 +3,8 @@
# See LICENSE for license information.
"""FP8 utilities for TransformerEngine"""
import copy
from contextlib import contextmanager
from typing import Tuple, Optional, Dict, Any
from typing import Tuple, Optional, Dict, Any, Union
import numpy as np
......@@ -13,6 +12,9 @@ import paddle
import transformer_engine_paddle as tex
from transformer_engine.common.recipe import DelayedScaling, Format
from .constants import dist_group_type
from .fp8_buffer import FP8MetaFwdBuffer, FP8MetaBwdBuffer
# FP8 support
_is_fp8_available = None
_reason_for_no_fp8 = ""
......@@ -50,21 +52,27 @@ class FP8State:
"""Stores FP8 state"""
def __init__(self):
self.fp8_enabled = False
self.fp8_calibration = False
self.fp8_recipe = None
self._fp8_enabled = False
self._fp8_calibration = False
self._fp8_recipe = None
self._fp8_distributed_group = None
self._is_first_fp8_module = False
self._fp8_autocast_counter = 0
self._fp8_autocast_depth = 0
self._fp8_fwd_buffer = FP8MetaFwdBuffer()
self._fp8_bwd_buffer = FP8MetaBwdBuffer()
def is_fp8_enabled(self) -> bool:
"""Is FP8 enabled"""
return self.fp8_enabled
return self._fp8_enabled
def is_fp8_calibration(self) -> bool:
"""Is FP8 calibration"""
return self.fp8_calibration
return self._fp8_calibration
def get_fp8_recipe(self) -> DelayedScaling:
"""Return the fp8 recipe"""
return self.fp8_recipe
return self._fp8_recipe
@staticmethod
def get_default_fp8_recipe() -> DelayedScaling:
......@@ -73,6 +81,63 @@ class FP8State:
"""
return DelayedScaling()
def get_autocast_id(self) -> int:
"""Returns the number of times of entering the `fp8_autocast` context.
as a unique ID for different training steps."""
return self._fp8_autocast_counter
def is_first_fp8_module(self):
"""Returns `True` only the first time when called multiple
times from within the same `fp8_autocast` context.
"""
tmp = self._is_first_fp8_module
self._is_first_fp8_module = False
return tmp
def get_fp8_group(self) -> Union[dist_group_type, None]:
"""Return the fp8 group for scale/amax comm"""
return self._fp8_distributed_group
def get_fp8_fwd_buffer(self) -> FP8MetaFwdBuffer:
"""Returns global fp8 forward buffer."""
return self._fp8_fwd_buffer
def get_fp8_bwd_buffer(self) -> FP8MetaBwdBuffer:
"""Returns global fp8 backward buffer."""
return self._fp8_bwd_buffer
def enter(
self,
enabled: bool,
calibrating: bool,
fp8_recipe: Optional[DelayedScaling],
fp8_group: Optional[dist_group_type],
) -> None:
"""Called when entering 'fp8_autocast'"""
self.saved_states = (self._fp8_enabled, self._fp8_calibration, self._fp8_recipe,
self._fp8_distributed_group, self._is_first_fp8_module)
self._fp8_enabled = enabled
self._fp8_calibration = calibrating
self._fp8_recipe = self.get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
self._fp8_distributed_group = fp8_group
if self._fp8_autocast_depth == 0:
self._is_first_fp8_module = True
self._fp8_autocast_counter += 1
self._fp8_autocast_depth += 1
def exit(self):
"""Called when exiting 'fp8_autocast'"""
# Restore saved states
(self._fp8_enabled, self._fp8_calibration, self._fp8_recipe, self._fp8_distributed_group,
self._is_first_fp8_module) = self.saved_states
self._fp8_autocast_depth -= 1
if self._fp8_autocast_depth == 0:
self._fp8_fwd_buffer.finalize()
_global_fp8_state = FP8State()
......@@ -87,25 +152,20 @@ def fp8_autocast(
enabled: bool = False,
calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
fp8_group: Optional[dist_group_type] = None,
) -> None:
"""
Context manager for FP8 usage.
"""
global _global_fp8_state
saved_fp8_state = copy.deepcopy(_global_fp8_state)
try:
_global_fp8_state.fp8_enabled = enabled
_global_fp8_state.fp8_calibration = calibrating
_global_fp8_state.fp8_recipe = FP8State.get_default_fp8_recipe(
) if fp8_recipe is None else fp8_recipe
_global_fp8_state.enter(enabled, calibrating, fp8_recipe, fp8_group)
if enabled:
fp8_available, reason_for_no_fp8 = is_fp8_available()
assert fp8_available, reason_for_no_fp8
yield
finally:
_global_fp8_state = saved_fp8_state
_global_fp8_state.exit()
def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType:
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""FP8 meta buffer for FP8 amax reduction"""
from abc import ABC, abstractmethod
from functools import partial
import os
from typing import Dict, Any, List, Union
import numpy as np
import paddle
from .constants import dist_group_type
class FP8MetaBufferBase(ABC):
"""
A global buffer that holds FP8 meta for reduction across trainers.
"""
def __init__(self):
self._data = {}
self._buffer_delete_key = None
self._amax_reduce_wait_func = None
self._dp_amax_reduce_interval = None
self._dp_amax_reduce_idx = 0
@staticmethod
@abstractmethod
def _get_meta_tensor_key():
"""Returns scaling key in `fp8_meta`."""
@staticmethod
@abstractmethod
def _get_buffer_position_key():
"""Returns module position key in `fp8_meta`."""
@staticmethod
@abstractmethod
def _get_autocast_key():
"""Returns autocast id key in `fp8_meta`."""
def _get_amax_buffer_key(self, fp8_meta: Dict[str, Any]) -> str:
"""Return a key in `_data` for the AMAX storage."""
return f"AMAX_{fp8_meta[self._get_autocast_key()]}"
def _execute_deletion(self) -> None:
"""Delete the key from global amax buffer."""
if (self._buffer_delete_key is not None and self._buffer_delete_key in self._data):
del self._data[self._buffer_delete_key]
def _wait_handle_and_split(
self,
contiguous_amax: paddle.Tensor,
chunk_sizes: List[int],
amax_buffer_key: str,
wait_handle: Union[bool, None],
) -> None:
"""Wait for amax reduction to finish and then copy reduced amax to buffer"""
if wait_handle is not None:
wait_handle.wait()
self._data[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes))
def _global_amax_reduction(
self,
fp8_meta: Dict[str, Any],
tp_group: dist_group_type,
tp_size: int,
) -> None:
"""Concatenate, reduce, and split amaxes in the global buffer."""
def _reduce_tensor_across_group_op_max(tensor, group, sync_op):
if paddle.distributed.is_initialized():
wait_handle = paddle.distributed.all_reduce(
tensor,
op=paddle.distributed.ReduceOp.MAX,
group=group,
sync_op=sync_op,
)
return wait_handle
return None
amax_buffer_key = self._get_amax_buffer_key(fp8_meta)
# Key already deleted.
if amax_buffer_key not in self._data:
return None
# Reduce AMAX in DP-domain at an interval.
if self._dp_amax_reduce_interval is None:
self._dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1"))
tp_amax_reduce = False
if self._dp_amax_reduce_idx == 0:
reduce_group = fp8_meta["fp8_group"]
else:
tp_amax_reduce = True
self._dp_amax_reduce_idx = (self._dp_amax_reduce_idx + 1) % self._dp_amax_reduce_interval
if tp_amax_reduce:
if tp_size > 1:
reduce_group = tp_group
else:
return None
chunk_sizes = [x.shape[0] for x in self._data[amax_buffer_key]]
contiguous_amax = paddle.concat(self._data[amax_buffer_key])
wait_handle = _reduce_tensor_across_group_op_max(
contiguous_amax,
reduce_group,
not fp8_meta["async_amax_reduction"],
)
return partial(
self._wait_handle_and_split,
contiguous_amax,
chunk_sizes,
amax_buffer_key,
wait_handle,
)
def add_amax(self, fp8_meta: Dict[str, Any]) -> None:
"""Append `amax_history` to global buffer."""
buffer_key = self._get_amax_buffer_key(fp8_meta)
fp8_meta_tensor_key = self._get_meta_tensor_key()
buffer_position_key = self._get_buffer_position_key()
if buffer_key not in self._data:
self._data[buffer_key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
else:
self._data[buffer_key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0])
if buffer_position_key not in fp8_meta:
fp8_meta[buffer_position_key] = len(self._data[buffer_key]) - 1
# Catch incorrect fp8_autocast usage.
assert fp8_meta[buffer_position_key] == len(self._data[buffer_key]) - 1, \
"Same module is being invoked more than once inside an `fp8_autocast` " \
"region when using FP8 with amax reduction. This behavior is currently " \
"unsupported. For more details and correct usage, please see " \
"https://github.com/NVIDIA/TransformerEngine/pull/93."
def copy_amax_from_buffer(self, fp8_meta: Dict[str, Any]) -> None:
"""Populate current amax with the correct location from buffer."""
fp8_meta_tensor_key = self._get_meta_tensor_key()
buffer_position_key = self._get_buffer_position_key()
if buffer_position_key not in fp8_meta:
return
amax_buffer_key = self._get_amax_buffer_key(fp8_meta)
assert amax_buffer_key in self._data, "TE internal error."
fp8_meta[fp8_meta_tensor_key].amax_history[0] = self._data[amax_buffer_key][
fp8_meta[buffer_position_key]]
def set_for_deletion(self, fp8_meta: Dict[str, Any]) -> None:
"""Delete this amax key from global buffer during autocast end."""
if self._get_autocast_key() not in fp8_meta:
return
self._buffer_delete_key = self._get_amax_buffer_key(fp8_meta)
def get_amax_reduce_handle(self) -> Union[bool, None]:
"""Return AMAX reduction wait handle."""
return self._amax_reduce_handle
def wait(self) -> None:
"""Wait for reduced amax to be available in buffer."""
if self._amax_reduce_wait_func is not None:
self._amax_reduce_wait_func() # pylint: disable=not-callable
self._amax_reduce_wait_func = None
def to_numpy(self) -> Dict[str, List[np.array]]:
"""Convert to numpy arrays"""
out = {}
for k, v in self._data.items():
out[k] = [tensor.numpy() for tensor in v]
return out
def from_numpy(self, buffer: Dict[str, np.array]) -> None:
"""Set buffer values from numpy arrays"""
for k, v in buffer.items():
self._data[k] = [paddle.to_tensor(arr) for arr in v]
class FP8MetaFwdBuffer(FP8MetaBufferBase):
"""FP8Meta Buffer for forward"""
@staticmethod
def _get_meta_tensor_key() -> str:
"""Returns scaling key in `fp8_meta`."""
return "scaling_fwd"
@staticmethod
def _get_buffer_position_key() -> str:
"""Returns module position key in `fp8_meta`."""
return "global_fp8_buffer_pos_fwd"
@staticmethod
def _get_autocast_key() -> str:
"""Returns module position key in `fp8_meta`."""
return "autocast_id_fwd"
def set_for_amax_reduction(
self,
fp8_meta: Dict[str, Any],
tp_group: dist_group_type,
tp_size: int,
) -> None:
"""Sets up the function to call during autocast exit."""
self._amax_global_reduce_func = partial(
self._global_amax_reduction,
fp8_meta,
tp_group,
tp_size,
)
def finalize(self) -> None:
"""
Called at FP8 autocast end.
Performs AMAX reduction and delete unused buffer entries.
"""
if hasattr(self, '_amax_global_reduce_func') and callable(self._amax_global_reduce_func):
self._amax_reduce_wait_func = self._amax_global_reduce_func()
self._execute_deletion()
class FP8MetaBwdBuffer(FP8MetaBufferBase):
"""FP8Meta Buffer for backward"""
@staticmethod
def _get_meta_tensor_key() -> str:
"""Returns scaling key in `fp8_meta`."""
return "scaling_bwd"
@staticmethod
def _get_buffer_position_key() -> str:
"""Returns module position key in `fp8_meta`."""
return "global_fp8_buffer_pos_bwd"
@staticmethod
def _get_autocast_key() -> str:
"""Returns module position key in `fp8_meta`."""
return "autocast_id_bwd"
def finalize(
self,
fp8_meta: Dict[str, Any],
tp_group: dist_group_type,
tp_size: int,
) -> None:
"""
Called at FP8 autocast end in backward.
Performs AMAX reduction and delete unused buffer entries.
"""
self._amax_reduce_wait_func = self._global_amax_reduction(fp8_meta, tp_group, tp_size)
self._execute_deletion()
......@@ -4,27 +4,25 @@
"""Attntion API"""
import math
import os
import warnings
from typing import Optional, Tuple, Union
import paddle
import paddle.nn.functional as F
from transformer_engine.paddle.constants import (
AttnTypes,
TE_DType,
)
from transformer_engine.paddle.cpp_extensions import (
from .layernorm_linear import LayerNormLinear
from .linear import Linear
from .softmax import FusedScaleMaskSoftmax
from ..constants import AttnTypes, TE_DType, dist_group_type
from ..cpp_extensions import (
fused_attn_fwd_qkvpacked,
fused_attn_bwd_qkvpacked,
fused_attn_fwd_kvpacked,
fused_attn_bwd_kvpacked,
)
from transformer_engine.paddle.utils import (attention_mask_func, mask_to_cu_seqlens)
from .base import TransformerEngineBaseLayer
from .layernorm_linear import LayerNormLinear
from .linear import Linear
from .softmax import FusedScaleMaskSoftmax
from ..distributed import get_tp_group_and_world_size, track_rng_state
from ..utils import attention_mask_func, divide, mask_to_cu_seqlens
class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer):
......@@ -161,9 +159,20 @@ class DotProductAttention(paddle.nn.Layer):
self.attn_mask_type = attn_mask_type
self.attention_dropout = attention_dropout
self.attention_type = attention_type
self.backend = backend
self.rng_state = paddle.zeros((2,), dtype='int64')
self.rng_state.persistable = True
self.backend = backend
arch = paddle.device.cuda.get_device_capability()
self.is_fused_attn_supported = arch in ((8, 0), (9, 0))
self.enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN",
"0")) and self.is_fused_attn_supported
if not self.enable_fused_attn and backend == 'transformer_engine':
# FMHA is not enabled, falling back to Paddle backend
self.backend = 'paddle'
if self.backend != 'transformer_engine':
self.scale_mask_softmax = FusedScaleMaskSoftmax(attn_mask_type,
attention_mask_func,
......@@ -343,7 +352,7 @@ class DotProductAttention(paddle.nn.Layer):
return out
class MultiHeadAttention(TransformerEngineBaseLayer):
class MultiHeadAttention(paddle.nn.Layer):
"""Attention w/ QKV and Proj Gemms
Parameters
......@@ -390,6 +399,8 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
input_layernorm: bool = False,
attention_type: str = "self",
zero_centered_gamma: bool = False,
set_parallel_mode: bool = False,
tp_group: Optional[dist_group_type] = None,
backend: str = 'transformer_engine',
) -> None:
super().__init__()
......@@ -403,11 +414,19 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group,
enable_tp=set_parallel_mode)
self.tensor_parallel = self.tp_size > 1
self.hidden_size_per_attention_head = hidden_size // num_attention_heads
self.num_attention_heads = num_attention_heads
norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.set_parallel_mode = set_parallel_mode
self.backend = backend
self.num_attention_heads_per_partition = divide(self.num_attention_heads, self.tp_size)
qkv_parallel_mode = "column" if set_parallel_mode else None
if self.attention_type == "self":
if self.input_layernorm:
self.layernorm_qkv = LayerNormLinear(
......@@ -418,6 +437,8 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
bias_attr=self.bias_attr,
return_layernorm_output=return_layernorm_output,
zero_centered_gamma=zero_centered_gamma,
parallel_mode=qkv_parallel_mode,
tp_group=self.tp_group,
backend=self.backend,
)
else:
......@@ -426,6 +447,8 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
3 * hidden_size,
self.weight_attr,
self.bias_attr,
parallel_mode=qkv_parallel_mode,
tp_group=self.tp_group,
backend=self.backend,
)
......@@ -439,6 +462,8 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
bias_attr=self.bias_attr,
return_layernorm_output=return_layernorm_output,
zero_centered_gamma=zero_centered_gamma,
parallel_mode=qkv_parallel_mode,
tp_group=self.tp_group,
backend=self.backend,
)
else:
......@@ -447,6 +472,8 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
hidden_size,
self.weight_attr,
self.bias_attr,
parallel_mode=qkv_parallel_mode,
tp_group=self.tp_group,
backend=self.backend,
)
self.key_value = Linear(
......@@ -454,6 +481,8 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
2 * hidden_size,
self.weight_attr,
self.bias_attr,
parallel_mode=qkv_parallel_mode,
tp_group=self.tp_group,
backend=self.backend,
)
......@@ -472,6 +501,8 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
hidden_size,
self.weight_attr,
self.bias_attr,
parallel_mode="row" if set_parallel_mode else None,
tp_group=self.tp_group,
backend=self.backend,
)
......@@ -520,23 +551,26 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
mixed_qkv_layer = self.qkv(hidden_states)
# [b, s_q, 3 * hidden_size] --> [b, s_q, 3, num_heads, head_size]
mixed_qkv_layer = mixed_qkv_layer.reshape(
shape=[0, 0, 3, self.num_attention_heads, self.hidden_size_per_attention_head])
context_layer = self.core_attention(
query_layer=mixed_qkv_layer,
key_value_layer=None,
attention_mask=attention_mask,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
set_zero=set_zero,
)
mixed_qkv_layer = mixed_qkv_layer.reshape(shape=[
0, 0, 3, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
])
with track_rng_state(enable=self.tensor_parallel):
context_layer = self.core_attention(
query_layer=mixed_qkv_layer,
key_value_layer=None,
attention_mask=attention_mask,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
set_zero=set_zero,
)
else: # cross attention
mixed_kv_layer = self.key_value(encoder_output)
# [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size]
mixed_kv_layer = mixed_kv_layer.reshape(
shape=[0, 0, 2, self.num_attention_heads, self.hidden_size_per_attention_head])
mixed_kv_layer = mixed_kv_layer.reshape(shape=[
0, 0, 2, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
])
if self.input_layernorm:
layernorm_query_outputs = self.layernorm_query(hidden_states)
......@@ -547,16 +581,18 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
else:
query_layer = self.query_layer(hidden_states)
query_layer = query_layer.reshape(
shape=[0, 0, self.num_attention_heads, self.hidden_size_per_attention_head])
context_layer = self.core_attention(
query_layer=query_layer,
key_value_layer=mixed_kv_layer,
attention_mask=attention_mask,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
set_zero=set_zero,
)
query_layer = query_layer.reshape(shape=[
0, 0, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
])
with track_rng_state(enable=self.tensor_parallel):
context_layer = self.core_attention(
query_layer=query_layer,
key_value_layer=mixed_kv_layer,
attention_mask=attention_mask,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
set_zero=set_zero,
)
context_layer = paddle.reshape(context_layer,
[0, 0, context_layer.shape[2] * context_layer.shape[3]])
......
......@@ -5,6 +5,7 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
import os
import pickle
from typing import Generator, Dict, Tuple, Union, Any
......@@ -14,7 +15,7 @@ import paddle
from paddle.fluid import core
from paddle.fluid.framework import _dygraph_tracer
from ..constants import FP8BwdTensors
from ..constants import FP8BwdTensors, dist_group_type
from ..cpp_extensions import cast_transpose, cast_transpose_bgrad, cast_to_fp8
from ..fp8 import (
FP8State,
......@@ -24,7 +25,6 @@ from ..fp8 import (
get_fp8_te_dtype,
)
from ..profile import nvtx_range
from ..utils import get_bias_dtype, cast_if_needed
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
......@@ -61,9 +61,15 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
self.fp8_calibration = False
self.fp8_meta = {}
self.fp8_meta["fp8_checkpoint"] = False
self.fp8_meta["fp8_group"] = None
self.fp8_meta["recipe"] = FP8State.get_default_fp8_recipe()
self.fp8_meta["scaling_fwd"] = FP8TensorMeta(is_forward=True)
self.fp8_meta["scaling_bwd"] = FP8TensorMeta(is_forward=False)
self.tp_group = None
self.tp_size = 1
self.fp8_meta["autocast_id_fwd_stack"] = []
self.fp8_meta["async_amax_reduction"] = bool(
int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0")))
def set_activation_dtype(self, inp: paddle.Tensor) -> None:
"""Get activation data type for AMP."""
......@@ -102,18 +108,20 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
# assume FP8 execution.
def fp8_init(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop."""
state = get_global_fp8_state()
self.fp8_enabled = state.is_fp8_enabled()
self.fp8_calibration = state.is_fp8_calibration()
global_fp8_state = get_global_fp8_state()
self.fp8_enabled = global_fp8_state.is_fp8_enabled()
self.fp8_calibration = global_fp8_state.is_fp8_calibration()
self.fp8_meta["fp8_checkpoint"] = self.fp8_enabled or self.fp8_calibration
if self.fp8_enabled or self.fp8_calibration:
# FP8 init has already been run and recipe is the same, don't do anything.
if self.fp8_initialized and state.get_fp8_recipe() == self.fp8_meta["recipe"]:
if self.fp8_initialized and global_fp8_state.get_fp8_recipe(
) == self.fp8_meta["recipe"]:
return
# Set FP8, recipe, and other FP8 metadata
self.fp8_meta["recipe"] = state.get_fp8_recipe()
self.fp8_meta["recipe"] = global_fp8_state.get_fp8_recipe()
self.fp8_meta["fp8_group"] = global_fp8_state.get_fp8_group()
# Set FP8_MAX per tensor according to recipe
self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
......@@ -136,6 +144,8 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
state = {}
state["scaling_fwd"] = self.fp8_meta["scaling_fwd"].to_numpy()
state["scaling_bwd"] = self.fp8_meta["scaling_bwd"].to_numpy()
state["global_fp8_fwd_buffer"] = get_global_fp8_state().get_fp8_fwd_buffer().to_numpy()
state["global_fp8_bwd_buffer"] = get_global_fp8_state().get_fp8_bwd_buffer().to_numpy()
# Store other pickelable values.
extra = {}
for k, v in self.fp8_meta.items():
......@@ -179,6 +189,12 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
self.fp8_meta["scaling_fwd"].from_numpy(state["scaling_fwd"])
self.fp8_meta["scaling_bwd"].from_numpy(state["scaling_bwd"])
# Restore global FP8 buffer states.
global_fp8_fwd_buffer = get_global_fp8_state().get_fp8_fwd_buffer()
global_fp8_bwd_buffer = get_global_fp8_state().get_fp8_bwd_buffer()
global_fp8_fwd_buffer.from_numpy(state["global_fp8_fwd_buffer"])
global_fp8_bwd_buffer.from_numpy(state["global_fp8_bwd_buffer"])
# Load extra items.
self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"].amax_history_len = self.fp8_meta["scaling_fwd"].amax_history.shape[
......@@ -210,9 +226,22 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
# Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False):
amax_and_scale_update(self.fp8_meta, True)
global_fp8_fwd_buffer = get_global_fp8_state().get_fp8_fwd_buffer()
global_fp8_fwd_buffer.wait()
if self.fp8_meta["recipe"].reduce_amax:
global_fp8_fwd_buffer.copy_amax_from_buffer(self.fp8_meta)
amax_and_scale_update(self.fp8_meta, True)
global_fp8_fwd_buffer.set_for_deletion(self.fp8_meta)
else:
amax_and_scale_update(self.fp8_meta, True)
if self.fp8_enabled and self.training:
# Setup for amax reduction
if self.fp8_meta["recipe"].reduce_amax:
global_fp8_state = get_global_fp8_state()
self.fp8_meta["first_module"] = global_fp8_state.is_first_fp8_module()
self.fp8_meta["autocast_id_fwd"] = global_fp8_state.get_autocast_id()
self.fp8_meta["autocast_id_fwd_stack"].append(self.fp8_meta["autocast_id_fwd"])
self.fp8_meta["update_amax_and_scale_fwd"] = True
else:
self.fp8_meta["update_amax_and_scale_fwd"] = False
......@@ -220,18 +249,47 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
with nvtx_range(self.__class__.__name__ + " forward"):
yield inp
if self.fp8_enabled and self.training and self.fp8_meta["recipe"].reduce_amax:
global_fp8_state = get_global_fp8_state()
global_fp8_fwd_buffer = global_fp8_state.get_fp8_fwd_buffer()
global_fp8_fwd_buffer.add_amax(self.fp8_meta)
global_fp8_fwd_buffer.set_for_amax_reduction(
self.fp8_meta,
self.tp_group,
self.tp_size,
)
@staticmethod
@contextmanager
def prepare_backward(fp8_enabled: bool,
fp8_meta: Dict[str, Any],
tp_group: dist_group_type,
tp_size: int,
name: str = "") -> Generator[None, None, None]:
"""Checks and prep for BWD."""
if fp8_enabled:
amax_and_scale_update(fp8_meta, False)
global_fp8_state = get_global_fp8_state()
global_fp8_bwd_buffer = global_fp8_state.get_fp8_bwd_buffer()
global_fp8_bwd_buffer.wait()
if fp8_meta["recipe"].reduce_amax:
global_fp8_bwd_buffer.copy_amax_from_buffer(fp8_meta)
amax_and_scale_update(fp8_meta, False)
global_fp8_bwd_buffer.set_for_deletion(fp8_meta)
# Get new backward key.
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0)
else:
amax_and_scale_update(fp8_meta, False)
with nvtx_range(name + " backward"):
yield
if fp8_enabled and fp8_meta["recipe"].reduce_amax:
global_fp8_bwd_buffer.add_amax(fp8_meta)
if fp8_meta["first_module"]:
global_fp8_bwd_buffer.finalize(fp8_meta, tp_group, tp_size)
@staticmethod
def grad_output_preprocess(
ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]:
......@@ -258,8 +316,6 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
)
bias_dtype = get_bias_dtype(ctx.activation_dtype)
bgrad = cast_if_needed(bgrad, bias_dtype)
else:
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
grad_output_c, grad_output_t = cast_transpose(
......
......@@ -31,7 +31,7 @@ class _LayerNorm(paddle.autograd.PyLayer):
zero_centered_gamma: bool,
) -> paddle.Tensor:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
in_features = ln_weight.shape[0]
assert inp.shape[-1] == in_features, "LayerNorm not possible"
inputmat = inp.reshape((-1, in_features))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment