Commit 5ad1d93f authored by Zihan Wang's avatar Zihan Wang
Browse files

fix bug

parent fd1528b1
...@@ -227,11 +227,10 @@ class SentencePredictionTask(base_task.Task): ...@@ -227,11 +227,10 @@ class SentencePredictionTask(base_task.Task):
"""Load a pretrained checkpoint (if exists) and then train from iter 0.""" """Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file = self.task_config.init_checkpoint ckpt_dir_or_file = self.task_config.init_checkpoint
if self.task_config.initial_parameters_from_pk: if self.task_config.initial_parameters_from_pk:
num_layers = self.task_config.model.encoder.num_layers num_layers = self.task_config.model.encoder.any.num_layers
num_attention_heads = self.task_config.model.encoder.num_attention_heads num_attention_heads = self.task_config.model.encoder.any.num_attention_heads
hidden_size = self.task_config.model.encoder.hidden_size hidden_size = self.task_config.model.encoder.any.hidden_size
inner_dim = self.task_config.model.encoder.inner_dim head_size = hidden_size // num_attention_heads
head_size = hidden_size / num_attention_heads
assert head_size * num_attention_heads == hidden_size assert head_size * num_attention_heads == hidden_size
encoder = model.checkpoint_items['encoder'] encoder = model.checkpoint_items['encoder']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment