Commit c14a2227 authored by Lysandre's avatar Lysandre Committed by Lysandre Debut
Browse files

ALBERT passes all tests

parent 870320a2
......@@ -7,7 +7,7 @@ class AlbertConfig(PretrainedConfig):
"""
def __init__(self,
vocab_size_or_config_json_file,
vocab_size_or_config_json_file=30000,
embedding_size=128,
hidden_size=4096,
num_hidden_layers=12,
......@@ -15,7 +15,6 @@ class AlbertConfig(PretrainedConfig):
num_attention_heads=64,
intermediate_size=16384,
inner_group_num=1,
down_scale_factor=1,
hidden_act="gelu_new",
hidden_dropout_prob=0,
attention_probs_dropout_prob=0,
......@@ -61,7 +60,6 @@ class AlbertConfig(PretrainedConfig):
self.num_hidden_groups = num_hidden_groups
self.num_attention_heads = num_attention_heads
self.inner_group_num = inner_group_num
self.down_scale_factor = down_scale_factor
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
......
......@@ -202,9 +202,6 @@ class AlbertLayerGroup(nn.Module):
layer_attentions = ()
for albert_layer in self.albert_layers:
if self.output_hidden_states:
layer_hidden_states = layer_hidden_states + (hidden_states,)
layer_output = albert_layer(hidden_states, attention_mask, head_mask)
hidden_states = layer_output[0]
......@@ -247,7 +244,7 @@ class AlbertTransformer(nn.Module):
hidden_states = layer_group_output[0]
if self.output_attentions:
all_attentions = all_attentions + layer_group_output[1]
all_attentions = all_attentions + layer_group_output[-1]
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
......
......@@ -22,7 +22,7 @@ from transformers.tokenization_albert import (AlbertTokenizer, SPIECE_UNDERLINE)
from .tokenization_tests_commons import CommonTestCases
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'fixtures/30k-clean.model')
'fixtures/spiece.model')
class AlbertTokenizationTest(CommonTestCases.CommonTokenizerTester):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment