Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7dcd8703
Unverified
Commit
7dcd8703
authored
Mar 27, 2023
by
Joao Gante
Committed by
GitHub
Mar 27, 2023
Browse files
Generate: support for left-padding on GPTNeoX and Llama (#22382)
parent
5506d049
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
108 additions
and
124 deletions
+108
-124
src/transformers/models/gpt_neox/modeling_gpt_neox.py
src/transformers/models/gpt_neox/modeling_gpt_neox.py
+56
-24
src/transformers/models/gptj/modeling_gptj.py
src/transformers/models/gptj/modeling_gptj.py
+1
-1
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+50
-98
tests/models/gpt_neox/test_modeling_gpt_neox.py
tests/models/gpt_neox/test_modeling_gpt_neox.py
+1
-1
No files found.
src/transformers/models/gpt_neox/modeling_gpt_neox.py
View file @
7dcd8703
...
@@ -100,12 +100,13 @@ class GPTNeoXAttention(nn.Module):
...
@@ -100,12 +100,13 @@ class GPTNeoXAttention(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
,
hidden_states
:
torch
.
FloatTensor
,
attention_mask
,
attention_mask
:
torch
.
FloatTensor
,
head_mask
=
None
,
position_ids
:
torch
.
LongTensor
,
layer_past
=
None
,
head_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
use_cache
=
False
,
layer_past
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
=
False
,
use_cache
:
Optional
[
bool
]
=
False
,
output_attentions
:
Optional
[
bool
]
=
False
,
):
):
has_layer_past
=
layer_past
is
not
None
has_layer_past
=
layer_past
is
not
None
...
@@ -132,12 +133,10 @@ class GPTNeoXAttention(nn.Module):
...
@@ -132,12 +133,10 @@ class GPTNeoXAttention(nn.Module):
# Compute token offset for rotary embeddings (when decoding)
# Compute token offset for rotary embeddings (when decoding)
seq_len
=
key
.
shape
[
-
2
]
seq_len
=
key
.
shape
[
-
2
]
offset
=
0
if
has_layer_past
:
if
has_layer_past
:
offset
=
layer_past
[
0
].
shape
[
-
2
]
seq_len
+=
layer_past
[
0
].
shape
[
-
2
]
seq_len
+=
offset
cos
,
sin
=
self
.
rotary_emb
(
value
,
seq_len
=
seq_len
)
cos
,
sin
=
self
.
rotary_emb
(
value
,
seq_len
=
seq_len
)
query
,
key
=
apply_rotary_pos_emb
(
query_rot
,
key_rot
,
cos
,
sin
,
offset
=
offset
)
query
,
key
=
apply_rotary_pos_emb
(
query_rot
,
key_rot
,
cos
,
sin
,
position_ids
)
query
=
torch
.
cat
((
query
,
query_pass
),
dim
=-
1
)
query
=
torch
.
cat
((
query
,
query_pass
),
dim
=-
1
)
key
=
torch
.
cat
((
key
,
key_pass
),
dim
=-
1
)
key
=
torch
.
cat
((
key
,
key_pass
),
dim
=-
1
)
...
@@ -275,9 +274,11 @@ def rotate_half(x):
...
@@ -275,9 +274,11 @@ def rotate_half(x):
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
offset
:
int
=
0
):
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
position_ids
):
cos
=
cos
[...,
offset
:
q
.
shape
[
-
2
]
+
offset
,
:]
gather_indices
=
position_ids
[:,
None
,
:,
None
]
# [bs, 1, seq_len, 1]
sin
=
sin
[...,
offset
:
q
.
shape
[
-
2
]
+
offset
,
:]
gather_indices
=
gather_indices
.
repeat
(
1
,
cos
.
shape
[
1
],
1
,
cos
.
shape
[
3
])
cos
=
torch
.
gather
(
cos
.
repeat
(
gather_indices
.
shape
[
0
],
1
,
1
,
1
),
2
,
gather_indices
)
sin
=
torch
.
gather
(
sin
.
repeat
(
gather_indices
.
shape
[
0
],
1
,
1
,
1
),
2
,
gather_indices
)
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
return
q_embed
,
k_embed
...
@@ -308,16 +309,18 @@ class GPTNeoXLayer(nn.Module):
...
@@ -308,16 +309,18 @@ class GPTNeoXLayer(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
,
hidden_states
:
Optional
[
torch
.
FloatTensor
],
attention_mask
=
None
,
attention_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
head_mask
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
=
False
,
head_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
layer_past
=
None
,
use_cache
:
Optional
[
bool
]
=
False
,
output_attentions
=
False
,
layer_past
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
Optional
[
bool
]
=
False
,
):
):
attention_layer_outputs
=
self
.
attention
(
attention_layer_outputs
=
self
.
attention
(
self
.
input_layernorm
(
hidden_states
),
self
.
input_layernorm
(
hidden_states
),
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
layer_past
=
layer_past
,
layer_past
=
layer_past
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
use_cache
=
use_cache
,
use_cache
=
use_cache
,
...
@@ -374,6 +377,11 @@ GPT_NEOX_INPUTS_DOCSTRING = r"""
...
@@ -374,6 +377,11 @@ GPT_NEOX_INPUTS_DOCSTRING = r"""
- 0 for tokens that are **masked**.
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
...
@@ -430,6 +438,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
...
@@ -430,6 +438,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
self
,
self
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
...
@@ -467,7 +476,17 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
...
@@ -467,7 +476,17 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
batch_size
,
seq_length
=
input_shape
batch_size
,
seq_length
=
input_shape
if
past_key_values
is
None
:
if
past_key_values
is
None
:
past_length
=
0
past_key_values
=
tuple
([
None
]
*
self
.
config
.
num_hidden_layers
)
past_key_values
=
tuple
([
None
]
*
self
.
config
.
num_hidden_layers
)
else
:
past_length
=
past_key_values
[
0
][
0
].
size
(
-
2
)
if
position_ids
is
None
:
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
position_ids
=
torch
.
arange
(
past_length
,
seq_length
+
past_length
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
view
(
-
1
,
seq_length
)
else
:
position_ids
=
position_ids
.
view
(
-
1
,
seq_length
).
long
()
# Attention mask.
# Attention mask.
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
...
@@ -527,12 +546,14 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
...
@@ -527,12 +546,14 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
create_custom_forward
(
layer
),
create_custom_forward
(
layer
),
hidden_states
,
hidden_states
,
attention_mask
,
attention_mask
,
position_ids
,
head_mask
[
i
],
head_mask
[
i
],
)
)
else
:
else
:
outputs
=
layer
(
outputs
=
layer
(
hidden_states
,
hidden_states
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
head_mask
=
head_mask
[
i
],
head_mask
=
head_mask
[
i
],
layer_past
=
layer_past
,
layer_past
=
layer_past
,
use_cache
=
use_cache
,
use_cache
=
use_cache
,
...
@@ -587,6 +608,7 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
...
@@ -587,6 +608,7 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
self
,
self
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
...
@@ -640,6 +662,7 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
...
@@ -640,6 +662,7 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
outputs
=
self
.
gpt_neox
(
outputs
=
self
.
gpt_neox
(
input_ids
,
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
...
@@ -672,20 +695,29 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
...
@@ -672,20 +695,29 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
attentions
=
outputs
.
attentions
,
attentions
=
outputs
.
attentions
,
)
)
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past_key_values
=
None
,
attention_mask
=
None
,
**
model_
kwargs
):
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past_key_values
=
None
,
attention_mask
=
None
,
**
kwargs
):
input_shape
=
input_ids
.
shape
input_shape
=
input_ids
.
shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if
attention_mask
is
None
:
attention_mask
=
input_ids
.
new_ones
(
input_shape
)
# cut decoder_input_ids if past is used
# cut decoder_input_ids if past is used
if
past_key_values
and
past_key_values
[
0
]
is
not
None
:
if
past_key_values
and
past_key_values
[
0
]
is
not
None
:
input_ids
=
input_ids
[:,
-
1
:]
input_ids
=
input_ids
[:,
-
1
:]
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
if
attention_mask
is
not
None
and
position_ids
is
None
:
# create position_ids on the fly for batch generation
position_ids
=
attention_mask
.
long
().
cumsum
(
-
1
)
-
1
position_ids
.
masked_fill_
(
attention_mask
==
0
,
1
)
if
past_key_values
:
position_ids
=
position_ids
[:,
-
1
].
unsqueeze
(
-
1
)
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if
attention_mask
is
None
:
attention_mask
=
input_ids
.
new_ones
(
input_shape
)
return
{
return
{
"input_ids"
:
input_ids
,
"input_ids"
:
input_ids
,
"attention_mask"
:
attention_mask
,
"attention_mask"
:
attention_mask
,
"position_ids"
:
position_ids
,
"past_key_values"
:
past_key_values
,
"past_key_values"
:
past_key_values
,
}
}
...
...
src/transformers/models/gptj/modeling_gptj.py
View file @
7dcd8703
...
@@ -192,7 +192,7 @@ class GPTJAttention(nn.Module):
...
@@ -192,7 +192,7 @@ class GPTJAttention(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
Optional
[
torch
.
FloatTensor
]
,
hidden_states
:
torch
.
FloatTensor
,
layer_past
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
layer_past
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
attention_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
...
...
src/transformers/models/llama/modeling_llama.py
View file @
7dcd8703
...
@@ -38,6 +38,7 @@ logger = logging.get_logger(__name__)
...
@@ -38,6 +38,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC
=
"LlamaConfig"
_CONFIG_FOR_DOC
=
"LlamaConfig"
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def
_make_causal_mask
(
input_ids_shape
:
torch
.
Size
,
dtype
:
torch
.
dtype
,
past_key_values_length
:
int
=
0
):
def
_make_causal_mask
(
input_ids_shape
:
torch
.
Size
,
dtype
:
torch
.
dtype
,
past_key_values_length
:
int
=
0
):
"""
"""
Make causal mask used for bi-directional self-attention.
Make causal mask used for bi-directional self-attention.
...
@@ -53,6 +54,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
...
@@ -53,6 +54,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
return
mask
[
None
,
None
,
:,
:].
expand
(
bsz
,
1
,
tgt_len
,
tgt_len
+
past_key_values_length
)
return
mask
[
None
,
None
,
:,
:].
expand
(
bsz
,
1
,
tgt_len
,
tgt_len
+
past_key_values_length
)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def
_expand_mask
(
mask
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
tgt_len
:
Optional
[
int
]
=
None
):
def
_expand_mask
(
mask
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
tgt_len
:
Optional
[
int
]
=
None
):
"""
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
...
@@ -126,9 +128,11 @@ def rotate_half(x):
...
@@ -126,9 +128,11 @@ def rotate_half(x):
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
offset
:
int
=
0
):
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
position_ids
):
cos
=
cos
[...,
offset
:
q
.
shape
[
-
2
]
+
offset
,
:]
gather_indices
=
position_ids
[:,
None
,
:,
None
]
# [bs, 1, seq_len, 1]
sin
=
sin
[...,
offset
:
q
.
shape
[
-
2
]
+
offset
,
:]
gather_indices
=
gather_indices
.
repeat
(
1
,
cos
.
shape
[
1
],
1
,
cos
.
shape
[
3
])
cos
=
torch
.
gather
(
cos
.
repeat
(
gather_indices
.
shape
[
0
],
1
,
1
,
1
),
2
,
gather_indices
)
sin
=
torch
.
gather
(
sin
.
repeat
(
gather_indices
.
shape
[
0
],
1
,
1
,
1
),
2
,
gather_indices
)
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
return
q_embed
,
k_embed
...
@@ -197,13 +201,12 @@ class LlamaAttention(nn.Module):
...
@@ -197,13 +201,12 @@ class LlamaAttention(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
use_cache
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Input shape: Batch x Time x Channel"""
bsz
,
q_len
,
_
=
hidden_states
.
size
()
bsz
,
q_len
,
_
=
hidden_states
.
size
()
query_states
=
self
.
q_proj
(
hidden_states
).
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
query_states
=
self
.
q_proj
(
hidden_states
).
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
...
@@ -211,12 +214,10 @@ class LlamaAttention(nn.Module):
...
@@ -211,12 +214,10 @@ class LlamaAttention(nn.Module):
value_states
=
self
.
v_proj
(
hidden_states
).
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
self
.
v_proj
(
hidden_states
).
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
kv_seq_len
=
key_states
.
shape
[
-
2
]
kv_seq_len
=
key_states
.
shape
[
-
2
]
offset
=
0
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
offset
=
past_key_value
[
0
].
shape
[
-
2
]
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
kv_seq_len
+=
offset
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
offset
=
offset
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
# [bsz, nh, t, hd]
# [bsz, nh, t, hd]
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
...
@@ -283,9 +284,10 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -283,9 +284,10 @@ class LlamaDecoderLayer(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
Optional
[
bool
]
=
False
,
output_attentions
:
Optional
[
bool
]
=
False
,
use_cache
:
Optional
[
bool
]
=
False
,
use_cache
:
Optional
[
bool
]
=
False
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
)
->
Tuple
[
torch
.
FloatTensor
,
Optional
[
Tuple
[
torch
.
FloatTensor
,
torch
.
FloatTensor
]]]:
)
->
Tuple
[
torch
.
FloatTensor
,
Optional
[
Tuple
[
torch
.
FloatTensor
,
torch
.
FloatTensor
]]]:
"""
"""
Args:
Args:
...
@@ -308,8 +310,9 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -308,8 +310,9 @@ class LlamaDecoderLayer(nn.Module):
# Self Attention
# Self Attention
hidden_states
,
self_attn_weights
,
present_key_value
=
self
.
self_attn
(
hidden_states
,
self_attn_weights
,
present_key_value
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
past_key_value
=
past_key_value
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
use_cache
=
use_cache
,
)
)
...
@@ -406,7 +409,11 @@ LLAMA_INPUTS_DOCSTRING = r"""
...
@@ -406,7 +409,11 @@ LLAMA_INPUTS_DOCSTRING = r"""
- 1 indicates the head is **not masked**,
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
...
@@ -488,10 +495,12 @@ class LlamaModel(LlamaPreTrainedModel):
...
@@ -488,10 +495,12 @@ class LlamaModel(LlamaPreTrainedModel):
return
combined_attention_mask
return
combined_attention_mask
@
add_start_docstrings_to_model_forward
(
LLAMA_INPUTS_DOCSTRING
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
LongTensor
=
None
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
...
@@ -499,49 +508,6 @@ class LlamaModel(LlamaPreTrainedModel):
...
@@ -499,49 +508,6 @@ class LlamaModel(LlamaPreTrainedModel):
output_hidden_states
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPast
]:
)
->
Union
[
Tuple
,
BaseModelOutputWithPast
]:
r
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
...
@@ -559,11 +525,23 @@ class LlamaModel(LlamaPreTrainedModel):
...
@@ -559,11 +525,23 @@ class LlamaModel(LlamaPreTrainedModel):
batch_size
,
seq_length
,
_
=
inputs_embeds
.
shape
batch_size
,
seq_length
,
_
=
inputs_embeds
.
shape
else
:
else
:
raise
ValueError
(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
raise
ValueError
(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past
=
seq_length
seq_length_with_past
=
seq_length
past_key_values_length
=
0
past_key_values_length
=
0
if
past_key_values
is
not
None
:
if
past_key_values
is
not
None
:
past_key_values_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
past_key_values_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
seq_length_with_past
=
seq_length_with_past
+
past_key_values_length
seq_length_with_past
=
seq_length_with_past
+
past_key_values_length
if
position_ids
is
None
:
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
position_ids
=
torch
.
arange
(
past_key_values_length
,
seq_length
+
past_key_values_length
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
view
(
-
1
,
seq_length
)
else
:
position_ids
=
position_ids
.
view
(
-
1
,
seq_length
).
long
()
if
inputs_embeds
is
None
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
# embed positions
# embed positions
...
@@ -608,12 +586,14 @@ class LlamaModel(LlamaPreTrainedModel):
...
@@ -608,12 +586,14 @@ class LlamaModel(LlamaPreTrainedModel):
create_custom_forward
(
decoder_layer
),
create_custom_forward
(
decoder_layer
),
hidden_states
,
hidden_states
,
attention_mask
,
attention_mask
,
position_ids
,
None
,
None
,
)
)
else
:
else
:
layer_outputs
=
decoder_layer
(
layer_outputs
=
decoder_layer
(
hidden_states
,
hidden_states
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
use_cache
=
use_cache
,
...
@@ -674,11 +654,13 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
...
@@ -674,11 +654,13 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
def
get_decoder
(
self
):
def
get_decoder
(
self
):
return
self
.
model
return
self
.
model
@
add_start_docstrings_to_model_forward
(
LLAMA_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
CausalLMOutputWithPast
,
config_class
=
_CONFIG_FOR_DOC
)
@
replace_return_docstrings
(
output_type
=
CausalLMOutputWithPast
,
config_class
=
_CONFIG_FOR_DOC
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
LongTensor
=
None
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
...
@@ -689,52 +671,10 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
...
@@ -689,52 +671,10 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
)
->
Union
[
Tuple
,
CausalLMOutputWithPast
]:
)
->
Union
[
Tuple
,
CausalLMOutputWithPast
]:
r
"""
r
"""
Args:
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Returns:
Returns:
...
@@ -765,6 +705,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
...
@@ -765,6 +705,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
outputs
=
self
.
model
(
outputs
=
self
.
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
use_cache
=
use_cache
,
use_cache
=
use_cache
,
...
@@ -807,6 +748,14 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
...
@@ -807,6 +748,14 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
if
past_key_values
:
if
past_key_values
:
input_ids
=
input_ids
[:,
-
1
:]
input_ids
=
input_ids
[:,
-
1
:]
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
if
attention_mask
is
not
None
and
position_ids
is
None
:
# create position_ids on the fly for batch generation
position_ids
=
attention_mask
.
long
().
cumsum
(
-
1
)
-
1
position_ids
.
masked_fill_
(
attention_mask
==
0
,
1
)
if
past_key_values
:
position_ids
=
position_ids
[:,
-
1
].
unsqueeze
(
-
1
)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if
inputs_embeds
is
not
None
and
past_key_values
is
None
:
if
inputs_embeds
is
not
None
and
past_key_values
is
None
:
model_inputs
=
{
"inputs_embeds"
:
inputs_embeds
}
model_inputs
=
{
"inputs_embeds"
:
inputs_embeds
}
...
@@ -815,6 +764,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
...
@@ -815,6 +764,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
model_inputs
.
update
(
model_inputs
.
update
(
{
{
"position_ids"
:
position_ids
,
"past_key_values"
:
past_key_values
,
"past_key_values"
:
past_key_values
,
"use_cache"
:
kwargs
.
get
(
"use_cache"
),
"use_cache"
:
kwargs
.
get
(
"use_cache"
),
"attention_mask"
:
attention_mask
,
"attention_mask"
:
attention_mask
,
...
@@ -868,6 +818,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
...
@@ -868,6 +818,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
self
,
self
,
input_ids
:
torch
.
LongTensor
=
None
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
...
@@ -886,8 +837,9 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
...
@@ -886,8 +837,9 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
transformer_outputs
=
self
.
model
(
transformer_outputs
=
self
.
model
(
input_ids
,
input_ids
,
past_key_values
=
past_key_values
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
use_cache
=
use_cache
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
...
...
tests/models/gpt_neox/test_modeling_gpt_neox.py
View file @
7dcd8703
...
@@ -237,7 +237,7 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
...
@@ -237,7 +237,7 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
@
require_torch
@
require_torch
class
GPTNeoXLanguageGenerationTest
(
unittest
.
TestCase
):
class
GPTNeoXLanguageGenerationTest
(
unittest
.
TestCase
):
@
slow
@
slow
def
test_lm_generate_
codegen
(
self
):
def
test_lm_generate_
gptneox
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"EleutherAI/pythia-410m-deduped"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"EleutherAI/pythia-410m-deduped"
)
for
checkpointing
in
[
True
,
False
]:
for
checkpointing
in
[
True
,
False
]:
model
=
GPTNeoXForCausalLM
.
from_pretrained
(
"EleutherAI/pythia-410m-deduped"
)
model
=
GPTNeoXForCausalLM
.
from_pretrained
(
"EleutherAI/pythia-410m-deduped"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment