Commit fa2ccbc0 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Fix E266 flake8 warning (x90).

parent 2ab78325
...@@ -487,7 +487,7 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -487,7 +487,7 @@ def evaluate(args, model, tokenizer, prefix=""):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--train_file", default=None, type=str, required=True, help="SWAG csv for training. E.g., train.csv" "--train_file", default=None, type=str, required=True, help="SWAG csv for training. E.g., train.csv"
) )
...@@ -520,7 +520,7 @@ def main(): ...@@ -520,7 +520,7 @@ def main():
help="The output directory where the model checkpoints and predictions will be written.", help="The output directory where the model checkpoints and predictions will be written.",
) )
## Other parameters # Other parameters
parser.add_argument( parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
) )
......
...@@ -430,7 +430,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -430,7 +430,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json" "--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json"
) )
...@@ -486,7 +486,7 @@ def main(): ...@@ -486,7 +486,7 @@ def main():
"--temperature", default=2.0, type=float, help="Distillation temperature. Only for distillation." "--temperature", default=2.0, type=float, help="Distillation temperature. Only for distillation."
) )
## Other parameters # Other parameters
parser.add_argument( parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
) )
......
...@@ -43,7 +43,7 @@ if __name__ == "__main__": ...@@ -43,7 +43,7 @@ if __name__ == "__main__":
state_dict = model.state_dict() state_dict = model.state_dict()
compressed_sd = {} compressed_sd = {}
### Embeddings ### # Embeddings #
if args.model_type == "gpt2": if args.model_type == "gpt2":
for param_name in ["wte.weight", "wpe.weight"]: for param_name in ["wte.weight", "wpe.weight"]:
compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"] compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"]
...@@ -55,7 +55,7 @@ if __name__ == "__main__": ...@@ -55,7 +55,7 @@ if __name__ == "__main__":
param_name = f"{prefix}.embeddings.LayerNorm.{w}" param_name = f"{prefix}.embeddings.LayerNorm.{w}"
compressed_sd[param_name] = state_dict[param_name] compressed_sd[param_name] = state_dict[param_name]
### Transformer Blocks ### # Transformer Blocks #
std_idx = 0 std_idx = 0
for teacher_idx in [0, 2, 4, 7, 9, 11]: for teacher_idx in [0, 2, 4, 7, 9, 11]:
if args.model_type == "gpt2": if args.model_type == "gpt2":
...@@ -82,7 +82,7 @@ if __name__ == "__main__": ...@@ -82,7 +82,7 @@ if __name__ == "__main__":
] ]
std_idx += 1 std_idx += 1
### Language Modeling Head ###s # Language Modeling Head ###s
if args.model_type == "roberta": if args.model_type == "roberta":
for layer in ["lm_head.decoder.weight", "lm_head.bias"]: for layer in ["lm_head.decoder.weight", "lm_head.bias"]:
compressed_sd[f"{layer}"] = state_dict[f"{layer}"] compressed_sd[f"{layer}"] = state_dict[f"{layer}"]
......
...@@ -219,7 +219,7 @@ def main(): ...@@ -219,7 +219,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
sanity_checks(args) sanity_checks(args)
## ARGS ## # ARGS #
init_gpu_params(args) init_gpu_params(args)
set_seed(args) set_seed(args)
if args.is_master: if args.is_master:
...@@ -236,7 +236,7 @@ def main(): ...@@ -236,7 +236,7 @@ def main():
os.makedirs(args.dump_path) os.makedirs(args.dump_path)
logger.info(f"Experiment will be dumped and logged in {args.dump_path}") logger.info(f"Experiment will be dumped and logged in {args.dump_path}")
### SAVE PARAMS ### # SAVE PARAMS #
logger.info(f"Param: {args}") logger.info(f"Param: {args}")
with open(os.path.join(args.dump_path, "parameters.json"), "w") as f: with open(os.path.join(args.dump_path, "parameters.json"), "w") as f:
json.dump(vars(args), f, indent=4) json.dump(vars(args), f, indent=4)
...@@ -245,7 +245,7 @@ def main(): ...@@ -245,7 +245,7 @@ def main():
student_config_class, student_model_class, _ = MODEL_CLASSES[args.student_type] student_config_class, student_model_class, _ = MODEL_CLASSES[args.student_type]
teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[args.teacher_type] teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[args.teacher_type]
### TOKENIZER ### # TOKENIZER #
tokenizer = teacher_tokenizer_class.from_pretrained(args.teacher_name) tokenizer = teacher_tokenizer_class.from_pretrained(args.teacher_name)
special_tok_ids = {} special_tok_ids = {}
for tok_name, tok_symbol in tokenizer.special_tokens_map.items(): for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
...@@ -255,7 +255,7 @@ def main(): ...@@ -255,7 +255,7 @@ def main():
args.special_tok_ids = special_tok_ids args.special_tok_ids = special_tok_ids
args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name] args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name]
## DATA LOADER ## # DATA LOADER #
logger.info(f"Loading data from {args.data_file}") logger.info(f"Loading data from {args.data_file}")
with open(args.data_file, "rb") as fp: with open(args.data_file, "rb") as fp:
data = pickle.load(fp) data = pickle.load(fp)
...@@ -275,7 +275,7 @@ def main(): ...@@ -275,7 +275,7 @@ def main():
train_lm_seq_dataset = LmSeqsDataset(params=args, data=data) train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
logger.info(f"Data loader created.") logger.info(f"Data loader created.")
## STUDENT ## # STUDENT #
logger.info(f"Loading student config from {args.student_config}") logger.info(f"Loading student config from {args.student_config}")
stu_architecture_config = student_config_class.from_pretrained(args.student_config) stu_architecture_config = student_config_class.from_pretrained(args.student_config)
stu_architecture_config.output_hidden_states = True stu_architecture_config.output_hidden_states = True
...@@ -290,26 +290,26 @@ def main(): ...@@ -290,26 +290,26 @@ def main():
student.to(f"cuda:{args.local_rank}") student.to(f"cuda:{args.local_rank}")
logger.info(f"Student loaded.") logger.info(f"Student loaded.")
## TEACHER ## # TEACHER #
teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True) teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True)
if args.n_gpu > 0: if args.n_gpu > 0:
teacher.to(f"cuda:{args.local_rank}") teacher.to(f"cuda:{args.local_rank}")
logger.info(f"Teacher loaded from {args.teacher_name}.") logger.info(f"Teacher loaded from {args.teacher_name}.")
## FREEZING ## # FREEZING #
if args.freeze_pos_embs: if args.freeze_pos_embs:
freeze_pos_embeddings(student, args) freeze_pos_embeddings(student, args)
if args.freeze_token_type_embds: if args.freeze_token_type_embds:
freeze_token_type_embeddings(student, args) freeze_token_type_embeddings(student, args)
## SANITY CHECKS ## # SANITY CHECKS #
assert student.config.vocab_size == teacher.config.vocab_size assert student.config.vocab_size == teacher.config.vocab_size
assert student.config.hidden_size == teacher.config.hidden_size assert student.config.hidden_size == teacher.config.hidden_size
assert student.config.max_position_embeddings == teacher.config.max_position_embeddings assert student.config.max_position_embeddings == teacher.config.max_position_embeddings
if args.mlm: if args.mlm:
assert token_probs.size(0) == stu_architecture_config.vocab_size assert token_probs.size(0) == stu_architecture_config.vocab_size
## DISTILLER ## # DISTILLER #
torch.cuda.empty_cache() torch.cuda.empty_cache()
distiller = Distiller( distiller = Distiller(
params=args, dataset=train_lm_seq_dataset, token_probs=token_probs, student=student, teacher=teacher params=args, dataset=train_lm_seq_dataset, token_probs=token_probs, student=student, teacher=teacher
......
...@@ -344,7 +344,7 @@ def load_examples(args, tokenizer, evaluate=False): ...@@ -344,7 +344,7 @@ def load_examples(args, tokenizer, evaluate=False):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--data_dir", "--data_dir",
default=None, default=None,
...@@ -374,7 +374,7 @@ def main(): ...@@ -374,7 +374,7 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.", help="The output directory where the model predictions and checkpoints will be written.",
) )
## Other parameters # Other parameters
parser.add_argument( parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
) )
......
...@@ -242,7 +242,7 @@ def prune_heads(args, model, eval_dataloader, head_mask): ...@@ -242,7 +242,7 @@ def prune_heads(args, model, eval_dataloader, head_mask):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--data_dir", "--data_dir",
default=None, default=None,
...@@ -272,7 +272,7 @@ def main(): ...@@ -272,7 +272,7 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.", help="The output directory where the model predictions and checkpoints will be written.",
) )
## Other parameters # Other parameters
parser.add_argument( parser.add_argument(
"--config_name", "--config_name",
default="", default="",
......
...@@ -410,7 +410,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): ...@@ -410,7 +410,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--data_dir", "--data_dir",
default=None, default=None,
...@@ -447,7 +447,7 @@ def main(): ...@@ -447,7 +447,7 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.", help="The output directory where the model predictions and checkpoints will be written.",
) )
## Other parameters # Other parameters
parser.add_argument( parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
) )
......
...@@ -422,7 +422,7 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -422,7 +422,7 @@ def evaluate(args, model, tokenizer, prefix=""):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file)." "--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file)."
) )
...@@ -434,7 +434,7 @@ def main(): ...@@ -434,7 +434,7 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.", help="The output directory where the model predictions and checkpoints will be written.",
) )
## Other parameters # Other parameters
parser.add_argument( parser.add_argument(
"--eval_data_file", "--eval_data_file",
default=None, default=None,
......
...@@ -385,7 +385,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False, test=False): ...@@ -385,7 +385,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False, test=False):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--data_dir", "--data_dir",
default=None, default=None,
...@@ -422,7 +422,7 @@ def main(): ...@@ -422,7 +422,7 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.", help="The output directory where the model predictions and checkpoints will be written.",
) )
## Other parameters # Other parameters
parser.add_argument( parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
) )
......
...@@ -385,7 +385,7 @@ def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode): ...@@ -385,7 +385,7 @@ def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--data_dir", "--data_dir",
default=None, default=None,
...@@ -415,7 +415,7 @@ def main(): ...@@ -415,7 +415,7 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.", help="The output directory where the model predictions and checkpoints will be written.",
) )
## Other parameters # Other parameters
parser.add_argument( parser.add_argument(
"--labels", "--labels",
default="", default="",
......
...@@ -377,7 +377,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): ...@@ -377,7 +377,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--data_dir", "--data_dir",
default=None, default=None,
...@@ -417,7 +417,7 @@ def main(): ...@@ -417,7 +417,7 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.", help="The output directory where the model predictions and checkpoints will be written.",
) )
## Other parameters # Other parameters
parser.add_argument( parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
) )
......
...@@ -401,7 +401,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -401,7 +401,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json" "--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json"
) )
...@@ -434,7 +434,7 @@ def main(): ...@@ -434,7 +434,7 @@ def main():
help="The output directory where the model checkpoints and predictions will be written.", help="The output directory where the model checkpoints and predictions will be written.",
) )
## Other parameters # Other parameters
parser.add_argument( parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
) )
......
...@@ -43,7 +43,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du ...@@ -43,7 +43,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
) )
......
...@@ -43,7 +43,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pyt ...@@ -43,7 +43,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pyt
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
) )
......
...@@ -43,7 +43,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor ...@@ -43,7 +43,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
) )
......
...@@ -51,7 +51,7 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p ...@@ -51,7 +51,7 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--gpt2_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." "--gpt2_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
) )
......
...@@ -51,7 +51,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c ...@@ -51,7 +51,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--openai_checkpoint_folder_path", "--openai_checkpoint_folder_path",
default=None, default=None,
......
...@@ -410,7 +410,7 @@ def convert_all_pt_checkpoints_to_tf( ...@@ -410,7 +410,7 @@ def convert_all_pt_checkpoints_to_tf(
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file." "--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file."
) )
......
...@@ -94,7 +94,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ ...@@ -94,7 +94,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
layer: BertLayer = model.roberta.encoder.layer[i] layer: BertLayer = model.roberta.encoder.layer[i]
roberta_layer: TransformerSentenceEncoderLayer = roberta_sent_encoder.layers[i] roberta_layer: TransformerSentenceEncoderLayer = roberta_sent_encoder.layers[i]
### self attention # self attention
self_attn: BertSelfAttention = layer.attention.self self_attn: BertSelfAttention = layer.attention.self
assert ( assert (
roberta_layer.self_attn.k_proj.weight.data.shape roberta_layer.self_attn.k_proj.weight.data.shape
...@@ -110,7 +110,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ ...@@ -110,7 +110,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
self_attn.value.weight.data = roberta_layer.self_attn.v_proj.weight self_attn.value.weight.data = roberta_layer.self_attn.v_proj.weight
self_attn.value.bias.data = roberta_layer.self_attn.v_proj.bias self_attn.value.bias.data = roberta_layer.self_attn.v_proj.bias
### self-attention output # self-attention output
self_output: BertSelfOutput = layer.attention.output self_output: BertSelfOutput = layer.attention.output
assert self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape assert self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape
self_output.dense.weight = roberta_layer.self_attn.out_proj.weight self_output.dense.weight = roberta_layer.self_attn.out_proj.weight
...@@ -118,20 +118,20 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ ...@@ -118,20 +118,20 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight
self_output.LayerNorm.bias = roberta_layer.self_attn_layer_norm.bias self_output.LayerNorm.bias = roberta_layer.self_attn_layer_norm.bias
### intermediate # intermediate
intermediate: BertIntermediate = layer.intermediate intermediate: BertIntermediate = layer.intermediate
assert intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape assert intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape
intermediate.dense.weight = roberta_layer.fc1.weight intermediate.dense.weight = roberta_layer.fc1.weight
intermediate.dense.bias = roberta_layer.fc1.bias intermediate.dense.bias = roberta_layer.fc1.bias
### output # output
bert_output: BertOutput = layer.output bert_output: BertOutput = layer.output
assert bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape assert bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape
bert_output.dense.weight = roberta_layer.fc2.weight bert_output.dense.weight = roberta_layer.fc2.weight
bert_output.dense.bias = roberta_layer.fc2.bias bert_output.dense.bias = roberta_layer.fc2.bias
bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight
bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias
#### end of layer # end of layer
if classification_head: if classification_head:
model.classifier.dense.weight = roberta.model.classification_heads["mnli"].dense.weight model.classifier.dense.weight = roberta.model.classification_heads["mnli"].dense.weight
...@@ -170,7 +170,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ ...@@ -170,7 +170,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--roberta_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump." "--roberta_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
) )
......
...@@ -43,7 +43,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du ...@@ -43,7 +43,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment