Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ad7233fc
Unverified
Commit
ad7233fc
authored
Mar 19, 2020
by
Sam Shleifer
Committed by
GitHub
Mar 19, 2020
Browse files
[BART] cleanup: remove redundant kwargs, improve docstrings (#3319)
parent
cd21d8bc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
80 deletions
+39
-80
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+28
-80
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+11
-0
No files found.
src/transformers/modeling_bart.py
View file @
ad7233fc
...
@@ -56,7 +56,7 @@ BART_GENERATION_EXAMPLE = r"""
...
@@ -56,7 +56,7 @@ BART_GENERATION_EXAMPLE = r"""
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
# Generate Summary
# Generate Summary
summary_ids = model.generate(inputs['input_ids'],
attention_mask=inputs['attention_mask'],
num_beams=4, max_length=5)
summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5
, early_stopping=True
)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
"""
"""
...
@@ -84,8 +84,9 @@ LARGE_NEGATIVE = -1e8
...
@@ -84,8 +84,9 @@ LARGE_NEGATIVE = -1e8
def
_prepare_bart_decoder_inputs
(
def
_prepare_bart_decoder_inputs
(
config
,
input_ids
,
decoder_input_ids
=
None
,
decoder_attn_mask
=
None
,
mask_dtype
=
None
,
config
,
input_ids
,
decoder_input_ids
=
None
,
decoder_attn_mask
=
None
,
mask_dtype
=
None
,
):
):
"""Prepare masks that ignore padding tokens decoder and a causal lm mask for the decoder if
"""Prepare masks that ignore padding tokens
in the
decoder and a causal lm mask for the decoder if
none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
Note: this is not called during generation
"""
"""
pad_token_id
=
config
.
pad_token_id
pad_token_id
=
config
.
pad_token_id
need_causal_mask
=
not
config
.
output_past
need_causal_mask
=
not
config
.
output_past
...
@@ -114,8 +115,6 @@ class PretrainedBartModel(PreTrainedModel):
...
@@ -114,8 +115,6 @@ class PretrainedBartModel(PreTrainedModel):
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
init_std
std
=
self
.
config
.
init_std
# called init_bert_params in fairseq
if
isinstance
(
module
,
nn
.
Linear
):
if
isinstance
(
module
,
nn
.
Linear
):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
if
module
.
bias
is
not
None
:
if
module
.
bias
is
not
None
:
...
@@ -127,16 +126,9 @@ class PretrainedBartModel(PreTrainedModel):
...
@@ -127,16 +126,9 @@ class PretrainedBartModel(PreTrainedModel):
@
property
@
property
def
dummy_inputs
(
self
):
def
dummy_inputs
(
self
):
pad_token
=
1
pad_token
=
self
.
config
.
pad_token_id
input_ids
=
torch
.
Tensor
(
input_ids
=
torch
.
tensor
([[
0
,
6
,
10
,
4
,
2
],
[
0
,
8
,
12
,
2
,
pad_token
]])
[
decoder_input_ids
,
decoder_attn_mask
=
_prepare_bart_decoder_inputs
(
self
.
config
,
input_ids
,)
[
0
,
31414
,
232
,
328
,
740
,
1140
,
12695
,
69
,
46078
,
1588
,
2
],
[
0
,
31414
,
232
,
328
,
740
,
1140
,
12695
,
69
,
46078
,
2
,
pad_token
],
]
).
long
()
decoder_input_ids
,
decoder_attn_mask
=
_prepare_bart_decoder_inputs
(
self
.
config
,
input_ids
,
attention_mask
=
None
,
decoder_input_ids
=
None
,
decoder_attn_mask
=
None
)
dummy_inputs
=
{
dummy_inputs
=
{
"decoder_input_ids"
:
decoder_input_ids
,
"decoder_input_ids"
:
decoder_input_ids
,
"attention_mask"
:
input_ids
.
ne
(
pad_token
),
"attention_mask"
:
input_ids
.
ne
(
pad_token
),
...
@@ -149,7 +141,7 @@ class PretrainedBartModel(PreTrainedModel):
...
@@ -149,7 +141,7 @@ class PretrainedBartModel(PreTrainedModel):
def
_make_linear_from_emb
(
emb
):
def
_make_linear_from_emb
(
emb
):
vocab_size
,
emb_size
=
emb
.
weight
.
shape
vocab_size
,
emb_size
=
emb
.
weight
.
shape
lin_layer
=
nn
.
Linear
(
vocab_size
,
emb_size
,
bias
=
False
)
lin_layer
=
nn
.
Linear
(
vocab_size
,
emb_size
,
bias
=
False
)
lin_layer
.
weight
.
data
=
emb
.
weight
.
data
# .T
lin_layer
.
weight
.
data
=
emb
.
weight
.
data
return
lin_layer
return
lin_layer
...
@@ -160,8 +152,8 @@ def _check_shapes(shape_1, shape2):
...
@@ -160,8 +152,8 @@ def _check_shapes(shape_1, shape2):
def
_combine_masks
(
key_padding_mask
,
causal_lm_mask
,
targ_size
):
def
_combine_masks
(
key_padding_mask
,
causal_lm_mask
,
targ_size
):
# targ_size =
(bsz, tgt_len, src_len)
"""Make one mask of shape
(bsz,
1,
tgt_len, src_len)
"""
a
=
torch
.
zeros
(
targ_size
)
a
=
torch
.
zeros
(
targ_size
)
# targ_size is(bsz, tgt_len, src_len)
b
=
torch
.
zeros
(
targ_size
)
b
=
torch
.
zeros
(
targ_size
)
if
key_padding_mask
is
not
None
:
# (bsz, tgt_len) -> targ_size
if
key_padding_mask
is
not
None
:
# (bsz, tgt_len) -> targ_size
_check_shapes
(
key_padding_mask
.
shape
,
targ_size
[:
2
])
_check_shapes
(
key_padding_mask
.
shape
,
targ_size
[:
2
])
...
@@ -223,7 +215,7 @@ class EncoderLayer(nn.Module):
...
@@ -223,7 +215,7 @@ class EncoderLayer(nn.Module):
encoded output of shape `(seq_len, batch, embed_dim)`
encoded output of shape `(seq_len, batch, embed_dim)`
"""
"""
residual
=
x
residual
=
x
x
,
attn_weights
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
key_padding_mask
=
encoder_padding_mask
,)
x
,
attn_weights
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
key_padding_mask
=
encoder_padding_mask
,)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
residual
+
x
x
=
self
.
self_attn_layer_norm
(
x
)
x
=
self
.
self_attn_layer_norm
(
x
)
...
@@ -266,7 +258,7 @@ class BartEncoder(nn.Module):
...
@@ -266,7 +258,7 @@ class BartEncoder(nn.Module):
self
.
layernorm_embedding
=
LayerNorm
(
embed_dim
)
self
.
layernorm_embedding
=
LayerNorm
(
embed_dim
)
def
forward
(
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
self
,
input_ids
,
attention_mask
=
None
,
):
):
"""
"""
Args:
Args:
...
@@ -274,21 +266,19 @@ class BartEncoder(nn.Module):
...
@@ -274,21 +266,19 @@ class BartEncoder(nn.Module):
`(batch, src_len)`
`(batch, src_len)`
attention_mask (torch.LongTensor): indicating which indices are padding tokens.
attention_mask (torch.LongTensor): indicating which indices are padding tokens.
Returns:
Returns:
namedtuple
:
Tuple comprised of
:
- **x** (Tensor): the last encoder layer's output of
- **x** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
shape `(src_len, batch, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *
return_all_hiddens
* is True.
Only populated if *
self.output_hidden_states:
* is True.
- **all_attentions** (List[Tensor]): Attention weights for each layer.
- **all_attentions** (List[Tensor]): Attention weights for each layer.
During training might not be of length n_layers because of layer dropout.
During training might not be of length n_layers because of layer dropout.
"""
"""
# check attention mask and invert
# check attention mask and invert
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
assert
attention_mask
.
dim
()
==
2
assert
attention_mask
.
dim
()
==
2
attention_mask
=
(
1.0
-
attention_mask
.
long
())
*
LARGE_NEGATIVE
attention_mask
=
(
1.0
-
attention_mask
.
long
())
*
-
10000.0
assert
attention_mask
.
max
()
<=
0
assert
attention_mask
.
max
()
<=
0
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
embed_pos
=
self
.
embed_positions
(
input_ids
)
embed_pos
=
self
.
embed_positions
(
input_ids
)
...
@@ -300,10 +290,7 @@ class BartEncoder(nn.Module):
...
@@ -300,10 +290,7 @@ class BartEncoder(nn.Module):
x
=
x
.
transpose
(
0
,
1
)
x
=
x
.
transpose
(
0
,
1
)
encoder_states
,
all_attentions
=
[],
[]
encoder_states
,
all_attentions
=
[],
[]
# encoder layers
for
encoder_layer
in
self
.
layers
:
for
encoder_layer
in
self
.
layers
:
if
self
.
output_hidden_states
:
if
self
.
output_hidden_states
:
encoder_states
.
append
(
x
)
encoder_states
.
append
(
x
)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
...
@@ -320,7 +307,6 @@ class BartEncoder(nn.Module):
...
@@ -320,7 +307,6 @@ class BartEncoder(nn.Module):
encoder_states
.
append
(
x
)
encoder_states
.
append
(
x
)
encoder_states
=
[
hidden_state
.
transpose
(
0
,
1
)
for
hidden_state
in
encoder_states
]
encoder_states
=
[
hidden_state
.
transpose
(
0
,
1
)
for
hidden_state
in
encoder_states
]
return
x
,
encoder_states
,
all_attentions
return
x
,
encoder_states
,
all_attentions
...
@@ -356,28 +342,12 @@ class DecoderLayer(nn.Module):
...
@@ -356,28 +342,12 @@ class DecoderLayer(nn.Module):
attention_mask
=
None
,
attention_mask
=
None
,
need_attn_weights
=
False
,
need_attn_weights
=
False
,
):
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attn_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
need_attn_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual
=
x
residual
=
x
y
=
x
# TODO(SS): figure out why fairseq did this, then hopefully delete it
if
layer_state
is
None
:
if
layer_state
is
None
:
layer_state
=
{}
layer_state
=
{}
# next line mutates layer state
# next line mutates layer state
x
,
self_attn_weights
=
self
.
self_attn
(
x
,
self_attn_weights
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
layer_state
=
layer_state
,
attn_mask
=
attention_mask
,)
query
=
x
,
key
=
y
,
value
=
y
,
layer_state
=
layer_state
,
attn_mask
=
attention_mask
,
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
residual
+
x
x
=
self
.
self_attn_layer_norm
(
x
)
x
=
self
.
self_attn_layer_norm
(
x
)
...
@@ -386,11 +356,9 @@ class DecoderLayer(nn.Module):
...
@@ -386,11 +356,9 @@ class DecoderLayer(nn.Module):
x
,
encoder_attn_weights
=
self
.
encoder_attn
(
x
,
encoder_attn_weights
=
self
.
encoder_attn
(
query
=
x
,
query
=
x
,
key
=
encoder_hidden_states
,
# could be None
key
=
encoder_hidden_states
,
value
=
encoder_hidden_states
,
key_padding_mask
=
encoder_attn_mask
,
key_padding_mask
=
encoder_attn_mask
,
layer_state
=
layer_state
,
# mutates layer state
layer_state
=
layer_state
,
# mutates layer state
static_kv
=
True
,
)
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
residual
+
x
...
@@ -527,19 +495,15 @@ class BartDecoder(nn.Module):
...
@@ -527,19 +495,15 @@ class BartDecoder(nn.Module):
return
x
,
next_cache
,
all_hidden_states
,
list
(
all_self_attns
)
return
x
,
next_cache
,
all_hidden_states
,
list
(
all_self_attns
)
def
reorder_attn_buffer
(
input_buffer
,
new_order
):
def
_reorder_buffer
(
attn_cache
,
new_order
):
"""Reorder buffered internal state (for incremental generation)."""
for
k
,
input_buffer_k
in
attn_cache
.
items
():
# input_buffer = self._get_input_buffer(incremental_state)
for
k
in
input_buffer
.
keys
():
input_buffer_k
=
input_buffer
[
k
]
if
input_buffer_k
is
not
None
:
if
input_buffer_k
is
not
None
:
input_buffer
[
k
]
=
input_buffer_k
.
index_select
(
0
,
new_order
)
attn_cache
[
k
]
=
input_buffer_k
.
index_select
(
0
,
new_order
)
# incremental_state = self._set_input_buffer(incremental_state, input_buffer)
return
attn_cache
return
input_buffer
class
SelfAttention
(
nn
.
Module
):
class
SelfAttention
(
nn
.
Module
):
"""Multi-headed attention from
"
Attention Is All You Need"""
"""Multi-headed attention from
'
Attention Is All You Need
' paper
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -551,7 +515,6 @@ class SelfAttention(nn.Module):
...
@@ -551,7 +515,6 @@ class SelfAttention(nn.Module):
):
):
super
().
__init__
()
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
dropout
=
dropout
self
.
dropout
=
dropout
self
.
head_dim
=
embed_dim
//
num_heads
self
.
head_dim
=
embed_dim
//
num_heads
...
@@ -572,42 +535,29 @@ class SelfAttention(nn.Module):
...
@@ -572,42 +535,29 @@ class SelfAttention(nn.Module):
self
,
self
,
query
,
query
,
key
:
Optional
[
Tensor
],
key
:
Optional
[
Tensor
],
value
:
Optional
[
Tensor
],
key_padding_mask
:
Optional
[
Tensor
]
=
None
,
key_padding_mask
:
Optional
[
Tensor
]
=
None
,
layer_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
layer_state
:
Optional
[
Dict
[
str
,
Optional
[
Tensor
]]]
=
None
,
static_kv
:
bool
=
False
,
attn_mask
:
Optional
[
Tensor
]
=
None
,
attn_mask
:
Optional
[
Tensor
]
=
None
,
)
->
Tuple
[
Tensor
,
Optional
[
Tensor
]]:
)
->
Tuple
[
Tensor
,
Optional
[
Tensor
]]:
"""Input shape: Time(SeqLen) x Batch x Channel
"""Input shape: Time(SeqLen) x Batch x Channel"""
static_kv
=
self
.
encoder_decoder_attention
# type: bool
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
"""
tgt_len
,
bsz
,
embed_dim
=
query
.
size
()
tgt_len
,
bsz
,
embed_dim
=
query
.
size
()
assert
embed_dim
==
self
.
embed_dim
assert
embed_dim
==
self
.
embed_dim
assert
list
(
query
.
size
())
==
[
tgt_len
,
bsz
,
embed_dim
]
assert
list
(
query
.
size
())
==
[
tgt_len
,
bsz
,
embed_dim
]
# get here for encoder decoder cause of static_kv
# get here for encoder decoder cause of static_kv
if
layer_state
is
not
None
:
#
get the last k,v and mask for reuse
if
layer_state
is
not
None
:
#
reuse k,v and encoder_padding_mask
saved_state
=
layer_state
.
get
(
self
.
cache_key
,
{})
saved_state
=
layer_state
.
get
(
self
.
cache_key
,
{})
if
"prev_key"
in
saved_state
:
if
"prev_key"
in
saved_state
:
# previous time steps are cached - no need to recompute key and value if they are static
# previous time steps are cached - no need to recompute key and value if they are static
if
static_kv
:
if
static_kv
:
assert
self
.
encoder_decoder_attention
key
=
None
key
=
value
=
None
else
:
else
:
saved_state
=
None
saved_state
=
None
layer_state
=
{}
layer_state
=
{}
q
=
self
.
q_proj
(
query
)
*
self
.
scaling
q
=
self
.
q_proj
(
query
)
*
self
.
scaling
if
s
elf
.
encoder_decoder_attention
:
if
s
tatic_kv
:
if
key
is
None
:
if
key
is
None
:
assert
value
is
None
k
=
v
=
None
k
=
v
=
None
else
:
else
:
k
=
self
.
k_proj
(
key
)
k
=
self
.
k_proj
(
key
)
...
@@ -624,7 +574,6 @@ class SelfAttention(nn.Module):
...
@@ -624,7 +574,6 @@ class SelfAttention(nn.Module):
if
saved_state
is
not
None
:
if
saved_state
is
not
None
:
k
,
v
,
key_padding_mask
=
self
.
_use_saved_state
(
k
,
v
,
saved_state
,
key_padding_mask
,
static_kv
,
bsz
)
k
,
v
,
key_padding_mask
=
self
.
_use_saved_state
(
k
,
v
,
saved_state
,
key_padding_mask
,
static_kv
,
bsz
)
# assert self.cache_key != 'encoder_decoder' or key_padding_mask is None
# Update cache
# Update cache
layer_state
[
self
.
cache_key
]
=
{
layer_state
[
self
.
cache_key
]
=
{
...
@@ -636,7 +585,6 @@ class SelfAttention(nn.Module):
...
@@ -636,7 +585,6 @@ class SelfAttention(nn.Module):
assert
k
is
not
None
assert
k
is
not
None
src_len
=
k
.
size
(
1
)
src_len
=
k
.
size
(
1
)
attn_weights
=
torch
.
bmm
(
q
,
k
.
transpose
(
1
,
2
))
attn_weights
=
torch
.
bmm
(
q
,
k
.
transpose
(
1
,
2
))
assert
attn_weights
.
size
()
==
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
assert
attn_weights
.
size
()
==
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
if
attn_mask
is
not
None
:
if
attn_mask
is
not
None
:
...
@@ -984,7 +932,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
...
@@ -984,7 +932,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
for
layer_past
in
decoder_cached_states
:
for
layer_past
in
decoder_cached_states
:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
layer_past_new
=
{
layer_past_new
=
{
attn_key
:
reorder_
attn_
buffer
(
attn_cache
,
beam_idx
)
for
attn_key
,
attn_cache
in
layer_past
.
items
()
attn_key
:
_
reorder_buffer
(
attn_cache
,
beam_idx
)
for
attn_key
,
attn_cache
in
layer_past
.
items
()
}
}
# reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
# reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
# reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
# reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
...
...
tests/test_modeling_bart.py
View file @
ad7233fc
...
@@ -330,6 +330,17 @@ class BartHeadTests(unittest.TestCase):
...
@@ -330,6 +330,17 @@ class BartHeadTests(unittest.TestCase):
lm_model
=
BartForConditionalGeneration
(
config
).
eval
().
to
(
torch_device
).
half
()
lm_model
=
BartForConditionalGeneration
(
config
).
eval
().
to
(
torch_device
).
half
()
lm_model
(
input_ids
,
attention_mask
=
attention_mask
)
lm_model
(
input_ids
,
attention_mask
=
attention_mask
)
def
test_default_generate_kwargs
(
self
):
config
,
input_ids
,
_
=
self
.
_get_config_and_data
(
output_past
=
True
)
model
=
BartForConditionalGeneration
(
config
).
eval
().
to
(
torch_device
)
model
.
generate
(
input_ids
)
model
.
generate
(
num_beams
=
4
,
do_sample
=
True
,
early_stopping
=
False
,
num_return_sequences
=
3
)
def
test_dummy_inputs
(
self
):
config
,
*
_
=
self
.
_get_config_and_data
(
output_past
=
True
)
model
=
BartForConditionalGeneration
(
config
).
eval
().
to
(
torch_device
)
model
(
**
model
.
dummy_inputs
)
def
test_prepare_bart_decoder_inputs
(
self
):
def
test_prepare_bart_decoder_inputs
(
self
):
config
,
*
_
=
self
.
_get_config_and_data
(
output_past
=
False
)
config
,
*
_
=
self
.
_get_config_and_data
(
output_past
=
False
)
input_ids
=
_long_tensor
(([
4
,
4
,
2
]))
# only used for .device if decoder_input_ids is passed
input_ids
=
_long_tensor
(([
4
,
4
,
2
]))
# only used for .device if decoder_input_ids is passed
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment