Commit 856d8b82 authored by Xian Li's avatar Xian Li Committed by Facebook Github Bot
Browse files

layer drop

Summary: This diff enables layer drop in transformer decoder in production training pipeline (ptt_transformer). It builds on top of the fairseq implementation D18094657 added by Angela Fan, and added additional logic to handle corresponding dropping layers at test time in exported model.

Reviewed By: jhcross

Differential Revision: D18165586

fbshipit-source-id: 373ac00268a25fa9e412edcb483becdfe792d992
parent 50cf3bb5
......@@ -345,7 +345,7 @@ def prune_state_dict(state_dict, args):
It's called by functions that load models from checkpoints and does not
need to be called directly.
"""
if not args:
if not args or args.arch == "ptt_transformer":
# args should not be none, but don't crash if it is.
return state_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment