Commit f1b01874 authored by Shijie Wu's avatar Shijie Wu
Browse files

Add use_lang_emb to config

parent e85123d3
...@@ -114,6 +114,7 @@ class XLMConfig(PretrainedConfig): ...@@ -114,6 +114,7 @@ class XLMConfig(PretrainedConfig):
causal=False, causal=False,
asm=False, asm=False,
n_langs=1, n_langs=1,
use_lang_emb=True,
max_position_embeddings=512, max_position_embeddings=512,
embed_init_std=2048 ** -0.5, embed_init_std=2048 ** -0.5,
layer_norm_eps=1e-12, layer_norm_eps=1e-12,
...@@ -157,6 +158,7 @@ class XLMConfig(PretrainedConfig): ...@@ -157,6 +158,7 @@ class XLMConfig(PretrainedConfig):
self.causal = causal self.causal = causal
self.asm = asm self.asm = asm
self.n_langs = n_langs self.n_langs = n_langs
self.use_lang_emb = use_lang_emb
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.bos_index = bos_index self.bos_index = bos_index
self.eos_index = eos_index self.eos_index = eos_index
...@@ -488,7 +490,7 @@ class XLMModel(XLMPreTrainedModel): ...@@ -488,7 +490,7 @@ class XLMModel(XLMPreTrainedModel):
""" """
ATTRIBUTES = ['encoder', 'eos_index', 'pad_index', # 'with_output', ATTRIBUTES = ['encoder', 'eos_index', 'pad_index', # 'with_output',
'n_langs', 'n_words', 'dim', 'n_layers', 'n_heads', 'n_langs', 'use_lang_emb', 'n_words', 'dim', 'n_layers', 'n_heads',
'hidden_dim', 'dropout', 'attention_dropout', 'asm', 'hidden_dim', 'dropout', 'attention_dropout', 'asm',
'asm_cutoffs', 'asm_div_value'] 'asm_cutoffs', 'asm_div_value']
...@@ -507,6 +509,7 @@ class XLMModel(XLMPreTrainedModel): ...@@ -507,6 +509,7 @@ class XLMModel(XLMPreTrainedModel):
# dictionary / languages # dictionary / languages
self.n_langs = config.n_langs self.n_langs = config.n_langs
self.use_lang_emb = config.use_lang_emb
self.n_words = config.n_words self.n_words = config.n_words
self.eos_index = config.eos_index self.eos_index = config.eos_index
self.pad_index = config.pad_index self.pad_index = config.pad_index
...@@ -529,7 +532,7 @@ class XLMModel(XLMPreTrainedModel): ...@@ -529,7 +532,7 @@ class XLMModel(XLMPreTrainedModel):
self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim) self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim)
if config.sinusoidal_embeddings: if config.sinusoidal_embeddings:
create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight) create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
if config.n_langs > 1: if config.n_langs > 1 and config.use_lang_emb:
self.lang_embeddings = nn.Embedding(self.n_langs, self.dim) self.lang_embeddings = nn.Embedding(self.n_langs, self.dim)
self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index) self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps) self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
...@@ -628,7 +631,7 @@ class XLMModel(XLMPreTrainedModel): ...@@ -628,7 +631,7 @@ class XLMModel(XLMPreTrainedModel):
# embeddings # embeddings
tensor = self.embeddings(input_ids) tensor = self.embeddings(input_ids)
tensor = tensor + self.position_embeddings(position_ids).expand_as(tensor) tensor = tensor + self.position_embeddings(position_ids).expand_as(tensor)
if langs is not None: if langs is not None and self.use_lang_emb:
tensor = tensor + self.lang_embeddings(langs) tensor = tensor + self.lang_embeddings(langs)
if token_type_ids is not None: if token_type_ids is not None:
tensor = tensor + self.embeddings(token_type_ids) tensor = tensor + self.embeddings(token_type_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment