Unverified Commit 48065c06 authored by Reza Yazdani's avatar Reza Yazdani Committed by GitHub
Browse files

Fixing the module-inject Api (#786)

parent e60e92eb
from .replace_module import replace_transformer_layer
...@@ -7,10 +7,10 @@ def replace_transformer_layer(orig_layer_impl, ...@@ -7,10 +7,10 @@ def replace_transformer_layer(orig_layer_impl,
model, model,
micro_batch_size, micro_batch_size,
bert_config, bert_config,
seed, seed=-1,
max_seq_length, preln=True,
preln=False,
fp16=True, fp16=True,
training=True,
huggingface=False, huggingface=False,
local_rank=-1): local_rank=-1):
""" Replace bert-style transformer layers with DeepSpeed's transformer layer """ Replace bert-style transformer layers with DeepSpeed's transformer layer
...@@ -21,9 +21,9 @@ def replace_transformer_layer(orig_layer_impl, ...@@ -21,9 +21,9 @@ def replace_transformer_layer(orig_layer_impl,
micro_batch_size (int): micro batch size per gpu used during training/eval micro_batch_size (int): micro batch size per gpu used during training/eval
bert_config (dict): model config containing hidden size, attention heads, etc. bert_config (dict): model config containing hidden size, attention heads, etc.
seed (int): random seed value seed (int): random seed value
max_seq_length (int): max sequence length for training
preln (bool): does the original layer implementation do pre or post layer norm? preln (bool): does the original layer implementation do pre or post layer norm?
fp16 (bool): fp16 or fp32 fp16 (bool): fp16 or fp32
Training (bool): select between training (True) or inference (False) mode
huggingface (bool): huggingface implementation is unique (supports both encoder/decoder modes) huggingface (bool): huggingface implementation is unique (supports both encoder/decoder modes)
Returns: Returns:
...@@ -32,7 +32,6 @@ def replace_transformer_layer(orig_layer_impl, ...@@ -32,7 +32,6 @@ def replace_transformer_layer(orig_layer_impl,
def replace_fn(child): def replace_fn(child):
transformer_config = deepspeed.DeepSpeedTransformerConfig( transformer_config = deepspeed.DeepSpeedTransformerConfig(
batch_size=micro_batch_size, batch_size=micro_batch_size,
max_seq_length=max_seq_length,
hidden_size=bert_config.hidden_size, hidden_size=bert_config.hidden_size,
heads=bert_config.num_attention_heads, heads=bert_config.num_attention_heads,
attn_dropout_ratio=bert_config.attention_probs_dropout_prob, attn_dropout_ratio=bert_config.attention_probs_dropout_prob,
...@@ -43,7 +42,8 @@ def replace_transformer_layer(orig_layer_impl, ...@@ -43,7 +42,8 @@ def replace_transformer_layer(orig_layer_impl,
fp16=fp16, fp16=fp16,
pre_layer_norm=preln, pre_layer_norm=preln,
huggingface=huggingface, huggingface=huggingface,
local_rank=local_rank) local_rank=local_rank,
training=training)
new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config) new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config)
# copy relevant state from child -> new module # copy relevant state from child -> new module
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment