Unverified Commit f6a08908 authored by Yifan Xiong's avatar Yifan Xiong Committed by GitHub
Browse files

Fix bug in TE BERT model (#461)

Fix bug in TE BERT model, create different instances for each layer
during init.
parent ca0dccac
...@@ -61,15 +61,18 @@ def __init__(self, config, num_classes): ...@@ -61,15 +61,18 @@ def __init__(self, config, num_classes):
self._embedding = torch.nn.Embedding(config.vocab_size, config.hidden_size) self._embedding = torch.nn.Embedding(config.vocab_size, config.hidden_size)
# Build BERT using nn.TransformerEncoderLayer or te.TransformerLayer # Build BERT using nn.TransformerEncoderLayer or te.TransformerLayer
# input shape: (seq_len, batch_size, hidden_size) # input shape: (seq_len, batch_size, hidden_size)
encoder_layer = te.TransformerLayer( self._encoder_layers = torch.nn.ModuleList(
[
te.TransformerLayer(
config.hidden_size, config.hidden_size,
config.intermediate_size, config.intermediate_size,
config.num_attention_heads, config.num_attention_heads,
apply_residual_connection_post_layernorm=True, apply_residual_connection_post_layernorm=True,
output_layernorm=True, output_layernorm=True,
layer_type='encoder', layer_type='encoder',
) for _ in range(config.num_hidden_layers)
]
) )
self._encoder_layers = torch.nn.ModuleList([encoder_layer for _ in range(config.num_hidden_layers)])
# BertPooler used in huggingface transformers # BertPooler used in huggingface transformers
# https://github.com/huggingface/transformers/blob/accad48e/src/transformers/models/bert/modeling_bert.py#L893 # https://github.com/huggingface/transformers/blob/accad48e/src/transformers/models/bert/modeling_bert.py#L893
self._pooler = torch.nn.Sequential( self._pooler = torch.nn.Sequential(
...@@ -113,7 +116,6 @@ def __init__(self, name, parameters=''): ...@@ -113,7 +116,6 @@ def __init__(self, name, parameters=''):
Precision.FLOAT16, Precision.FLOAT16,
Precision.FP8_HYBRID, Precision.FP8_HYBRID,
Precision.FP8_E4M3, Precision.FP8_E4M3,
Precision.FP8_E5M2,
] ]
self._optimizer_type = Optimizer.ADAMW self._optimizer_type = Optimizer.ADAMW
self._loss_fn = torch.nn.CrossEntropyLoss() self._loss_fn = torch.nn.CrossEntropyLoss()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment