Unverified Commit 805b9872 authored by Tian Zheng's avatar Tian Zheng Committed by GitHub
Browse files

[Paddle] Add control of RNG state (#410)



* Add control of attention dropout and hidden dropout RNG state
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix CI error
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

---------
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
parent 3a63b13d
......@@ -107,6 +107,7 @@ class TestLinearPipelineParallel(unittest.TestCase):
def test_pipeline_train(self):
"""Test pipeline parallel training"""
set_random_seed(1024)
np.random.seed(1024)
weight1_np = np.random.normal(size=[self.in_features, self.hidden_features])
weight2_np = np.random.normal(size=[self.hidden_features, self.in_features])
......
......@@ -1085,7 +1085,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
assert_allclose(layer_te.self_attention.qkv.bias.grad,
layer_pd.self_attention.qkv.bias.grad,
rtol=0.01,
atol=0.5)
atol=0.6)
else:
assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad,
layer_pd.self_attention.layernorm_qkv.bias.grad,
......
......@@ -3,9 +3,12 @@
# See LICENSE for license information.
"""Utils for testing"""
import random
import numpy as np
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
import transformer_engine # pylint: disable=unused-import
......@@ -49,6 +52,43 @@ def is_devices_enough(required):
def set_random_seed(seed):
"""Set random seed for reproducability."""
np.random.seed(seed)
paddle.seed(seed)
paddle.distributed.fleet.meta_parallel.model_parallel_random_seed(seed)
hcg = fleet.get_hybrid_communicate_group()
if paddle.distributed.get_world_size() > 1:
# obtain rank message of hybrid parallel
mp_rank = hcg.get_model_parallel_rank()
mp_size = hcg.get_model_parallel_world_size()
pp_rank = hcg.get_stage_id()
pp_size = hcg.get_pipe_parallel_world_size()
dp_rank = hcg.get_data_parallel_rank()
dp_size = hcg.get_data_parallel_world_size()
sharding_rank = hcg.get_sharding_parallel_rank()
else:
mp_rank, mp_size = 0, 1
pp_rank, pp_size = 0, 1
dp_rank, dp_size = 0, 1
sharding_rank, _ = 0, 1
random.seed(seed + 100 * pp_rank)
np.random.seed(seed + 100 * pp_rank)
seed_offset = seed + 1024 + paddle.distributed.get_world_size()
global_seed = (seed_offset + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) +
sharding_rank * (mp_size * pp_size * dp_size))
seed_offset += paddle.distributed.get_world_size()
local_seed = (seed_offset + mp_rank + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) +
sharding_rank * (mp_size * pp_size * dp_size))
tracker = get_rng_state_tracker()
# tracker.reset()
if "global_seed" not in tracker.states_:
tracker.add("global_seed", global_seed)
if "local_seed" not in tracker.states_:
tracker.add("local_seed", local_seed)
paddle.seed(global_seed)
......@@ -39,13 +39,13 @@ def get_tp_group_and_world_size(tp_group: Union[dist_group_type, None],
@contextmanager
def track_rng_state(enable: bool) -> None:
def track_rng_state(enable: bool, **kwargs) -> None:
"""
Applies get_rng_state_tracker().rng_state() to the context.
If not enabled, it does nothing.
"""
if enable:
with get_rng_state_tracker().rng_state():
with get_rng_state_tracker().rng_state(**kwargs):
yield
else:
yield
......
......@@ -401,6 +401,7 @@ class MultiHeadAttention(paddle.nn.Layer):
zero_centered_gamma: bool = False,
set_parallel_mode: bool = False,
tp_group: Optional[dist_group_type] = None,
rng_state_name: str = 'local_seed',
backend: str = 'transformer_engine',
) -> None:
super().__init__()
......@@ -422,6 +423,7 @@ class MultiHeadAttention(paddle.nn.Layer):
self.num_attention_heads = num_attention_heads
norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.set_parallel_mode = set_parallel_mode
self.rng_state_name = rng_state_name
self.backend = backend
self.num_attention_heads_per_partition = divide(self.num_attention_heads, self.tp_size)
......@@ -555,7 +557,7 @@ class MultiHeadAttention(paddle.nn.Layer):
0, 0, 3, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
])
with track_rng_state(enable=self.tensor_parallel):
with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
context_layer = self.core_attention(
query_layer=mixed_qkv_layer,
key_value_layer=None,
......@@ -584,7 +586,7 @@ class MultiHeadAttention(paddle.nn.Layer):
query_layer = query_layer.reshape(shape=[
0, 0, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
])
with track_rng_state(enable=self.tensor_parallel):
with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
context_layer = self.core_attention(
query_layer=query_layer,
key_value_layer=mixed_kv_layer,
......
......@@ -9,6 +9,7 @@ import paddle
from . import LayerNormMLP, LayerNorm, MultiHeadAttention
from ..constants import AttnMaskTypes, LayerTypes, dist_group_type
from ..distributed import get_tp_group_and_world_size, track_rng_state
class TransformerLayer(paddle.nn.Layer):
......@@ -90,6 +91,8 @@ class TransformerLayer(paddle.nn.Layer):
activation: str = 'gelu',
set_parallel_mode: bool = False,
tp_group: Optional[dist_group_type] = None,
attention_dropout_rng_state_name: str = 'local_seed',
hidden_dropout_rng_state_name: str = 'global_seed',
backend: str = 'transformer_engine') -> None:
super().__init__()
......@@ -99,7 +102,10 @@ class TransformerLayer(paddle.nn.Layer):
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.self_attn_mask_type = self_attn_mask_type
self.set_parallel_mode = set_parallel_mode
self.tp_group = tp_group
self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group,
enable_tp=set_parallel_mode)
self.tensor_parallel = self.tp_size > 1
self.hidden_dropout_rng_state_name = hidden_dropout_rng_state_name
assert (self_attn_mask_type
in AttnMaskTypes), f"self_attn_mask_type {self_attn_mask_type} not supported"
......@@ -119,6 +125,7 @@ class TransformerLayer(paddle.nn.Layer):
"zero_centered_gamma": zero_centered_gamma,
"set_parallel_mode": set_parallel_mode,
"tp_group": tp_group,
"rng_state_name": attention_dropout_rng_state_name,
"backend": backend,
}
......@@ -224,11 +231,12 @@ class TransformerLayer(paddle.nn.Layer):
residual = hidden_states
# dropoout add.
out = paddle.nn.functional.dropout(
attention_output,
p=self.hidden_dropout,
training=True,
)
with track_rng_state(enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name):
out = paddle.nn.functional.dropout(
attention_output,
p=self.hidden_dropout,
training=True,
)
bda_output = residual + out
# Cross attention.
......@@ -247,11 +255,13 @@ class TransformerLayer(paddle.nn.Layer):
attention_output = inter_attention_outputs
residual = bda_output
out = paddle.nn.functional.dropout(
attention_output,
p=self.hidden_dropout,
training=True,
)
with track_rng_state(enable=self.tensor_parallel,
name=self.hidden_dropout_rng_state_name):
out = paddle.nn.functional.dropout(
attention_output,
p=self.hidden_dropout,
training=True,
)
bda_output = residual + out
# MLP.
......@@ -263,7 +273,8 @@ class TransformerLayer(paddle.nn.Layer):
residual = bda_output
# dropoout add.
out = paddle.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=True)
with track_rng_state(enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name):
out = paddle.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=True)
output = residual + out
# For BERT like architectures.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment