Unverified Commit f6a08908 authored by Yifan Xiong's avatar Yifan Xiong Committed by GitHub
Browse files

Fix bug in TE BERT model (#461)

Fix bug in TE BERT model, create different instances for each layer
during init.
parent ca0dccac
......@@ -61,15 +61,18 @@ def __init__(self, config, num_classes):
self._embedding = torch.nn.Embedding(config.vocab_size, config.hidden_size)
# Build BERT using nn.TransformerEncoderLayer or te.TransformerLayer
# input shape: (seq_len, batch_size, hidden_size)
encoder_layer = te.TransformerLayer(
config.hidden_size,
config.intermediate_size,
config.num_attention_heads,
apply_residual_connection_post_layernorm=True,
output_layernorm=True,
layer_type='encoder',
self._encoder_layers = torch.nn.ModuleList(
[
te.TransformerLayer(
config.hidden_size,
config.intermediate_size,
config.num_attention_heads,
apply_residual_connection_post_layernorm=True,
output_layernorm=True,
layer_type='encoder',
) for _ in range(config.num_hidden_layers)
]
)
self._encoder_layers = torch.nn.ModuleList([encoder_layer for _ in range(config.num_hidden_layers)])
# BertPooler used in huggingface transformers
# https://github.com/huggingface/transformers/blob/accad48e/src/transformers/models/bert/modeling_bert.py#L893
self._pooler = torch.nn.Sequential(
......@@ -113,7 +116,6 @@ def __init__(self, name, parameters=''):
Precision.FLOAT16,
Precision.FP8_HYBRID,
Precision.FP8_E4M3,
Precision.FP8_E5M2,
]
self._optimizer_type = Optimizer.ADAMW
self._loss_fn = torch.nn.CrossEntropyLoss()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment