Unverified Commit 44bd538b authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

Module replacement support (#586)


Co-authored-by: default avatarReza Yazdani <reyazda@microsoft.com>
Co-authored-by: default avatarOlatunji Ruwase <olruwase@microsoft.com>
parent 5ab12795
import copy
import torch
from deepspeed.ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
def module_inject(layer_obj,
model,
config,
micro_batch_size,
max_seq_length,
seed,
preln,
fp16=True):
for name, child in model.named_children():
if isinstance(child, layer_obj):
print('REPLACING BertLayer')
cuda_config = DeepSpeedTransformerConfig(
batch_size=micro_batch_size,
max_seq_length=max_seq_length,
hidden_size=config.hidden_size,
heads=config.num_attention_heads,
attn_dropout_ratio=config.attention_probs_dropout_prob,
hidden_dropout_ratio=config.hidden_dropout_prob,
num_hidden_layers=config.num_hidden_layers,
initializer_range=config.initializer_range,
seed=seed,
fp16=fp16,
pre_layer_norm=preln)
new_module = DeepSpeedTransformerLayer(cuda_config)
# copy relevant state from child -> new module
qw = child.attention.self.query.weight
qb = child.attention.self.query.bias
kw = child.attention.self.key.weight
kb = child.attention.self.key.bias
vw = child.attention.self.value.weight
vb = child.attention.self.value.bias
qkvw = torch.cat((qw, kw, vw), 0)
qkvb = torch.cat((qb, kb, vb), 0)
new_module.attn_qkvw.data = qkvw
new_module.attn_qkvb.data = qkvb
new_module.attn_ow.data = child.attention.output.dense.weight
new_module.attn_ob.data = child.attention.output.dense.bias
if preln:
attention_layerNorm = child.PostAttentionLayerNorm
else:
attention_layerNorm = child.attention.output.LayerNorm
new_module.attn_nw.data = attention_layerNorm.weight
new_module.attn_nb.data = attention_layerNorm.bias
if preln:
intermediate_FF = child.intermediate.dense_act
else:
intermediate_FF = child.intermediate.dense
new_module.inter_w.data = intermediate_FF.weight
new_module.inter_b.data = intermediate_FF.bias
new_module.output_w.data = child.output.dense.weight
new_module.output_b.data = child.output.dense.bias
if preln:
transformer_LayerNorm = child.PreAttentionLayerNorm
else:
transformer_LayerNorm = child.output.LayerNorm
new_module.norm_w.data = transformer_LayerNorm.weight
new_module.norm_b.data = transformer_LayerNorm.bias
setattr(model, name, copy.deepcopy(new_module))
else:
module_inject(layer_obj,
child,
config,
micro_batch_size,
max_seq_length,
seed,
preln,
fp16)
return model
def test_hi():
from turing.nvidia_modelingpreln import BertConfig as BertConfigPreLN
from turing.nvidia_modelingpreln import BertForQuestionAnswering as BertForQuestionAnsweringPreLN
from turing.nvidia_modelingpreln import BertLayer
bert_model_config = {
"vocab_size_or_config_json_file": 119547,
"hidden_size": 1024,
"num_hidden_layers": 1,
"num_attention_heads": 16,
"intermediate_size": 4096,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"initializer_range": 0.02
}
bert_config = BertConfigPreLN(**bert_model_config)
base_model = BertForQuestionAnsweringPreLN(bert_config, args=None)
#base_model = LinearStack()
test_model = copy.deepcopy(base_model)
test_model = module_inject(BertLayer, test_model, bert_config, 4, 384, 1234)
print('BASE', base_model)
print('TEST', test_model)
#base_model.eval()
#test_model.eval()
#test_input = torch.rand(1, base_model.input_dim)
#base_output = base_model(test_input)
#test_output = test_model(test_input)
#
#assert torch.allclose(base_output, test_output, atol=3e-8)
import copy
import torch
import deepspeed
def replace_transformer_layer(orig_layer_impl,
model,
micro_batch_size,
bert_config,
seed,
max_seq_length,
preln=False,
fp16=True,
huggingface=False,
local_rank=-1):
""" Replace bert-style transformer layers with DeepSpeed's transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
micro_batch_size (int): micro batch size per gpu used during training/eval
bert_config (dict): model config containing hidden size, attention heads, etc.
seed (int): random seed value
max_seq_length (int): max sequence length for training
preln (bool): does the original layer implementation do pre or post layer norm?
fp16 (bool): fp16 or fp32
huggingface (bool): huggingface implementation is unique (supports both encoder/decoder modes)
Returns:
Updated nn.module with replaced transformer layers
"""
def replace_fn(child):
transformer_config = deepspeed.DeepSpeedTransformerConfig(
batch_size=micro_batch_size,
max_seq_length=max_seq_length,
hidden_size=bert_config.hidden_size,
heads=bert_config.num_attention_heads,
attn_dropout_ratio=bert_config.attention_probs_dropout_prob,
hidden_dropout_ratio=bert_config.hidden_dropout_prob,
num_hidden_layers=bert_config.num_hidden_layers,
initializer_range=bert_config.initializer_range,
seed=seed,
fp16=fp16,
pre_layer_norm=preln,
huggingface=huggingface,
local_rank=local_rank)
new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config)
# copy relevant state from child -> new module
qw = child.attention.self.query.weight
qb = child.attention.self.query.bias
kw = child.attention.self.key.weight
kb = child.attention.self.key.bias
vw = child.attention.self.value.weight
vb = child.attention.self.value.bias
qkvw = torch.cat((qw, kw, vw), 0)
qkvb = torch.cat((qb, kb, vb), 0)
#qw.data,kw.data,vw.data = torch.chunk(qkvw, 3, axis=0)
#qb.data,kb.data,vb.data = torch.chunk(qkvb, 3, axis=0)
new_module.attn_qkvw.data = qkvw
new_module.attn_qkvb.data = qkvb
new_module.attn_ow.data = child.attention.output.dense.weight
new_module.attn_ob.data = child.attention.output.dense.bias
if preln:
attention_layernorm = child.PostAttentionLayerNorm
else:
attention_layernorm = child.attention.output.LayerNorm
new_module.attn_nw.data = attention_layernorm.weight
new_module.attn_nb.data = attention_layernorm.bias
if preln:
intermediate_ff = child.intermediate.dense_act
else:
intermediate_ff = child.intermediate.dense
new_module.inter_w.data = intermediate_ff.weight
new_module.inter_b.data = intermediate_ff.bias
new_module.output_w.data = child.output.dense.weight
new_module.output_b.data = child.output.dense.bias
if preln:
transformer_layernorm = child.PreAttentionLayerNorm
else:
transformer_layernorm = child.output.LayerNorm
new_module.norm_w.data = transformer_layernorm.weight
new_module.norm_b.data = transformer_layernorm.bias
return new_module
return replace_module(model=model, orig_class=orig_layer_impl, replace_fn=replace_fn)
def revert_transformer_layer(orig_layer_impl, model, bert_config, preln=False):
""" Revert DeepSpeed's transformer layer back to original bert-style transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation that was replaced,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
bert_config (dict): model config containing hidden size, attention heads, etc.
Returns:
Updated nn.module with original bert-style transformer layers
"""
def replace_fn(child):
#from turing.nvidia_modelingpreln import BertLayer
orig_module = orig_layer_impl(bert_config)
# copy relevant state from child -> original module
qkvw = child.attn_qkvw.data
qkvb = child.attn_qkvb.data
qw, kw, vw = torch.chunk(qkvw, 3, axis=0)
qb, kb, vb = torch.chunk(qkvb, 3, axis=0)
orig_module.attention.self.query.weight.data = qw
orig_module.attention.self.query.bias.data = qb
orig_module.attention.self.key.weight.data = kw
orig_module.attention.self.key.bias.data = kb
orig_module.attention.self.value.weight.data = vw
orig_module.attention.self.value.bias.data = vb
orig_module.attention.output.dense.weight.data = child.attn_ow.data
orig_module.attention.output.dense.bias.data = child.attn_ob.data
attn_ln_w = child.attn_nw.data
attn_ln_b = child.attn_nb.data
if preln:
orig_module.PostAttentionLayerNorm.weight.data = attn_ln_w
orig_module.PostAttentionLayerNorm.bias.data = attn_ln_b
else:
orig_module.attention.output.LayerNorm.weight.data = attn_ln_w
orig_module.attention.output.LayerNorm.bias.data = attn_ln_b
inter_ff_w = child.inter_w.data
inter_ff_b = child.inter_b.data
if preln:
orig_module.intermediate.dense_act.weight.data = inter_ff_w
orig_module.intermediate.dense_act.bias.data = inter_ff_b
else:
orig_module.intermediate.dense.weight.data = inter_ff_w
orig_module.intermediate.dense.bias.data = inter_ff_b
orig_module.output.dense.weight.data = child.output_w.data
orig_module.output.dense.bias.data = child.output_b.data
transformer_ln_w = child.norm_w.data
transformer_ln_b = child.norm_b.data
if preln:
orig_module.PreAttentionLayerNorm.weight.data = transformer_ln_w
orig_module.PreAttentionLayerNorm.bias.data = transformer_ln_b
else:
orig_module.output.LayerNorm.weight.data = transformer_ln_w
orig_module.output.LayerNorm.bias.data = transformer_ln_b
return orig_module
return replace_module(model=model,
orig_class=deepspeed.DeepSpeedTransformerLayer,
replace_fn=replace_fn)
def replace_module(model, orig_class, replace_fn):
""" Scan the model for instances of ``orig_clas:`` to replace using ``replace_fn``.
Arguments:
model (torch.nn.Module): the model to augment
orig_class (torch.nn.Module): the module to search for
replace_fn (method): a method to convert instances of ``orig_class`` to the
desired type and return a new instance.
Returns:
A modified ``model``.
"""
policy = {orig_class: replace_fn}
return _replace_module(model, policy)
def _replace_module(model, policies):
""" Traverse model's children recursively and apply any transformations in ``policies``.
Arguments:
model (torch.nn.Module): model to augment
policies (dict): Mapping of source class to replacement function.
Returns:
Modified ``model``.
"""
for name, child in model.named_children():
if child.__class__ in policies:
orig = repr(child)
setattr(model, name, policies[child.__class__](child))
new = getattr(model, name)
else:
_replace_module(child, policies)
return model
...@@ -3,4 +3,7 @@ from . import lamb ...@@ -3,4 +3,7 @@ from . import lamb
from . import sparse_attention from . import sparse_attention
from . import transformer from . import transformer
from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .module_inject import replace_module
from ..git_version_info import compatible_ops as __compatible_ops__ from ..git_version_info import compatible_ops as __compatible_ops__
import copy
import torch
import deepspeed
from deepspeed.ops import DeepSpeedTransformerConfig
def _copy_child_transformer_state(new_module, orig_child, pre_layer_norm):
# copy relevant state from original child -> new module
qw = orig_child.attention.self.query.weight
qb = orig_child.attention.self.query.bias
kw = orig_child.attention.self.key.weight
kb = orig_child.attention.self.key.bias
vw = orig_child.attention.self.value.weight
vb = orig_child.attention.self.value.bias
qkvw = torch.cat((qw, kw, vw), 0)
qkvb = torch.cat((qb, kb, vb), 0)
#qw.data,kw.data,vw.data = torch.chunk(qkvw, 3, axis=0)
#qb.data,kb.data,vb.data = torch.chunk(qkvb, 3, axis=0)
new_module.attn_qkvw.data = qkvw
new_module.attn_qkvb.data = qkvb
new_module.attn_ow.data = orig_child.attention.output.dense.weight
new_module.attn_ob.data = orig_child.attention.output.dense.bias
if pre_layer_norm:
attention_layernorm = orig_child.PostAttentionLayerNorm
else:
attention_layernorm = orig_child.attention.output.LayerNorm
new_module.attn_nw.data = attention_layernorm.weight
new_module.attn_nb.data = attention_layernorm.bias
if pre_layer_norm:
intermediate_ff = orig_child.intermediate.dense_act
else:
intermediate_ff = orig_child.intermediate.dense
new_module.inter_w.data = intermediate_ff.weight
new_module.inter_b.data = intermediate_ff.bias
new_module.output_w.data = orig_child.output.dense.weight
new_module.output_b.data = orig_child.output.dense.bias
if pre_layer_norm:
transformer_layernorm = orig_child.PreAttentionLayerNorm
else:
transformer_layernorm = orig_child.output.LayerNorm
new_module.norm_w.data = transformer_layernorm.weight
new_module.norm_b.data = transformer_layernorm.bias
def _replace_transformer_layer(orig_layer_impl, model, transformer_config):
""" Replace bert-style transformer layers with DeepSpeed's transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
transformer_config (dict): deepspeed transformer layer config containing hidden size, attention heads, etc.
Returns:
Updated nn.module with replaced transformer layers
"""
def replace_fn(child):
new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config)
_copy_child_transformer_state(new_module,
child,
transformer_config.pre_layer_norm)
return new_module
return _replace_module(model=model,
orig_class=orig_layer_impl,
replace_fn=replace_fn)
def replace_module(orig_module_impl, model, replacement_module_config):
""" Replace client module
Arguments:
orig_module_impl (torch.nn.Module): original module implementation to replace,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
replacement_module_config (dict): deepspeed replacement module config (e.g., DeepSpeedTransformerConfig) .
Returns:
Updated nn.module with replaced modules
"""
assert isinstance(replacement_module_config, DeepSpeedTransformerConfig), \
'Only DeepSpeedTransformerConfig is currently supported as replacement config'
return _replace_transformer_layer(orig_layer_impl=orig_module_impl,
model=model,
transformer_config=replacement_module_config)
def _revert_transformer_layer(orig_layer_impl, model, bert_config, transformer_config):
""" Revert DeepSpeed's transformer layer back to original bert-style transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation that was replaced,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
bert_config (dict): model config containing hidden size, attention heads, etc.
transformer_config (dict): deepspeed tranformer config used for replacement
Returns:
Updated nn.module with original bert-style transformer layers
"""
def replace_fn(child):
#from turing.nvidia_modelingpreln import BertLayer
orig_module = orig_layer_impl(bert_config)
# copy relevant state from child -> original module
qkvw = child.attn_qkvw.data
qkvb = child.attn_qkvb.data
qw, kw, vw = torch.chunk(qkvw, 3, axis=0)
qb, kb, vb = torch.chunk(qkvb, 3, axis=0)
orig_module.attention.self.query.weight.data = qw
orig_module.attention.self.query.bias.data = qb
orig_module.attention.self.key.weight.data = kw
orig_module.attention.self.key.bias.data = kb
orig_module.attention.self.value.weight.data = vw
orig_module.attention.self.value.bias.data = vb
orig_module.attention.output.dense.weight.data = child.attn_ow.data
orig_module.attention.output.dense.bias.data = child.attn_ob.data
attn_ln_w = child.attn_nw.data
attn_ln_b = child.attn_nb.data
if transformer_config.pre_layer_norm:
orig_module.PostAttentionLayerNorm.weight.data = attn_ln_w
orig_module.PostAttentionLayerNorm.bias.data = attn_ln_b
else:
orig_module.attention.output.LayerNorm.weight.data = attn_ln_w
orig_module.attention.output.LayerNorm.bias.data = attn_ln_b
inter_ff_w = child.inter_w.data
inter_ff_b = child.inter_b.data
if transformer_config.pre_layer_norm:
orig_module.intermediate.dense_act.weight.data = inter_ff_w
orig_module.intermediate.dense_act.bias.data = inter_ff_b
else:
orig_module.intermediate.dense.weight.data = inter_ff_w
orig_module.intermediate.dense.bias.data = inter_ff_b
orig_module.output.dense.weight.data = child.output_w.data
orig_module.output.dense.bias.data = child.output_b.data
transformer_ln_w = child.norm_w.data
transformer_ln_b = child.norm_b.data
if transformer_config.pre_layer_norm:
orig_module.PreAttentionLayerNorm.weight.data = transformer_ln_w
orig_module.PreAttentionLayerNorm.bias.data = transformer_ln_b
else:
orig_module.output.LayerNorm.weight.data = transformer_ln_w
orig_module.output.LayerNorm.bias.data = transformer_ln_b
return orig_module
return _replace_module(model=model,
orig_class=deepspeed.DeepSpeedTransformerLayer,
replace_fn=replace_fn)
def revert_module(orig_module_impl,
model,
orig_module_config,
replacement_module_config):
""" Revert DeepSpeed's module back to original client module
Arguments:
orig_module_impl (torch.nn.Module): the original module that was replaced,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
orig_module_config (dict): original module configuration
replacement_module_config (dict): replacement deepspeed module configuration
Returns:
Updated nn.module with original bert-style transformer layers
"""
assert isinstance(replacement_module_config, DeepSpeedTransformerConfig), \
'Only DeepSpeedTransformerConfig is currently supported as replacement config'
return _revert_transformer_layer(orig_layer_impl=orig_module_impl,
model=model,
bert_config=orig_module_config,
transformer_config=replacement_module_config)
def _replace_module(model, orig_class, replace_fn):
""" Scan the model for instances of ``orig_clas:`` to replace using ``replace_fn``.
Arguments:
model (torch.nn.Module): the model to augment
orig_class (torch.nn.Module): the module to search for
replace_fn (method): a method to convert instances of ``orig_class`` to the
desired type and return a new instance.
Returns:
A modified ``model``.
"""
policy = {orig_class: replace_fn}
return _replace_module_using_policies(model, policy)
def _replace_module_using_policies(model, policies):
""" Traverse model's children recursively and apply any transformations in ``policies``.
Arguments:
model (torch.nn.Module): model to augment
policies (dict): Mapping of source class to replacement function.
Returns:
Modified ``model``.
"""
for name, child in model.named_children():
if child.__class__ in policies:
orig = repr(child)
setattr(model, name, policies[child.__class__](child))
new = getattr(model, name)
else:
_replace_module_using_policies(child, policies)
return model
...@@ -87,6 +87,10 @@ class DeepSpeedTransformerConfig(TransformerConfig): ...@@ -87,6 +87,10 @@ class DeepSpeedTransformerConfig(TransformerConfig):
that by enabling it, the pretraining tasks such as BERT are not affected and can obtain that by enabling it, the pretraining tasks such as BERT are not affected and can obtain
a high accuracy level. On the other hand, for the downstream tasks, such as fine-tuning, we recommend a high accuracy level. On the other hand, for the downstream tasks, such as fine-tuning, we recommend
to turn it off in order to be able to reproduce the same result through the regular kernel execution. to turn it off in order to be able to reproduce the same result through the regular kernel execution.
huggingface: Enbale if using the HuggingFace interface style for sending out the forward results.
training: Enable for training rather than inference.
""" """
def __init__(self, def __init__(self,
batch_size=-1, batch_size=-1,
...@@ -105,7 +109,9 @@ class DeepSpeedTransformerConfig(TransformerConfig): ...@@ -105,7 +109,9 @@ class DeepSpeedTransformerConfig(TransformerConfig):
gelu_checkpoint=False, gelu_checkpoint=False,
adjust_init_range=True, adjust_init_range=True,
attn_dropout_checkpoint=False, attn_dropout_checkpoint=False,
stochastic_mode=False): stochastic_mode=False,
huggingface=False,
training=True):
super(DeepSpeedTransformerConfig, super(DeepSpeedTransformerConfig,
self).__init__( self).__init__(
batch_size, batch_size,
...@@ -124,10 +130,11 @@ class DeepSpeedTransformerConfig(TransformerConfig): ...@@ -124,10 +130,11 @@ class DeepSpeedTransformerConfig(TransformerConfig):
self.gelu_checkpoint = gelu_checkpoint # True: if higher batch size is required self.gelu_checkpoint = gelu_checkpoint # True: if higher batch size is required
self.adjust_init_range = adjust_init_range self.adjust_init_range = adjust_init_range
self.test_gemm = False self.test_gemm = False
self.training = True self.training = training
self.is_grad_enabled = True self.is_grad_enabled = True
self.attn_dropout_checkpoint = attn_dropout_checkpoint self.attn_dropout_checkpoint = attn_dropout_checkpoint
self.stochastic_mode = stochastic_mode self.stochastic_mode = stochastic_mode
self.huggingface = huggingface
@classmethod @classmethod
def from_dict(cls, json_object): def from_dict(cls, json_object):
...@@ -252,7 +259,7 @@ class DeepSpeedTransformerFunction(Function): ...@@ -252,7 +259,7 @@ class DeepSpeedTransformerFunction(Function):
norm_w.register_hook(lambda x, self=self: grads.append([x, "norm_W"])) norm_w.register_hook(lambda x, self=self: grads.append([x, "norm_W"]))
norm_b.register_hook(lambda x, self=self: grads.append([x, "norm_B"])) norm_b.register_hook(lambda x, self=self: grads.append([x, "norm_B"]))
if config.is_grad_enabled: if config.is_grad_enabled and config.training:
if (config.pre_layer_norm and config.normalize_invertible): if (config.pre_layer_norm and config.normalize_invertible):
ctx.save_for_backward(input_mask, ctx.save_for_backward(input_mask,
attn_qkvw, attn_qkvw,
...@@ -313,7 +320,11 @@ class DeepSpeedTransformerFunction(Function): ...@@ -313,7 +320,11 @@ class DeepSpeedTransformerFunction(Function):
if inp_size[1] % 16 != 0: if inp_size[1] % 16 != 0:
output = torch.narrow(output, 1, 0, inp_size[1]) output = torch.narrow(output, 1, 0, inp_size[1])
return output
if config.huggingface:
return (output, ) # outputs -> (output) : outputs[0] = output
else:
return output
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
...@@ -412,6 +423,25 @@ class DeepSpeedTransformerFunction(Function): ...@@ -412,6 +423,25 @@ class DeepSpeedTransformerFunction(Function):
norm_w, norm_w,
norm_b) norm_b)
# This appears to be an effective way to release context memory
ctx.qkv_tf = None
ctx.soft_inp = None
ctx.ctx_bufB = None
ctx.gelu_inp = None
ctx.ff2_inp = None
ctx.attn_o_inp = None
ctx.ff1_inp = None
ctx.add_res = None
ctx.inp_norm = None
ctx.config = None
ctx.attn_layer_norm_mean = None
ctx.layer_norm_mean = None
ctx.attn_prob_dropout_mask = None
ctx.attn_output_dropout_mask = None
ctx.layer_output_dropout_mask = None
ctx.attn_layer_norm_var = None
ctx.layer_norm_var = None
if grad_output_shape[1] % 16 != 0: if grad_output_shape[1] % 16 != 0:
grad_input = torch.narrow(grad_input, 1, 0, grad_output_shape[1]) grad_input = torch.narrow(grad_input, 1, 0, grad_output_shape[1])
...@@ -438,21 +468,24 @@ class DeepSpeedTransformerFunction(Function): ...@@ -438,21 +468,24 @@ class DeepSpeedTransformerFunction(Function):
class DeepSpeedTransformerLayer(nn.Module): class DeepSpeedTransformerLayer(nn.Module):
"""Initialize the DeepSpeed Transformer Layer. """Initialize the DeepSpeed Transformer Layer.
Static variable:
layer_id: The layer-index counter starting from 0 and incrementing by 1 every time a layer object is instantiated,
e.g. if a model has 24 transformer layers, layer_id goes from 0 to 23.
Arguments: Arguments:
layer_id: The layer index starting from 0, e.g. if model has 24 transformer layers,
layer_id will be 0,1,2...23 when each layer object is instantiated
config: An object of DeepSpeedTransformerConfig config: An object of DeepSpeedTransformerConfig
initial_weights: Optional: Only used for unit test initial_weights: Optional: Only used for unit test
initial_biases: Optional: Only used for unit test initial_biases: Optional: Only used for unit test
""" """
def __init__(self, layer_id, config, initial_weights=None, initial_biases=None): layer_id = 0
def __init__(self, config, initial_weights=None, initial_biases=None):
super(DeepSpeedTransformerLayer, self).__init__() super(DeepSpeedTransformerLayer, self).__init__()
self.config = config self.config = config
self.config.layer_id = layer_id self.config.layer_id = DeepSpeedTransformerLayer.layer_id
DeepSpeedTransformerLayer.layer_id = DeepSpeedTransformerLayer.layer_id + 1
print("DeepSpeed Transformer config is ", self.config.__dict__) print("DeepSpeed Transformer config is ", self.config.__dict__)
...@@ -548,11 +581,18 @@ class DeepSpeedTransformerLayer(nn.Module): ...@@ -548,11 +581,18 @@ class DeepSpeedTransformerLayer(nn.Module):
self.norm_w.data.fill_(1.0) self.norm_w.data.fill_(1.0)
self.norm_b.data.zero_() self.norm_b.data.zero_()
def forward(self, input, input_mask, grads=None): def forward(self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
grads=None):
self.config.training = self.training self.config.training = self.training
self.config.is_grad_enabled = torch.is_grad_enabled() self.config.is_grad_enabled = torch.is_grad_enabled()
return DeepSpeedTransformerFunction.apply(input, return DeepSpeedTransformerFunction.apply(hidden_states,
input_mask, attention_mask,
self, self,
grads, grads,
self.config.layer_id, self.config.layer_id,
......
...@@ -530,21 +530,10 @@ def see_memory_usage(message): ...@@ -530,21 +530,10 @@ def see_memory_usage(message):
# Print message except when distributed but not rank 0 # Print message except when distributed but not rank 0
logger.info(message) logger.info(message)
logger.info( logger.info(
"Memory Allocated %s GigaBytes ", f"MA {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),2 )} GB \
torch.cuda.memory_allocated() / (1024 * 1024 * 1024), Max_MA {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \
) CA {round(torch.cuda.memory_cached() / (1024 * 1024 * 1024),2)} GB \
logger.info( Max_CA {round(torch.cuda.max_memory_cached() / (1024 * 1024 * 1024))} GB ")
"Max Memory Allocated %s GigaBytes",
torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),
)
logger.info(
"Cache Allocated %s GigaBytes",
torch.cuda.memory_cached() / (1024 * 1024 * 1024),
)
logger.info(
"Max cache Allocated %s GigaBytes",
torch.cuda.max_memory_cached() / (1024 * 1024 * 1024),
)
def call_to_str(base, *args, **kwargs): def call_to_str(base, *args, **kwargs):
......
...@@ -284,10 +284,10 @@ transformer layers using DeepSpeed transformer kernel as below. ...@@ -284,10 +284,10 @@ transformer layers using DeepSpeed transformer kernel as below.
gelu_checkpoint=args.gelu_checkpoint, gelu_checkpoint=args.gelu_checkpoint,
stochastic_mode=True) stochastic_mode=True)
self.layer = nn.ModuleList([copy.deepcopy(DeepSpeedTransformerLayer(i, cuda_config)) for i in range(config.num_hidden_layers)]) layer = DeepSpeedTransformerLayer(cuda_config)
else: else:
layer = BertLayer(config) layer = BertLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
``` ```
All configuration settings come from the DeepSpeed configuration file and All configuration settings come from the DeepSpeed configuration file and
command arguments and thus we must pass the `args` variable to here in this model. command arguments and thus we must pass the `args` variable to here in this model.
......
...@@ -43,8 +43,8 @@ config = DeepSpeedTransformerConfig(batch_size = 64, ...@@ -43,8 +43,8 @@ config = DeepSpeedTransformerConfig(batch_size = 64,
normalize_invertible=False, normalize_invertible=False,
gelu_checkpoint=False) gelu_checkpoint=False)
self.layer = nn.ModuleList([ self.layer = nn.ModuleList([
copy.deepcopy(DeepSpeedTransformerLayer(i, cuda_config)) copy.deepcopy(DeepSpeedTransformerLayer(cuda_config))
for i in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
``` ```
### Transformer kernel Parameters ### Transformer kernel Parameters
......
...@@ -83,11 +83,10 @@ class DSEncoder(nn.Module): ...@@ -83,11 +83,10 @@ class DSEncoder(nn.Module):
super(DSEncoder, self).__init__() super(DSEncoder, self).__init__()
self.FinalLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.FinalLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.layer = nn.ModuleList([ self.layer = nn.ModuleList([
copy.deepcopy(DeepSpeedTransformerLayer(i, copy.deepcopy(DeepSpeedTransformerLayer(config,
config,
weights, weights,
biases)) biases))
for i in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.grads = [] self.grads = []
self.pre_or_post = config.pre_layer_norm self.pre_or_post = config.pre_layer_norm
...@@ -122,7 +121,9 @@ class DSEncoder(nn.Module): ...@@ -122,7 +121,9 @@ class DSEncoder(nn.Module):
# decoder layers # decoder layers
else: else:
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states, attention_mask, self.grads) hidden_states = layer_module(hidden_states,
attention_mask,
grads=self.grads)
hidden_states.register_hook( hidden_states.register_hook(
lambda x, lambda x,
self=self: self.grads.append([x, self=self: self.grads.append([x,
......
...@@ -48,11 +48,10 @@ class DSEncoder(nn.Module): ...@@ -48,11 +48,10 @@ class DSEncoder(nn.Module):
super(DSEncoder, self).__init__() super(DSEncoder, self).__init__()
self.FinalLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.FinalLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.layer = nn.ModuleList([ self.layer = nn.ModuleList([
copy.deepcopy(DeepSpeedTransformerLayer(i, copy.deepcopy(DeepSpeedTransformerLayer(config,
config,
weights, weights,
biases)) biases))
for i in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.grads = [] self.grads = []
self.pre_or_post = config.pre_layer_norm self.pre_or_post = config.pre_layer_norm
...@@ -88,11 +87,6 @@ class DSEncoder(nn.Module): ...@@ -88,11 +87,6 @@ class DSEncoder(nn.Module):
else: else:
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states, attention_mask) hidden_states = layer_module(hidden_states, attention_mask)
hidden_states.register_hook(
lambda x,
i=i,
self=self: self.grads.append([x,
"hidden_state"]))
if output_all_encoded_layers: if output_all_encoded_layers:
all_encoder_layers.append(hidden_states) all_encoder_layers.append(hidden_states)
...@@ -103,9 +97,6 @@ class DSEncoder(nn.Module): ...@@ -103,9 +97,6 @@ class DSEncoder(nn.Module):
all_encoder_layers.append(hidden_states) all_encoder_layers.append(hidden_states)
return all_encoder_layers return all_encoder_layers
def get_grads(self):
return self.grads
def create_models(ds_config): def create_models(ds_config):
bert_config = BertConfig(vocab_size_or_config_json_file=119547, bert_config = BertConfig(vocab_size_or_config_json_file=119547,
...@@ -201,7 +192,7 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None): ...@@ -201,7 +192,7 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None):
output_all_encoded_layers=False, output_all_encoded_layers=False,
checkpoint_activations=False) checkpoint_activations=False)
# check grads # check forward evaluation
check_equal(base_results, ds_results, atol=atol, verbose=verbose) check_equal(base_results, ds_results, atol=atol, verbose=verbose)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment