Unverified Commit 44bd538b authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

Module replacement support (#586)


Co-authored-by: default avatarReza Yazdani <reyazda@microsoft.com>
Co-authored-by: default avatarOlatunji Ruwase <olruwase@microsoft.com>
parent 5ab12795
import copy
import torch
from deepspeed.ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
def module_inject(layer_obj,
model,
config,
micro_batch_size,
max_seq_length,
seed,
preln,
fp16=True):
for name, child in model.named_children():
if isinstance(child, layer_obj):
print('REPLACING BertLayer')
cuda_config = DeepSpeedTransformerConfig(
batch_size=micro_batch_size,
max_seq_length=max_seq_length,
hidden_size=config.hidden_size,
heads=config.num_attention_heads,
attn_dropout_ratio=config.attention_probs_dropout_prob,
hidden_dropout_ratio=config.hidden_dropout_prob,
num_hidden_layers=config.num_hidden_layers,
initializer_range=config.initializer_range,
seed=seed,
fp16=fp16,
pre_layer_norm=preln)
new_module = DeepSpeedTransformerLayer(cuda_config)
# copy relevant state from child -> new module
qw = child.attention.self.query.weight
qb = child.attention.self.query.bias
kw = child.attention.self.key.weight
kb = child.attention.self.key.bias
vw = child.attention.self.value.weight
vb = child.attention.self.value.bias
qkvw = torch.cat((qw, kw, vw), 0)
qkvb = torch.cat((qb, kb, vb), 0)
new_module.attn_qkvw.data = qkvw
new_module.attn_qkvb.data = qkvb
new_module.attn_ow.data = child.attention.output.dense.weight
new_module.attn_ob.data = child.attention.output.dense.bias
if preln:
attention_layerNorm = child.PostAttentionLayerNorm
else:
attention_layerNorm = child.attention.output.LayerNorm
new_module.attn_nw.data = attention_layerNorm.weight
new_module.attn_nb.data = attention_layerNorm.bias
if preln:
intermediate_FF = child.intermediate.dense_act
else:
intermediate_FF = child.intermediate.dense
new_module.inter_w.data = intermediate_FF.weight
new_module.inter_b.data = intermediate_FF.bias
new_module.output_w.data = child.output.dense.weight
new_module.output_b.data = child.output.dense.bias
if preln:
transformer_LayerNorm = child.PreAttentionLayerNorm
else:
transformer_LayerNorm = child.output.LayerNorm
new_module.norm_w.data = transformer_LayerNorm.weight
new_module.norm_b.data = transformer_LayerNorm.bias
setattr(model, name, copy.deepcopy(new_module))
else:
module_inject(layer_obj,
child,
config,
micro_batch_size,
max_seq_length,
seed,
preln,
fp16)
return model
def test_hi():
from turing.nvidia_modelingpreln import BertConfig as BertConfigPreLN
from turing.nvidia_modelingpreln import BertForQuestionAnswering as BertForQuestionAnsweringPreLN
from turing.nvidia_modelingpreln import BertLayer
bert_model_config = {
"vocab_size_or_config_json_file": 119547,
"hidden_size": 1024,
"num_hidden_layers": 1,
"num_attention_heads": 16,
"intermediate_size": 4096,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"initializer_range": 0.02
}
bert_config = BertConfigPreLN(**bert_model_config)
base_model = BertForQuestionAnsweringPreLN(bert_config, args=None)
#base_model = LinearStack()
test_model = copy.deepcopy(base_model)
test_model = module_inject(BertLayer, test_model, bert_config, 4, 384, 1234)
print('BASE', base_model)
print('TEST', test_model)
#base_model.eval()
#test_model.eval()
#test_input = torch.rand(1, base_model.input_dim)
#base_output = base_model(test_input)
#test_output = test_model(test_input)
#
#assert torch.allclose(base_output, test_output, atol=3e-8)
import copy
import torch
import deepspeed
def replace_transformer_layer(orig_layer_impl,
model,
micro_batch_size,
bert_config,
seed,
max_seq_length,
preln=False,
fp16=True,
huggingface=False,
local_rank=-1):
""" Replace bert-style transformer layers with DeepSpeed's transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
micro_batch_size (int): micro batch size per gpu used during training/eval
bert_config (dict): model config containing hidden size, attention heads, etc.
seed (int): random seed value
max_seq_length (int): max sequence length for training
preln (bool): does the original layer implementation do pre or post layer norm?
fp16 (bool): fp16 or fp32
huggingface (bool): huggingface implementation is unique (supports both encoder/decoder modes)
Returns:
Updated nn.module with replaced transformer layers
"""
def replace_fn(child):
transformer_config = deepspeed.DeepSpeedTransformerConfig(
batch_size=micro_batch_size,
max_seq_length=max_seq_length,
hidden_size=bert_config.hidden_size,
heads=bert_config.num_attention_heads,
attn_dropout_ratio=bert_config.attention_probs_dropout_prob,
hidden_dropout_ratio=bert_config.hidden_dropout_prob,
num_hidden_layers=bert_config.num_hidden_layers,
initializer_range=bert_config.initializer_range,
seed=seed,
fp16=fp16,
pre_layer_norm=preln,
huggingface=huggingface,
local_rank=local_rank)
new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config)
# copy relevant state from child -> new module
qw = child.attention.self.query.weight
qb = child.attention.self.query.bias
kw = child.attention.self.key.weight
kb = child.attention.self.key.bias
vw = child.attention.self.value.weight
vb = child.attention.self.value.bias
qkvw = torch.cat((qw, kw, vw), 0)
qkvb = torch.cat((qb, kb, vb), 0)
#qw.data,kw.data,vw.data = torch.chunk(qkvw, 3, axis=0)
#qb.data,kb.data,vb.data = torch.chunk(qkvb, 3, axis=0)
new_module.attn_qkvw.data = qkvw
new_module.attn_qkvb.data = qkvb
new_module.attn_ow.data = child.attention.output.dense.weight
new_module.attn_ob.data = child.attention.output.dense.bias
if preln:
attention_layernorm = child.PostAttentionLayerNorm
else:
attention_layernorm = child.attention.output.LayerNorm
new_module.attn_nw.data = attention_layernorm.weight
new_module.attn_nb.data = attention_layernorm.bias
if preln:
intermediate_ff = child.intermediate.dense_act
else:
intermediate_ff = child.intermediate.dense
new_module.inter_w.data = intermediate_ff.weight
new_module.inter_b.data = intermediate_ff.bias
new_module.output_w.data = child.output.dense.weight
new_module.output_b.data = child.output.dense.bias
if preln:
transformer_layernorm = child.PreAttentionLayerNorm
else:
transformer_layernorm = child.output.LayerNorm
new_module.norm_w.data = transformer_layernorm.weight
new_module.norm_b.data = transformer_layernorm.bias
return new_module
return replace_module(model=model, orig_class=orig_layer_impl, replace_fn=replace_fn)
def revert_transformer_layer(orig_layer_impl, model, bert_config, preln=False):
""" Revert DeepSpeed's transformer layer back to original bert-style transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation that was replaced,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
bert_config (dict): model config containing hidden size, attention heads, etc.
Returns:
Updated nn.module with original bert-style transformer layers
"""
def replace_fn(child):
#from turing.nvidia_modelingpreln import BertLayer
orig_module = orig_layer_impl(bert_config)
# copy relevant state from child -> original module
qkvw = child.attn_qkvw.data
qkvb = child.attn_qkvb.data
qw, kw, vw = torch.chunk(qkvw, 3, axis=0)
qb, kb, vb = torch.chunk(qkvb, 3, axis=0)
orig_module.attention.self.query.weight.data = qw
orig_module.attention.self.query.bias.data = qb
orig_module.attention.self.key.weight.data = kw
orig_module.attention.self.key.bias.data = kb
orig_module.attention.self.value.weight.data = vw
orig_module.attention.self.value.bias.data = vb
orig_module.attention.output.dense.weight.data = child.attn_ow.data
orig_module.attention.output.dense.bias.data = child.attn_ob.data
attn_ln_w = child.attn_nw.data
attn_ln_b = child.attn_nb.data
if preln:
orig_module.PostAttentionLayerNorm.weight.data = attn_ln_w
orig_module.PostAttentionLayerNorm.bias.data = attn_ln_b
else:
orig_module.attention.output.LayerNorm.weight.data = attn_ln_w
orig_module.attention.output.LayerNorm.bias.data = attn_ln_b
inter_ff_w = child.inter_w.data
inter_ff_b = child.inter_b.data
if preln:
orig_module.intermediate.dense_act.weight.data = inter_ff_w
orig_module.intermediate.dense_act.bias.data = inter_ff_b
else:
orig_module.intermediate.dense.weight.data = inter_ff_w
orig_module.intermediate.dense.bias.data = inter_ff_b
orig_module.output.dense.weight.data = child.output_w.data
orig_module.output.dense.bias.data = child.output_b.data
transformer_ln_w = child.norm_w.data
transformer_ln_b = child.norm_b.data
if preln:
orig_module.PreAttentionLayerNorm.weight.data = transformer_ln_w
orig_module.PreAttentionLayerNorm.bias.data = transformer_ln_b
else:
orig_module.output.LayerNorm.weight.data = transformer_ln_w
orig_module.output.LayerNorm.bias.data = transformer_ln_b
return orig_module
return replace_module(model=model,
orig_class=deepspeed.DeepSpeedTransformerLayer,
replace_fn=replace_fn)
def replace_module(model, orig_class, replace_fn):
""" Scan the model for instances of ``orig_clas:`` to replace using ``replace_fn``.
Arguments:
model (torch.nn.Module): the model to augment
orig_class (torch.nn.Module): the module to search for
replace_fn (method): a method to convert instances of ``orig_class`` to the
desired type and return a new instance.
Returns:
A modified ``model``.
"""
policy = {orig_class: replace_fn}
return _replace_module(model, policy)
def _replace_module(model, policies):
""" Traverse model's children recursively and apply any transformations in ``policies``.
Arguments:
model (torch.nn.Module): model to augment
policies (dict): Mapping of source class to replacement function.
Returns:
Modified ``model``.
"""
for name, child in model.named_children():
if child.__class__ in policies:
orig = repr(child)
setattr(model, name, policies[child.__class__](child))
new = getattr(model, name)
else:
_replace_module(child, policies)
return model
......@@ -3,4 +3,7 @@ from . import lamb
from . import sparse_attention
from . import transformer
from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .module_inject import replace_module
from ..git_version_info import compatible_ops as __compatible_ops__
import copy
import torch
import deepspeed
from deepspeed.ops import DeepSpeedTransformerConfig
def _copy_child_transformer_state(new_module, orig_child, pre_layer_norm):
# copy relevant state from original child -> new module
qw = orig_child.attention.self.query.weight
qb = orig_child.attention.self.query.bias
kw = orig_child.attention.self.key.weight
kb = orig_child.attention.self.key.bias
vw = orig_child.attention.self.value.weight
vb = orig_child.attention.self.value.bias
qkvw = torch.cat((qw, kw, vw), 0)
qkvb = torch.cat((qb, kb, vb), 0)
#qw.data,kw.data,vw.data = torch.chunk(qkvw, 3, axis=0)
#qb.data,kb.data,vb.data = torch.chunk(qkvb, 3, axis=0)
new_module.attn_qkvw.data = qkvw
new_module.attn_qkvb.data = qkvb
new_module.attn_ow.data = orig_child.attention.output.dense.weight
new_module.attn_ob.data = orig_child.attention.output.dense.bias
if pre_layer_norm:
attention_layernorm = orig_child.PostAttentionLayerNorm
else:
attention_layernorm = orig_child.attention.output.LayerNorm
new_module.attn_nw.data = attention_layernorm.weight
new_module.attn_nb.data = attention_layernorm.bias
if pre_layer_norm:
intermediate_ff = orig_child.intermediate.dense_act
else:
intermediate_ff = orig_child.intermediate.dense
new_module.inter_w.data = intermediate_ff.weight
new_module.inter_b.data = intermediate_ff.bias
new_module.output_w.data = orig_child.output.dense.weight
new_module.output_b.data = orig_child.output.dense.bias
if pre_layer_norm:
transformer_layernorm = orig_child.PreAttentionLayerNorm
else:
transformer_layernorm = orig_child.output.LayerNorm
new_module.norm_w.data = transformer_layernorm.weight
new_module.norm_b.data = transformer_layernorm.bias
def _replace_transformer_layer(orig_layer_impl, model, transformer_config):
""" Replace bert-style transformer layers with DeepSpeed's transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
transformer_config (dict): deepspeed transformer layer config containing hidden size, attention heads, etc.
Returns:
Updated nn.module with replaced transformer layers
"""
def replace_fn(child):
new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config)
_copy_child_transformer_state(new_module,
child,
transformer_config.pre_layer_norm)
return new_module
return _replace_module(model=model,
orig_class=orig_layer_impl,
replace_fn=replace_fn)
def replace_module(orig_module_impl, model, replacement_module_config):
""" Replace client module
Arguments:
orig_module_impl (torch.nn.Module): original module implementation to replace,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
replacement_module_config (dict): deepspeed replacement module config (e.g., DeepSpeedTransformerConfig) .
Returns:
Updated nn.module with replaced modules
"""
assert isinstance(replacement_module_config, DeepSpeedTransformerConfig), \
'Only DeepSpeedTransformerConfig is currently supported as replacement config'
return _replace_transformer_layer(orig_layer_impl=orig_module_impl,
model=model,
transformer_config=replacement_module_config)
def _revert_transformer_layer(orig_layer_impl, model, bert_config, transformer_config):
""" Revert DeepSpeed's transformer layer back to original bert-style transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation that was replaced,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
bert_config (dict): model config containing hidden size, attention heads, etc.
transformer_config (dict): deepspeed tranformer config used for replacement
Returns:
Updated nn.module with original bert-style transformer layers
"""
def replace_fn(child):
#from turing.nvidia_modelingpreln import BertLayer
orig_module = orig_layer_impl(bert_config)
# copy relevant state from child -> original module
qkvw = child.attn_qkvw.data
qkvb = child.attn_qkvb.data
qw, kw, vw = torch.chunk(qkvw, 3, axis=0)
qb, kb, vb = torch.chunk(qkvb, 3, axis=0)
orig_module.attention.self.query.weight.data = qw
orig_module.attention.self.query.bias.data = qb
orig_module.attention.self.key.weight.data = kw
orig_module.attention.self.key.bias.data = kb
orig_module.attention.self.value.weight.data = vw
orig_module.attention.self.value.bias.data = vb
orig_module.attention.output.dense.weight.data = child.attn_ow.data
orig_module.attention.output.dense.bias.data = child.attn_ob.data
attn_ln_w = child.attn_nw.data
attn_ln_b = child.attn_nb.data
if transformer_config.pre_layer_norm:
orig_module.PostAttentionLayerNorm.weight.data = attn_ln_w
orig_module.PostAttentionLayerNorm.bias.data = attn_ln_b
else:
orig_module.attention.output.LayerNorm.weight.data = attn_ln_w
orig_module.attention.output.LayerNorm.bias.data = attn_ln_b
inter_ff_w = child.inter_w.data
inter_ff_b = child.inter_b.data
if transformer_config.pre_layer_norm:
orig_module.intermediate.dense_act.weight.data = inter_ff_w
orig_module.intermediate.dense_act.bias.data = inter_ff_b
else:
orig_module.intermediate.dense.weight.data = inter_ff_w
orig_module.intermediate.dense.bias.data = inter_ff_b
orig_module.output.dense.weight.data = child.output_w.data
orig_module.output.dense.bias.data = child.output_b.data
transformer_ln_w = child.norm_w.data
transformer_ln_b = child.norm_b.data
if transformer_config.pre_layer_norm:
orig_module.PreAttentionLayerNorm.weight.data = transformer_ln_w
orig_module.PreAttentionLayerNorm.bias.data = transformer_ln_b
else:
orig_module.output.LayerNorm.weight.data = transformer_ln_w
orig_module.output.LayerNorm.bias.data = transformer_ln_b
return orig_module
return _replace_module(model=model,
orig_class=deepspeed.DeepSpeedTransformerLayer,
replace_fn=replace_fn)
def revert_module(orig_module_impl,
model,
orig_module_config,
replacement_module_config):
""" Revert DeepSpeed's module back to original client module
Arguments:
orig_module_impl (torch.nn.Module): the original module that was replaced,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
orig_module_config (dict): original module configuration
replacement_module_config (dict): replacement deepspeed module configuration
Returns:
Updated nn.module with original bert-style transformer layers
"""
assert isinstance(replacement_module_config, DeepSpeedTransformerConfig), \
'Only DeepSpeedTransformerConfig is currently supported as replacement config'
return _revert_transformer_layer(orig_layer_impl=orig_module_impl,
model=model,
bert_config=orig_module_config,
transformer_config=replacement_module_config)
def _replace_module(model, orig_class, replace_fn):
""" Scan the model for instances of ``orig_clas:`` to replace using ``replace_fn``.
Arguments:
model (torch.nn.Module): the model to augment
orig_class (torch.nn.Module): the module to search for
replace_fn (method): a method to convert instances of ``orig_class`` to the
desired type and return a new instance.
Returns:
A modified ``model``.
"""
policy = {orig_class: replace_fn}
return _replace_module_using_policies(model, policy)
def _replace_module_using_policies(model, policies):
""" Traverse model's children recursively and apply any transformations in ``policies``.
Arguments:
model (torch.nn.Module): model to augment
policies (dict): Mapping of source class to replacement function.
Returns:
Modified ``model``.
"""
for name, child in model.named_children():
if child.__class__ in policies:
orig = repr(child)
setattr(model, name, policies[child.__class__](child))
new = getattr(model, name)
else:
_replace_module_using_policies(child, policies)
return model
......@@ -87,6 +87,10 @@ class DeepSpeedTransformerConfig(TransformerConfig):
that by enabling it, the pretraining tasks such as BERT are not affected and can obtain
a high accuracy level. On the other hand, for the downstream tasks, such as fine-tuning, we recommend
to turn it off in order to be able to reproduce the same result through the regular kernel execution.
huggingface: Enbale if using the HuggingFace interface style for sending out the forward results.
training: Enable for training rather than inference.
"""
def __init__(self,
batch_size=-1,
......@@ -105,7 +109,9 @@ class DeepSpeedTransformerConfig(TransformerConfig):
gelu_checkpoint=False,
adjust_init_range=True,
attn_dropout_checkpoint=False,
stochastic_mode=False):
stochastic_mode=False,
huggingface=False,
training=True):
super(DeepSpeedTransformerConfig,
self).__init__(
batch_size,
......@@ -124,10 +130,11 @@ class DeepSpeedTransformerConfig(TransformerConfig):
self.gelu_checkpoint = gelu_checkpoint # True: if higher batch size is required
self.adjust_init_range = adjust_init_range
self.test_gemm = False
self.training = True
self.training = training
self.is_grad_enabled = True
self.attn_dropout_checkpoint = attn_dropout_checkpoint
self.stochastic_mode = stochastic_mode
self.huggingface = huggingface
@classmethod
def from_dict(cls, json_object):
......@@ -252,7 +259,7 @@ class DeepSpeedTransformerFunction(Function):
norm_w.register_hook(lambda x, self=self: grads.append([x, "norm_W"]))
norm_b.register_hook(lambda x, self=self: grads.append([x, "norm_B"]))
if config.is_grad_enabled:
if config.is_grad_enabled and config.training:
if (config.pre_layer_norm and config.normalize_invertible):
ctx.save_for_backward(input_mask,
attn_qkvw,
......@@ -313,7 +320,11 @@ class DeepSpeedTransformerFunction(Function):
if inp_size[1] % 16 != 0:
output = torch.narrow(output, 1, 0, inp_size[1])
return output
if config.huggingface:
return (output, ) # outputs -> (output) : outputs[0] = output
else:
return output
@staticmethod
def backward(ctx, grad_output):
......@@ -412,6 +423,25 @@ class DeepSpeedTransformerFunction(Function):
norm_w,
norm_b)
# This appears to be an effective way to release context memory
ctx.qkv_tf = None
ctx.soft_inp = None
ctx.ctx_bufB = None
ctx.gelu_inp = None
ctx.ff2_inp = None
ctx.attn_o_inp = None
ctx.ff1_inp = None
ctx.add_res = None
ctx.inp_norm = None
ctx.config = None
ctx.attn_layer_norm_mean = None
ctx.layer_norm_mean = None
ctx.attn_prob_dropout_mask = None
ctx.attn_output_dropout_mask = None
ctx.layer_output_dropout_mask = None
ctx.attn_layer_norm_var = None
ctx.layer_norm_var = None
if grad_output_shape[1] % 16 != 0:
grad_input = torch.narrow(grad_input, 1, 0, grad_output_shape[1])
......@@ -438,21 +468,24 @@ class DeepSpeedTransformerFunction(Function):
class DeepSpeedTransformerLayer(nn.Module):
"""Initialize the DeepSpeed Transformer Layer.
Static variable:
layer_id: The layer-index counter starting from 0 and incrementing by 1 every time a layer object is instantiated,
e.g. if a model has 24 transformer layers, layer_id goes from 0 to 23.
Arguments:
layer_id: The layer index starting from 0, e.g. if model has 24 transformer layers,
layer_id will be 0,1,2...23 when each layer object is instantiated
config: An object of DeepSpeedTransformerConfig
initial_weights: Optional: Only used for unit test
initial_biases: Optional: Only used for unit test
"""
def __init__(self, layer_id, config, initial_weights=None, initial_biases=None):
layer_id = 0
def __init__(self, config, initial_weights=None, initial_biases=None):
super(DeepSpeedTransformerLayer, self).__init__()
self.config = config
self.config.layer_id = layer_id
self.config.layer_id = DeepSpeedTransformerLayer.layer_id
DeepSpeedTransformerLayer.layer_id = DeepSpeedTransformerLayer.layer_id + 1
print("DeepSpeed Transformer config is ", self.config.__dict__)
......@@ -548,11 +581,18 @@ class DeepSpeedTransformerLayer(nn.Module):
self.norm_w.data.fill_(1.0)
self.norm_b.data.zero_()
def forward(self, input, input_mask, grads=None):
def forward(self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
grads=None):
self.config.training = self.training
self.config.is_grad_enabled = torch.is_grad_enabled()
return DeepSpeedTransformerFunction.apply(input,
input_mask,
return DeepSpeedTransformerFunction.apply(hidden_states,
attention_mask,
self,
grads,
self.config.layer_id,
......
......@@ -530,21 +530,10 @@ def see_memory_usage(message):
# Print message except when distributed but not rank 0
logger.info(message)
logger.info(
"Memory Allocated %s GigaBytes ",
torch.cuda.memory_allocated() / (1024 * 1024 * 1024),
)
logger.info(
"Max Memory Allocated %s GigaBytes",
torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),
)
logger.info(
"Cache Allocated %s GigaBytes",
torch.cuda.memory_cached() / (1024 * 1024 * 1024),
)
logger.info(
"Max cache Allocated %s GigaBytes",
torch.cuda.max_memory_cached() / (1024 * 1024 * 1024),
)
f"MA {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),2 )} GB \
Max_MA {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \
CA {round(torch.cuda.memory_cached() / (1024 * 1024 * 1024),2)} GB \
Max_CA {round(torch.cuda.max_memory_cached() / (1024 * 1024 * 1024))} GB ")
def call_to_str(base, *args, **kwargs):
......
......@@ -284,10 +284,10 @@ transformer layers using DeepSpeed transformer kernel as below.
gelu_checkpoint=args.gelu_checkpoint,
stochastic_mode=True)
self.layer = nn.ModuleList([copy.deepcopy(DeepSpeedTransformerLayer(i, cuda_config)) for i in range(config.num_hidden_layers)])
layer = DeepSpeedTransformerLayer(cuda_config)
else:
layer = BertLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
```
All configuration settings come from the DeepSpeed configuration file and
command arguments and thus we must pass the `args` variable to here in this model.
......
......@@ -43,8 +43,8 @@ config = DeepSpeedTransformerConfig(batch_size = 64,
normalize_invertible=False,
gelu_checkpoint=False)
self.layer = nn.ModuleList([
copy.deepcopy(DeepSpeedTransformerLayer(i, cuda_config))
for i in range(config.num_hidden_layers)
copy.deepcopy(DeepSpeedTransformerLayer(cuda_config))
for _ in range(config.num_hidden_layers)
])
```
### Transformer kernel Parameters
......
......@@ -83,11 +83,10 @@ class DSEncoder(nn.Module):
super(DSEncoder, self).__init__()
self.FinalLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.layer = nn.ModuleList([
copy.deepcopy(DeepSpeedTransformerLayer(i,
config,
copy.deepcopy(DeepSpeedTransformerLayer(config,
weights,
biases))
for i in range(config.num_hidden_layers)
for _ in range(config.num_hidden_layers)
])
self.grads = []
self.pre_or_post = config.pre_layer_norm
......@@ -122,7 +121,9 @@ class DSEncoder(nn.Module):
# decoder layers
else:
for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states, attention_mask, self.grads)
hidden_states = layer_module(hidden_states,
attention_mask,
grads=self.grads)
hidden_states.register_hook(
lambda x,
self=self: self.grads.append([x,
......
......@@ -48,11 +48,10 @@ class DSEncoder(nn.Module):
super(DSEncoder, self).__init__()
self.FinalLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.layer = nn.ModuleList([
copy.deepcopy(DeepSpeedTransformerLayer(i,
config,
copy.deepcopy(DeepSpeedTransformerLayer(config,
weights,
biases))
for i in range(config.num_hidden_layers)
for _ in range(config.num_hidden_layers)
])
self.grads = []
self.pre_or_post = config.pre_layer_norm
......@@ -88,11 +87,6 @@ class DSEncoder(nn.Module):
else:
for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states, attention_mask)
hidden_states.register_hook(
lambda x,
i=i,
self=self: self.grads.append([x,
"hidden_state"]))
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
......@@ -103,9 +97,6 @@ class DSEncoder(nn.Module):
all_encoder_layers.append(hidden_states)
return all_encoder_layers
def get_grads(self):
return self.grads
def create_models(ds_config):
bert_config = BertConfig(vocab_size_or_config_json_file=119547,
......@@ -201,7 +192,7 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None):
output_all_encoded_layers=False,
checkpoint_activations=False)
# check grads
# check forward evaluation
check_equal(base_results, ds_results, atol=atol, verbose=verbose)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment