Commit bc3d72d1 authored by silencealiang's avatar silencealiang
Browse files

fix bug

parent bcd9442e
......@@ -8,7 +8,8 @@ def transformer_block_init_wrapper(fn):
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config = args[0] if len(args) > 1 else kwargs['config']
if getattr(config, "mtp_num_layers", 0) > 0:
mtp_layers = getattr(config, "mtp_num_layers", None)
if isinstance(mtp_layers, int) and mtp_layers > 0:
self.main_final_layernorm = self.final_layernorm
self.final_layernorm = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment