"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "fd3de2000fc087cf04361ea8d295f7554566854a"
Unverified Commit a139d1a1 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[cleanup] consolidate some prune_heads logic (#4799)

parent 4c7f564f
...@@ -26,7 +26,7 @@ from torch.nn import CrossEntropyLoss, MSELoss ...@@ -26,7 +26,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
from .configuration_albert import AlbertConfig from .configuration_albert import AlbertConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_bert import ACT2FN, BertEmbeddings, BertSelfAttention, prune_linear_layer from .modeling_bert import ACT2FN, BertEmbeddings, BertSelfAttention, prune_linear_layer
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -199,14 +199,9 @@ class AlbertAttention(BertSelfAttention): ...@@ -199,14 +199,9 @@ class AlbertAttention(BertSelfAttention):
def prune_heads(self, heads): def prune_heads(self, heads):
if len(heads) == 0: if len(heads) == 0:
return return
mask = torch.ones(self.num_attention_heads, self.attention_head_size) heads, index = find_pruneable_heads_and_indices(
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads heads, self.num_attention_heads, self.attention_head_size, self.pruned_heads
for head in heads: )
# Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
# Prune linear layers # Prune linear layers
self.query = prune_linear_layer(self.query, index) self.query = prune_linear_layer(self.query, index)
......
...@@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss, MSELoss ...@@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
from .activations import gelu, gelu_new, swish from .activations import gelu, gelu_new, swish
from .configuration_bert import BertConfig from .configuration_bert import BertConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import PreTrainedModel, prune_linear_layer from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -284,14 +284,9 @@ class BertAttention(nn.Module): ...@@ -284,14 +284,9 @@ class BertAttention(nn.Module):
def prune_heads(self, heads): def prune_heads(self, heads):
if len(heads) == 0: if len(heads) == 0:
return return
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) heads, index = find_pruneable_heads_and_indices(
heads = set(heads) - self.pruned_heads # Convert to set and remove already pruned heads heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
for head in heads: )
# Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
# Prune linear layers # Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index) self.self.query = prune_linear_layer(self.self.query, index)
......
...@@ -31,7 +31,7 @@ from torch.nn import CrossEntropyLoss ...@@ -31,7 +31,7 @@ from torch.nn import CrossEntropyLoss
from .activations import gelu from .activations import gelu
from .configuration_distilbert import DistilBertConfig from .configuration_distilbert import DistilBertConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import PreTrainedModel, prune_linear_layer from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -120,13 +120,7 @@ class MultiHeadSelfAttention(nn.Module): ...@@ -120,13 +120,7 @@ class MultiHeadSelfAttention(nn.Module):
attention_head_size = self.dim // self.n_heads attention_head_size = self.dim // self.n_heads
if len(heads) == 0: if len(heads) == 0:
return return
mask = torch.ones(self.n_heads, attention_head_size) heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads)
heads = set(heads) - self.pruned_heads
for head in heads:
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
# Prune linear layers # Prune linear layers
self.q_lin = prune_linear_layer(self.q_lin, index) self.q_lin = prune_linear_layer(self.q_lin, index)
self.k_lin = prune_linear_layer(self.k_lin, index) self.k_lin = prune_linear_layer(self.k_lin, index)
......
...@@ -27,7 +27,13 @@ from torch.nn import CrossEntropyLoss ...@@ -27,7 +27,13 @@ from torch.nn import CrossEntropyLoss
from .activations import ACT2FN from .activations import ACT2FN
from .configuration_gpt2 import GPT2Config from .configuration_gpt2 import GPT2Config
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer from .modeling_utils import (
Conv1D,
PreTrainedModel,
SequenceSummary,
find_pruneable_heads_and_indices,
prune_conv1d_layer,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -122,14 +128,9 @@ class Attention(nn.Module): ...@@ -122,14 +128,9 @@ class Attention(nn.Module):
def prune_heads(self, heads): def prune_heads(self, heads):
if len(heads) == 0: if len(heads) == 0:
return return
mask = torch.ones(self.n_head, self.split_size // self.n_head) heads, index = find_pruneable_heads_and_indices(
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
for head in heads: )
# Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
# Prune conv1d layers # Prune conv1d layers
......
...@@ -29,7 +29,13 @@ from torch.nn import CrossEntropyLoss ...@@ -29,7 +29,13 @@ from torch.nn import CrossEntropyLoss
from .activations import gelu_new, swish from .activations import gelu_new, swish
from .configuration_openai import OpenAIGPTConfig from .configuration_openai import OpenAIGPTConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer from .modeling_utils import (
Conv1D,
PreTrainedModel,
SequenceSummary,
find_pruneable_heads_and_indices,
prune_conv1d_layer,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -142,13 +148,9 @@ class Attention(nn.Module): ...@@ -142,13 +148,9 @@ class Attention(nn.Module):
def prune_heads(self, heads): def prune_heads(self, heads):
if len(heads) == 0: if len(heads) == 0:
return return
mask = torch.ones(self.n_head, self.split_size // self.n_head) heads, index = find_pruneable_heads_and_indices(
heads = set(heads) - self.pruned_heads heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
for head in heads: )
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
# Prune conv1d layers # Prune conv1d layers
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
......
...@@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss ...@@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss
from .configuration_t5 import T5Config from .configuration_t5 import T5Config
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import PreTrainedModel, prune_linear_layer from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -216,13 +216,7 @@ class T5Attention(nn.Module): ...@@ -216,13 +216,7 @@ class T5Attention(nn.Module):
def prune_heads(self, heads): def prune_heads(self, heads):
if len(heads) == 0: if len(heads) == 0:
return return
mask = torch.ones(self.n_heads, self.d_kv) heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, self.d_kv, self.pruned_heads)
heads = set(heads) - self.pruned_heads
for head in heads:
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
# Prune linear layers # Prune linear layers
self.q = prune_linear_layer(self.q, index) self.q = prune_linear_layer(self.q, index)
self.k = prune_linear_layer(self.k, index) self.k = prune_linear_layer(self.k, index)
......
...@@ -55,6 +55,20 @@ except ImportError: ...@@ -55,6 +55,20 @@ except ImportError:
return input return input
def find_pruneable_heads_and_indices(
heads: List, n_heads: int, head_size: int, already_pruned_heads: set
) -> Tuple[set, "torch.LongTensor"]:
mask = torch.ones(n_heads, head_size)
heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
for head in heads:
# Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index: torch.LongTensor = torch.arange(len(mask))[mask].long()
return heads, index
class ModuleUtilsMixin: class ModuleUtilsMixin:
""" """
A few utilities for torch.nn.Modules, to be used as a mixin. A few utilities for torch.nn.Modules, to be used as a mixin.
......
...@@ -29,7 +29,13 @@ from torch.nn import functional as F ...@@ -29,7 +29,13 @@ from torch.nn import functional as F
from .activations import gelu from .activations import gelu
from .configuration_xlm import XLMConfig from .configuration_xlm import XLMConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import PreTrainedModel, SequenceSummary, SQuADHead, prune_linear_layer from .modeling_utils import (
PreTrainedModel,
SequenceSummary,
SQuADHead,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -105,13 +111,7 @@ class MultiHeadAttention(nn.Module): ...@@ -105,13 +111,7 @@ class MultiHeadAttention(nn.Module):
attention_head_size = self.dim // self.n_heads attention_head_size = self.dim // self.n_heads
if len(heads) == 0: if len(heads) == 0:
return return
mask = torch.ones(self.n_heads, attention_head_size) heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads)
heads = set(heads) - self.pruned_heads
for head in heads:
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
# Prune linear layers # Prune linear layers
self.q_lin = prune_linear_layer(self.q_lin, index) self.q_lin = prune_linear_layer(self.q_lin, index)
self.k_lin = prune_linear_layer(self.k_lin, index) self.k_lin = prune_linear_layer(self.k_lin, index)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment