Commit af9500dc authored by Michael Wu's avatar Michael Wu Committed by Facebook Github Bot
Browse files

Add option to freeze transformer params for fine-tuning

Summary: add flags to freeze embedding parameters and transformer layer parameters in `TransformerSentenceEncoder`.

Reviewed By: myleott

Differential Revision: D15866135

fbshipit-source-id: e634d7adfd5e81eacccf2b9cf6bc15bad30bd1fe
parent 461a366d
...@@ -5,12 +5,11 @@ ...@@ -5,12 +5,11 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from typing import Tuple, Optional from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.modules import ( from fairseq.modules import (
LayerNorm, LayerNorm,
MultiheadAttention, MultiheadAttention,
...@@ -84,11 +83,13 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -84,11 +83,13 @@ class TransformerSentenceEncoder(nn.Module):
offset_positions_by_padding: bool = True, offset_positions_by_padding: bool = True,
encoder_normalize_before: bool = False, encoder_normalize_before: bool = False,
apply_bert_init: bool = False, apply_bert_init: bool = False,
activation_fn: str = 'relu', activation_fn: str = "relu",
learned_pos_embedding: bool = True, learned_pos_embedding: bool = True,
add_bias_kv: bool = False, add_bias_kv: bool = False,
add_zero_attn: bool = False, add_zero_attn: bool = False,
embed_scale: float = None, embed_scale: float = None,
freeze_embeddings: bool = False,
n_trans_layers_to_freeze: int = 0,
export: bool = False, export: bool = False,
) -> None: ) -> None:
...@@ -104,7 +105,7 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -104,7 +105,7 @@ class TransformerSentenceEncoder(nn.Module):
self.learned_pos_embedding = learned_pos_embedding self.learned_pos_embedding = learned_pos_embedding
self.embed_tokens = nn.Embedding( self.embed_tokens = nn.Embedding(
self.vocab_size, self.embedding_dim, self.padding_idx, self.vocab_size, self.embedding_dim, self.padding_idx
) )
self.embed_scale = embed_scale self.embed_scale = embed_scale
...@@ -118,10 +119,7 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -118,10 +119,7 @@ class TransformerSentenceEncoder(nn.Module):
PositionalEmbedding( PositionalEmbedding(
self.max_seq_len, self.max_seq_len,
self.embedding_dim, self.embedding_dim,
padding_idx=( padding_idx=(self.padding_idx if offset_positions_by_padding else None),
self.padding_idx if offset_positions_by_padding
else None
),
learned=self.learned_pos_embedding, learned=self.learned_pos_embedding,
) )
if self.use_position_embeddings if self.use_position_embeddings
...@@ -155,12 +153,26 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -155,12 +153,26 @@ class TransformerSentenceEncoder(nn.Module):
if self.apply_bert_init: if self.apply_bert_init:
self.apply(init_bert_params) self.apply(init_bert_params)
def freeze_module_params(m):
if m is not None:
for p in m.parameters():
p.requires_grad = False
if freeze_embeddings:
freeze_module_params(self.embed_tokens)
freeze_module_params(self.segment_embeddings)
freeze_module_params(self.embed_positions)
freeze_module_params(self.emb_layer_norm)
for layer in range(n_trans_layers_to_freeze):
freeze_module_params(self.layers[layer])
def forward( def forward(
self, self,
tokens: torch.Tensor, tokens: torch.Tensor,
segment_labels: torch.Tensor, segment_labels: torch.Tensor,
last_state_only: bool = False, last_state_only: bool = False,
positions: Optional[torch.Tensor] = None positions: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# compute padding mask. This is needed for multi-head attention # compute padding mask. This is needed for multi-head attention
...@@ -186,7 +198,7 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -186,7 +198,7 @@ class TransformerSentenceEncoder(nn.Module):
# account for padding while computing the representation # account for padding while computing the representation
if padding_mask is not None: if padding_mask is not None:
x *= (1 - padding_mask.unsqueeze(-1).type_as(x)) x *= 1 - padding_mask.unsqueeze(-1).type_as(x)
# B x T x C -> T x B x C # B x T x C -> T x B x C
x = x.transpose(0, 1) x = x.transpose(0, 1)
...@@ -196,10 +208,7 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -196,10 +208,7 @@ class TransformerSentenceEncoder(nn.Module):
inner_states.append(x) inner_states.append(x)
for layer in self.layers: for layer in self.layers:
x, _ = layer( x, _ = layer(x, self_attn_padding_mask=padding_mask)
x,
self_attn_padding_mask=padding_mask,
)
if not last_state_only: if not last_state_only:
inner_states.append(x) inner_states.append(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment