Unverified Commit 805b9872 authored by Tian Zheng's avatar Tian Zheng Committed by GitHub
Browse files

[Paddle] Add control of RNG state (#410)



* Add control of attention dropout and hidden dropout RNG state
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix CI error
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

---------
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
parent 3a63b13d
...@@ -107,6 +107,7 @@ class TestLinearPipelineParallel(unittest.TestCase): ...@@ -107,6 +107,7 @@ class TestLinearPipelineParallel(unittest.TestCase):
def test_pipeline_train(self): def test_pipeline_train(self):
"""Test pipeline parallel training""" """Test pipeline parallel training"""
set_random_seed(1024) set_random_seed(1024)
np.random.seed(1024)
weight1_np = np.random.normal(size=[self.in_features, self.hidden_features]) weight1_np = np.random.normal(size=[self.in_features, self.hidden_features])
weight2_np = np.random.normal(size=[self.hidden_features, self.in_features]) weight2_np = np.random.normal(size=[self.hidden_features, self.in_features])
......
...@@ -1085,7 +1085,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -1085,7 +1085,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
assert_allclose(layer_te.self_attention.qkv.bias.grad, assert_allclose(layer_te.self_attention.qkv.bias.grad,
layer_pd.self_attention.qkv.bias.grad, layer_pd.self_attention.qkv.bias.grad,
rtol=0.01, rtol=0.01,
atol=0.5) atol=0.6)
else: else:
assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad, assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad,
layer_pd.self_attention.layernorm_qkv.bias.grad, layer_pd.self_attention.layernorm_qkv.bias.grad,
......
...@@ -3,9 +3,12 @@ ...@@ -3,9 +3,12 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Utils for testing""" """Utils for testing"""
import random
import numpy as np import numpy as np
import paddle import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
import transformer_engine # pylint: disable=unused-import import transformer_engine # pylint: disable=unused-import
...@@ -49,6 +52,43 @@ def is_devices_enough(required): ...@@ -49,6 +52,43 @@ def is_devices_enough(required):
def set_random_seed(seed): def set_random_seed(seed):
"""Set random seed for reproducability.""" """Set random seed for reproducability."""
np.random.seed(seed)
paddle.seed(seed) hcg = fleet.get_hybrid_communicate_group()
paddle.distributed.fleet.meta_parallel.model_parallel_random_seed(seed) if paddle.distributed.get_world_size() > 1:
# obtain rank message of hybrid parallel
mp_rank = hcg.get_model_parallel_rank()
mp_size = hcg.get_model_parallel_world_size()
pp_rank = hcg.get_stage_id()
pp_size = hcg.get_pipe_parallel_world_size()
dp_rank = hcg.get_data_parallel_rank()
dp_size = hcg.get_data_parallel_world_size()
sharding_rank = hcg.get_sharding_parallel_rank()
else:
mp_rank, mp_size = 0, 1
pp_rank, pp_size = 0, 1
dp_rank, dp_size = 0, 1
sharding_rank, _ = 0, 1
random.seed(seed + 100 * pp_rank)
np.random.seed(seed + 100 * pp_rank)
seed_offset = seed + 1024 + paddle.distributed.get_world_size()
global_seed = (seed_offset + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) +
sharding_rank * (mp_size * pp_size * dp_size))
seed_offset += paddle.distributed.get_world_size()
local_seed = (seed_offset + mp_rank + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) +
sharding_rank * (mp_size * pp_size * dp_size))
tracker = get_rng_state_tracker()
# tracker.reset()
if "global_seed" not in tracker.states_:
tracker.add("global_seed", global_seed)
if "local_seed" not in tracker.states_:
tracker.add("local_seed", local_seed)
paddle.seed(global_seed)
...@@ -39,13 +39,13 @@ def get_tp_group_and_world_size(tp_group: Union[dist_group_type, None], ...@@ -39,13 +39,13 @@ def get_tp_group_and_world_size(tp_group: Union[dist_group_type, None],
@contextmanager @contextmanager
def track_rng_state(enable: bool) -> None: def track_rng_state(enable: bool, **kwargs) -> None:
""" """
Applies get_rng_state_tracker().rng_state() to the context. Applies get_rng_state_tracker().rng_state() to the context.
If not enabled, it does nothing. If not enabled, it does nothing.
""" """
if enable: if enable:
with get_rng_state_tracker().rng_state(): with get_rng_state_tracker().rng_state(**kwargs):
yield yield
else: else:
yield yield
......
...@@ -401,6 +401,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -401,6 +401,7 @@ class MultiHeadAttention(paddle.nn.Layer):
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
set_parallel_mode: bool = False, set_parallel_mode: bool = False,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
rng_state_name: str = 'local_seed',
backend: str = 'transformer_engine', backend: str = 'transformer_engine',
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -422,6 +423,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -422,6 +423,7 @@ class MultiHeadAttention(paddle.nn.Layer):
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
norm_factor = math.sqrt(self.hidden_size_per_attention_head) norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.set_parallel_mode = set_parallel_mode self.set_parallel_mode = set_parallel_mode
self.rng_state_name = rng_state_name
self.backend = backend self.backend = backend
self.num_attention_heads_per_partition = divide(self.num_attention_heads, self.tp_size) self.num_attention_heads_per_partition = divide(self.num_attention_heads, self.tp_size)
...@@ -555,7 +557,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -555,7 +557,7 @@ class MultiHeadAttention(paddle.nn.Layer):
0, 0, 3, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head 0, 0, 3, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
]) ])
with track_rng_state(enable=self.tensor_parallel): with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
context_layer = self.core_attention( context_layer = self.core_attention(
query_layer=mixed_qkv_layer, query_layer=mixed_qkv_layer,
key_value_layer=None, key_value_layer=None,
...@@ -584,7 +586,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -584,7 +586,7 @@ class MultiHeadAttention(paddle.nn.Layer):
query_layer = query_layer.reshape(shape=[ query_layer = query_layer.reshape(shape=[
0, 0, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head 0, 0, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
]) ])
with track_rng_state(enable=self.tensor_parallel): with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
context_layer = self.core_attention( context_layer = self.core_attention(
query_layer=query_layer, query_layer=query_layer,
key_value_layer=mixed_kv_layer, key_value_layer=mixed_kv_layer,
......
...@@ -9,6 +9,7 @@ import paddle ...@@ -9,6 +9,7 @@ import paddle
from . import LayerNormMLP, LayerNorm, MultiHeadAttention from . import LayerNormMLP, LayerNorm, MultiHeadAttention
from ..constants import AttnMaskTypes, LayerTypes, dist_group_type from ..constants import AttnMaskTypes, LayerTypes, dist_group_type
from ..distributed import get_tp_group_and_world_size, track_rng_state
class TransformerLayer(paddle.nn.Layer): class TransformerLayer(paddle.nn.Layer):
...@@ -90,6 +91,8 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -90,6 +91,8 @@ class TransformerLayer(paddle.nn.Layer):
activation: str = 'gelu', activation: str = 'gelu',
set_parallel_mode: bool = False, set_parallel_mode: bool = False,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
attention_dropout_rng_state_name: str = 'local_seed',
hidden_dropout_rng_state_name: str = 'global_seed',
backend: str = 'transformer_engine') -> None: backend: str = 'transformer_engine') -> None:
super().__init__() super().__init__()
...@@ -99,7 +102,10 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -99,7 +102,10 @@ class TransformerLayer(paddle.nn.Layer):
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.self_attn_mask_type = self_attn_mask_type self.self_attn_mask_type = self_attn_mask_type
self.set_parallel_mode = set_parallel_mode self.set_parallel_mode = set_parallel_mode
self.tp_group = tp_group self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group,
enable_tp=set_parallel_mode)
self.tensor_parallel = self.tp_size > 1
self.hidden_dropout_rng_state_name = hidden_dropout_rng_state_name
assert (self_attn_mask_type assert (self_attn_mask_type
in AttnMaskTypes), f"self_attn_mask_type {self_attn_mask_type} not supported" in AttnMaskTypes), f"self_attn_mask_type {self_attn_mask_type} not supported"
...@@ -119,6 +125,7 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -119,6 +125,7 @@ class TransformerLayer(paddle.nn.Layer):
"zero_centered_gamma": zero_centered_gamma, "zero_centered_gamma": zero_centered_gamma,
"set_parallel_mode": set_parallel_mode, "set_parallel_mode": set_parallel_mode,
"tp_group": tp_group, "tp_group": tp_group,
"rng_state_name": attention_dropout_rng_state_name,
"backend": backend, "backend": backend,
} }
...@@ -224,11 +231,12 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -224,11 +231,12 @@ class TransformerLayer(paddle.nn.Layer):
residual = hidden_states residual = hidden_states
# dropoout add. # dropoout add.
out = paddle.nn.functional.dropout( with track_rng_state(enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name):
attention_output, out = paddle.nn.functional.dropout(
p=self.hidden_dropout, attention_output,
training=True, p=self.hidden_dropout,
) training=True,
)
bda_output = residual + out bda_output = residual + out
# Cross attention. # Cross attention.
...@@ -247,11 +255,13 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -247,11 +255,13 @@ class TransformerLayer(paddle.nn.Layer):
attention_output = inter_attention_outputs attention_output = inter_attention_outputs
residual = bda_output residual = bda_output
out = paddle.nn.functional.dropout( with track_rng_state(enable=self.tensor_parallel,
attention_output, name=self.hidden_dropout_rng_state_name):
p=self.hidden_dropout, out = paddle.nn.functional.dropout(
training=True, attention_output,
) p=self.hidden_dropout,
training=True,
)
bda_output = residual + out bda_output = residual + out
# MLP. # MLP.
...@@ -263,7 +273,8 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -263,7 +273,8 @@ class TransformerLayer(paddle.nn.Layer):
residual = bda_output residual = bda_output
# dropoout add. # dropoout add.
out = paddle.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=True) with track_rng_state(enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name):
out = paddle.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=True)
output = residual + out output = residual + out
# For BERT like architectures. # For BERT like architectures.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment