"docs/vscode:/vscode.git/clone" did not exist on "9d27df8071bb39d117755200ace81a3669b4134c"
Unverified Commit 48065c06 authored by Reza Yazdani's avatar Reza Yazdani Committed by GitHub
Browse files

Fixing the module-inject Api (#786)

parent e60e92eb
from .replace_module import replace_transformer_layer
......@@ -7,10 +7,10 @@ def replace_transformer_layer(orig_layer_impl,
model,
micro_batch_size,
bert_config,
seed,
max_seq_length,
preln=False,
seed=-1,
preln=True,
fp16=True,
training=True,
huggingface=False,
local_rank=-1):
""" Replace bert-style transformer layers with DeepSpeed's transformer layer
......@@ -21,9 +21,9 @@ def replace_transformer_layer(orig_layer_impl,
micro_batch_size (int): micro batch size per gpu used during training/eval
bert_config (dict): model config containing hidden size, attention heads, etc.
seed (int): random seed value
max_seq_length (int): max sequence length for training
preln (bool): does the original layer implementation do pre or post layer norm?
fp16 (bool): fp16 or fp32
Training (bool): select between training (True) or inference (False) mode
huggingface (bool): huggingface implementation is unique (supports both encoder/decoder modes)
Returns:
......@@ -32,7 +32,6 @@ def replace_transformer_layer(orig_layer_impl,
def replace_fn(child):
transformer_config = deepspeed.DeepSpeedTransformerConfig(
batch_size=micro_batch_size,
max_seq_length=max_seq_length,
hidden_size=bert_config.hidden_size,
heads=bert_config.num_attention_heads,
attn_dropout_ratio=bert_config.attention_probs_dropout_prob,
......@@ -43,7 +42,8 @@ def replace_transformer_layer(orig_layer_impl,
fp16=fp16,
pre_layer_norm=preln,
huggingface=huggingface,
local_rank=local_rank)
local_rank=local_rank,
training=training)
new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config)
# copy relevant state from child -> new module
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment