"IPXE/git@developer.sourcefind.cn:dadigang/Ventoy.git" did not exist on "fd0d335eb69950ceaed69adeac72064987cd79b9"
Unverified Commit 684196b8 authored by Kiarash Jamali's avatar Kiarash Jamali Committed by GitHub
Browse files

Allow rotary embeddings for Bert (#363)

parent cbf982af
...@@ -52,10 +52,16 @@ logger = logging.getLogger(__name__) ...@@ -52,10 +52,16 @@ logger = logging.getLogger(__name__)
def create_mixer_cls(config, cross_attn=False, return_residual=False): def create_mixer_cls(config, cross_attn=False, return_residual=False):
use_flash_attn = getattr(config, 'use_flash_attn', False) use_flash_attn = getattr(config, 'use_flash_attn', False)
fused_bias_fc = getattr(config, 'fused_bias_fc', False) fused_bias_fc = getattr(config, 'fused_bias_fc', False)
rotary_kwargs = {}
if config.position_embedding_type == "rotary":
rotary_kwargs["rotary_emb_dim"] = getattr(config, "rotary_emb_dim", config.hidden_size)
rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
rotary_kwargs["rotary_emb_scale_base"] = getattr(config, "rotary_emb_scale_base", None)
rotary_kwargs["rotary_emb_interleaved"] = getattr(config, "rotary_emb_interleaved", False)
mixer_cls = partial(MHA, num_heads=config.num_attention_heads, cross_attn=cross_attn, mixer_cls = partial(MHA, num_heads=config.num_attention_heads, cross_attn=cross_attn,
dropout=config.attention_probs_dropout_prob, causal=False, dropout=config.attention_probs_dropout_prob, causal=False,
fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn, fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn,
return_residual=return_residual) return_residual=return_residual, **rotary_kwargs)
return mixer_cls return mixer_cls
...@@ -298,7 +304,6 @@ class BertModel(BertPreTrainedModel): ...@@ -298,7 +304,6 @@ class BertModel(BertPreTrainedModel):
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False) self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
if self.fused_dropout_add_ln and layer_norm is None: if self.fused_dropout_add_ln and layer_norm is None:
raise ImportError('dropout_add_layer_norm is not installed') raise ImportError('dropout_add_layer_norm is not installed')
assert config.position_embedding_type == 'absolute'
assert config.hidden_act in ['gelu', 'gelu_new', 'gelu_fast'] assert config.hidden_act in ['gelu', 'gelu_new', 'gelu_fast']
self.embeddings = BertEmbeddings(config.hidden_size, config.vocab_size, self.embeddings = BertEmbeddings(config.hidden_size, config.vocab_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment