Commit 856d8b82 authored by Xian Li's avatar Xian Li Committed by Facebook Github Bot
Browse files

layer drop

Summary: This diff enables layer drop in transformer decoder in production training pipeline (ptt_transformer). It builds on top of the fairseq implementation D18094657 added by Angela Fan, and added additional logic to handle corresponding dropping layers at test time in exported model.

Reviewed By: jhcross

Differential Revision: D18165586

fbshipit-source-id: 373ac00268a25fa9e412edcb483becdfe792d992
parent 50cf3bb5
...@@ -345,7 +345,7 @@ def prune_state_dict(state_dict, args): ...@@ -345,7 +345,7 @@ def prune_state_dict(state_dict, args):
It's called by functions that load models from checkpoints and does not It's called by functions that load models from checkpoints and does not
need to be called directly. need to be called directly.
""" """
if not args: if not args or args.arch == "ptt_transformer":
# args should not be none, but don't crash if it is. # args should not be none, but don't crash if it is.
return state_dict return state_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment