Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ad7233fc
Unverified
Commit
ad7233fc
authored
Mar 19, 2020
by
Sam Shleifer
Committed by
GitHub
Mar 19, 2020
Browse files
[BART] cleanup: remove redundant kwargs, improve docstrings (#3319)
parent
cd21d8bc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
80 deletions
+39
-80
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+28
-80
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+11
-0
No files found.
src/transformers/modeling_bart.py
View file @
ad7233fc
...
...
@@ -56,7 +56,7 @@ BART_GENERATION_EXAMPLE = r"""
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
# Generate Summary
summary_ids = model.generate(inputs['input_ids'],
attention_mask=inputs['attention_mask'],
num_beams=4, max_length=5)
summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5
, early_stopping=True
)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
"""
...
...
@@ -84,8 +84,9 @@ LARGE_NEGATIVE = -1e8
def
_prepare_bart_decoder_inputs
(
config
,
input_ids
,
decoder_input_ids
=
None
,
decoder_attn_mask
=
None
,
mask_dtype
=
None
,
):
"""Prepare masks that ignore padding tokens decoder and a causal lm mask for the decoder if
"""Prepare masks that ignore padding tokens
in the
decoder and a causal lm mask for the decoder if
none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
Note: this is not called during generation
"""
pad_token_id
=
config
.
pad_token_id
need_causal_mask
=
not
config
.
output_past
...
...
@@ -114,8 +115,6 @@ class PretrainedBartModel(PreTrainedModel):
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
init_std
# called init_bert_params in fairseq
if
isinstance
(
module
,
nn
.
Linear
):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
if
module
.
bias
is
not
None
:
...
...
@@ -127,16 +126,9 @@ class PretrainedBartModel(PreTrainedModel):
@
property
def
dummy_inputs
(
self
):
pad_token
=
1
input_ids
=
torch
.
Tensor
(
[
[
0
,
31414
,
232
,
328
,
740
,
1140
,
12695
,
69
,
46078
,
1588
,
2
],
[
0
,
31414
,
232
,
328
,
740
,
1140
,
12695
,
69
,
46078
,
2
,
pad_token
],
]
).
long
()
decoder_input_ids
,
decoder_attn_mask
=
_prepare_bart_decoder_inputs
(
self
.
config
,
input_ids
,
attention_mask
=
None
,
decoder_input_ids
=
None
,
decoder_attn_mask
=
None
)
pad_token
=
self
.
config
.
pad_token_id
input_ids
=
torch
.
tensor
([[
0
,
6
,
10
,
4
,
2
],
[
0
,
8
,
12
,
2
,
pad_token
]])
decoder_input_ids
,
decoder_attn_mask
=
_prepare_bart_decoder_inputs
(
self
.
config
,
input_ids
,)
dummy_inputs
=
{
"decoder_input_ids"
:
decoder_input_ids
,
"attention_mask"
:
input_ids
.
ne
(
pad_token
),
...
...
@@ -149,7 +141,7 @@ class PretrainedBartModel(PreTrainedModel):
def
_make_linear_from_emb
(
emb
):
vocab_size
,
emb_size
=
emb
.
weight
.
shape
lin_layer
=
nn
.
Linear
(
vocab_size
,
emb_size
,
bias
=
False
)
lin_layer
.
weight
.
data
=
emb
.
weight
.
data
# .T
lin_layer
.
weight
.
data
=
emb
.
weight
.
data
return
lin_layer
...
...
@@ -160,8 +152,8 @@ def _check_shapes(shape_1, shape2):
def
_combine_masks
(
key_padding_mask
,
causal_lm_mask
,
targ_size
):
# targ_size =
(bsz, tgt_len, src_len)
a
=
torch
.
zeros
(
targ_size
)
"""Make one mask of shape
(bsz,
1,
tgt_len, src_len)
"""
a
=
torch
.
zeros
(
targ_size
)
# targ_size is(bsz, tgt_len, src_len)
b
=
torch
.
zeros
(
targ_size
)
if
key_padding_mask
is
not
None
:
# (bsz, tgt_len) -> targ_size
_check_shapes
(
key_padding_mask
.
shape
,
targ_size
[:
2
])
...
...
@@ -223,7 +215,7 @@ class EncoderLayer(nn.Module):
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual
=
x
x
,
attn_weights
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
key_padding_mask
=
encoder_padding_mask
,)
x
,
attn_weights
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
key_padding_mask
=
encoder_padding_mask
,)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
self_attn_layer_norm
(
x
)
...
...
@@ -266,7 +258,7 @@ class BartEncoder(nn.Module):
self
.
layernorm_embedding
=
LayerNorm
(
embed_dim
)
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
self
,
input_ids
,
attention_mask
=
None
,
):
"""
Args:
...
...
@@ -274,21 +266,19 @@ class BartEncoder(nn.Module):
`(batch, src_len)`
attention_mask (torch.LongTensor): indicating which indices are padding tokens.
Returns:
namedtuple
:
Tuple comprised of
:
- **x** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *
return_all_hiddens
* is True.
Only populated if *
self.output_hidden_states:
* is True.
- **all_attentions** (List[Tensor]): Attention weights for each layer.
During training might not be of length n_layers because of layer dropout.
"""
# check attention mask and invert
if
attention_mask
is
not
None
:
assert
attention_mask
.
dim
()
==
2
attention_mask
=
(
1.0
-
attention_mask
.
long
())
*
-
10000.0
attention_mask
=
(
1.0
-
attention_mask
.
long
())
*
LARGE_NEGATIVE
assert
attention_mask
.
max
()
<=
0
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
embed_pos
=
self
.
embed_positions
(
input_ids
)
...
...
@@ -300,10 +290,7 @@ class BartEncoder(nn.Module):
x
=
x
.
transpose
(
0
,
1
)
encoder_states
,
all_attentions
=
[],
[]
# encoder layers
for
encoder_layer
in
self
.
layers
:
if
self
.
output_hidden_states
:
encoder_states
.
append
(
x
)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
...
...
@@ -320,7 +307,6 @@ class BartEncoder(nn.Module):
encoder_states
.
append
(
x
)
encoder_states
=
[
hidden_state
.
transpose
(
0
,
1
)
for
hidden_state
in
encoder_states
]
return
x
,
encoder_states
,
all_attentions
...
...
@@ -356,28 +342,12 @@ class DecoderLayer(nn.Module):
attention_mask
=
None
,
need_attn_weights
=
False
,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attn_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
need_attn_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual
=
x
y
=
x
# TODO(SS): figure out why fairseq did this, then hopefully delete it
if
layer_state
is
None
:
layer_state
=
{}
# next line mutates layer state
x
,
self_attn_weights
=
self
.
self_attn
(
query
=
x
,
key
=
y
,
value
=
y
,
layer_state
=
layer_state
,
attn_mask
=
attention_mask
,
)
x
,
self_attn_weights
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
layer_state
=
layer_state
,
attn_mask
=
attention_mask
,)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
self_attn_layer_norm
(
x
)
...
...
@@ -386,11 +356,9 @@ class DecoderLayer(nn.Module):
x
,
encoder_attn_weights
=
self
.
encoder_attn
(
query
=
x
,
key
=
encoder_hidden_states
,
# could be None
value
=
encoder_hidden_states
,
key
=
encoder_hidden_states
,
key_padding_mask
=
encoder_attn_mask
,
layer_state
=
layer_state
,
# mutates layer state
static_kv
=
True
,
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
...
...
@@ -527,19 +495,15 @@ class BartDecoder(nn.Module):
return
x
,
next_cache
,
all_hidden_states
,
list
(
all_self_attns
)
def
reorder_attn_buffer
(
input_buffer
,
new_order
):
"""Reorder buffered internal state (for incremental generation)."""
# input_buffer = self._get_input_buffer(incremental_state)
for
k
in
input_buffer
.
keys
():
input_buffer_k
=
input_buffer
[
k
]
def
_reorder_buffer
(
attn_cache
,
new_order
):
for
k
,
input_buffer_k
in
attn_cache
.
items
():
if
input_buffer_k
is
not
None
:
input_buffer
[
k
]
=
input_buffer_k
.
index_select
(
0
,
new_order
)
# incremental_state = self._set_input_buffer(incremental_state, input_buffer)
return
input_buffer
attn_cache
[
k
]
=
input_buffer_k
.
index_select
(
0
,
new_order
)
return
attn_cache
class
SelfAttention
(
nn
.
Module
):
"""Multi-headed attention from
"
Attention Is All You Need"""
"""Multi-headed attention from
'
Attention Is All You Need
' paper
"""
def
__init__
(
self
,
...
...
@@ -551,7 +515,6 @@ class SelfAttention(nn.Module):
):
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
dropout
=
dropout
self
.
head_dim
=
embed_dim
//
num_heads
...
...
@@ -572,42 +535,29 @@ class SelfAttention(nn.Module):
self
,
query
,
key
:
Optional
[
Tensor
],
value
:
Optional
[
Tensor
],
key_padding_mask
:
Optional
[
Tensor
]
=
None
,
layer_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
static_kv
:
bool
=
False
,
layer_state
:
Optional
[
Dict
[
str
,
Optional
[
Tensor
]]]
=
None
,
attn_mask
:
Optional
[
Tensor
]
=
None
,
)
->
Tuple
[
Tensor
,
Optional
[
Tensor
]]:
"""Input shape: Time(SeqLen) x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
"""
"""Input shape: Time(SeqLen) x Batch x Channel"""
static_kv
=
self
.
encoder_decoder_attention
# type: bool
tgt_len
,
bsz
,
embed_dim
=
query
.
size
()
assert
embed_dim
==
self
.
embed_dim
assert
list
(
query
.
size
())
==
[
tgt_len
,
bsz
,
embed_dim
]
# get here for encoder decoder cause of static_kv
if
layer_state
is
not
None
:
#
get the last k,v and mask for reuse
if
layer_state
is
not
None
:
#
reuse k,v and encoder_padding_mask
saved_state
=
layer_state
.
get
(
self
.
cache_key
,
{})
if
"prev_key"
in
saved_state
:
# previous time steps are cached - no need to recompute key and value if they are static
if
static_kv
:
assert
self
.
encoder_decoder_attention
key
=
value
=
None
key
=
None
else
:
saved_state
=
None
layer_state
=
{}
q
=
self
.
q_proj
(
query
)
*
self
.
scaling
if
s
elf
.
encoder_decoder_attention
:
if
s
tatic_kv
:
if
key
is
None
:
assert
value
is
None
k
=
v
=
None
else
:
k
=
self
.
k_proj
(
key
)
...
...
@@ -624,7 +574,6 @@ class SelfAttention(nn.Module):
if
saved_state
is
not
None
:
k
,
v
,
key_padding_mask
=
self
.
_use_saved_state
(
k
,
v
,
saved_state
,
key_padding_mask
,
static_kv
,
bsz
)
# assert self.cache_key != 'encoder_decoder' or key_padding_mask is None
# Update cache
layer_state
[
self
.
cache_key
]
=
{
...
...
@@ -636,7 +585,6 @@ class SelfAttention(nn.Module):
assert
k
is
not
None
src_len
=
k
.
size
(
1
)
attn_weights
=
torch
.
bmm
(
q
,
k
.
transpose
(
1
,
2
))
assert
attn_weights
.
size
()
==
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
if
attn_mask
is
not
None
:
...
...
@@ -984,7 +932,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
for
layer_past
in
decoder_cached_states
:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
layer_past_new
=
{
attn_key
:
reorder_
attn_
buffer
(
attn_cache
,
beam_idx
)
for
attn_key
,
attn_cache
in
layer_past
.
items
()
attn_key
:
_
reorder_buffer
(
attn_cache
,
beam_idx
)
for
attn_key
,
attn_cache
in
layer_past
.
items
()
}
# reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
# reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
...
...
tests/test_modeling_bart.py
View file @
ad7233fc
...
...
@@ -330,6 +330,17 @@ class BartHeadTests(unittest.TestCase):
lm_model
=
BartForConditionalGeneration
(
config
).
eval
().
to
(
torch_device
).
half
()
lm_model
(
input_ids
,
attention_mask
=
attention_mask
)
def
test_default_generate_kwargs
(
self
):
config
,
input_ids
,
_
=
self
.
_get_config_and_data
(
output_past
=
True
)
model
=
BartForConditionalGeneration
(
config
).
eval
().
to
(
torch_device
)
model
.
generate
(
input_ids
)
model
.
generate
(
num_beams
=
4
,
do_sample
=
True
,
early_stopping
=
False
,
num_return_sequences
=
3
)
def
test_dummy_inputs
(
self
):
config
,
*
_
=
self
.
_get_config_and_data
(
output_past
=
True
)
model
=
BartForConditionalGeneration
(
config
).
eval
().
to
(
torch_device
)
model
(
**
model
.
dummy_inputs
)
def
test_prepare_bart_decoder_inputs
(
self
):
config
,
*
_
=
self
.
_get_config_and_data
(
output_past
=
False
)
input_ids
=
_long_tensor
(([
4
,
4
,
2
]))
# only used for .device if decoder_input_ids is passed
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment