Commit bc3d72d1 authored by silencealiang's avatar silencealiang
Browse files

fix bug

parent bcd9442e
...@@ -8,7 +8,8 @@ def transformer_block_init_wrapper(fn): ...@@ -8,7 +8,8 @@ def transformer_block_init_wrapper(fn):
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block # mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config = args[0] if len(args) > 1 else kwargs['config'] config = args[0] if len(args) > 1 else kwargs['config']
if getattr(config, "mtp_num_layers", 0) > 0: mtp_layers = getattr(config, "mtp_num_layers", None)
if isinstance(mtp_layers, int) and mtp_layers > 0:
self.main_final_layernorm = self.final_layernorm self.main_final_layernorm = self.final_layernorm
self.final_layernorm = None self.final_layernorm = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment