Commit af9500dc authored by Michael Wu's avatar Michael Wu Committed by Facebook Github Bot
Browse files

Add option to freeze transformer params for fine-tuning

Summary: add flags to freeze embedding parameters and transformer layer parameters in `TransformerSentenceEncoder`.

Reviewed By: myleott

Differential Revision: D15866135

fbshipit-source-id: e634d7adfd5e81eacccf2b9cf6bc15bad30bd1fe
parent 461a366d
......@@ -5,12 +5,11 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from typing import Tuple, Optional
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.modules import (
LayerNorm,
MultiheadAttention,
......@@ -84,11 +83,13 @@ class TransformerSentenceEncoder(nn.Module):
offset_positions_by_padding: bool = True,
encoder_normalize_before: bool = False,
apply_bert_init: bool = False,
activation_fn: str = 'relu',
activation_fn: str = "relu",
learned_pos_embedding: bool = True,
add_bias_kv: bool = False,
add_zero_attn: bool = False,
embed_scale: float = None,
freeze_embeddings: bool = False,
n_trans_layers_to_freeze: int = 0,
export: bool = False,
) -> None:
......@@ -104,7 +105,7 @@ class TransformerSentenceEncoder(nn.Module):
self.learned_pos_embedding = learned_pos_embedding
self.embed_tokens = nn.Embedding(
self.vocab_size, self.embedding_dim, self.padding_idx,
self.vocab_size, self.embedding_dim, self.padding_idx
)
self.embed_scale = embed_scale
......@@ -118,10 +119,7 @@ class TransformerSentenceEncoder(nn.Module):
PositionalEmbedding(
self.max_seq_len,
self.embedding_dim,
padding_idx=(
self.padding_idx if offset_positions_by_padding
else None
),
padding_idx=(self.padding_idx if offset_positions_by_padding else None),
learned=self.learned_pos_embedding,
)
if self.use_position_embeddings
......@@ -155,12 +153,26 @@ class TransformerSentenceEncoder(nn.Module):
if self.apply_bert_init:
self.apply(init_bert_params)
def freeze_module_params(m):
if m is not None:
for p in m.parameters():
p.requires_grad = False
if freeze_embeddings:
freeze_module_params(self.embed_tokens)
freeze_module_params(self.segment_embeddings)
freeze_module_params(self.embed_positions)
freeze_module_params(self.emb_layer_norm)
for layer in range(n_trans_layers_to_freeze):
freeze_module_params(self.layers[layer])
def forward(
self,
tokens: torch.Tensor,
segment_labels: torch.Tensor,
last_state_only: bool = False,
positions: Optional[torch.Tensor] = None
positions: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# compute padding mask. This is needed for multi-head attention
......@@ -186,7 +198,7 @@ class TransformerSentenceEncoder(nn.Module):
# account for padding while computing the representation
if padding_mask is not None:
x *= (1 - padding_mask.unsqueeze(-1).type_as(x))
x *= 1 - padding_mask.unsqueeze(-1).type_as(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
......@@ -196,10 +208,7 @@ class TransformerSentenceEncoder(nn.Module):
inner_states.append(x)
for layer in self.layers:
x, _ = layer(
x,
self_attn_padding_mask=padding_mask,
)
x, _ = layer(x, self_attn_padding_mask=padding_mask)
if not last_state_only:
inner_states.append(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment