Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7b090876
Unverified
Commit
7b090876
authored
Jul 28, 2022
by
Thomas Wang
Committed by
GitHub
Jul 28, 2022
Browse files
[BLOOM] Deprecate `position_ids` (#18342)
parent
9c336657
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
22 deletions
+44
-22
src/transformers/models/bloom/modeling_bloom.py
src/transformers/models/bloom/modeling_bloom.py
+44
-22
No files found.
src/transformers/models/bloom/modeling_bloom.py
View file @
7b090876
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
"""PyTorch BLOOM model."""
"""PyTorch BLOOM model."""
import
math
import
math
import
warnings
from
typing
import
Tuple
,
Union
from
typing
import
Tuple
,
Union
import
torch
import
torch
...
@@ -522,11 +523,6 @@ BLOOM_INPUTS_DOCSTRING = r"""
...
@@ -522,11 +523,6 @@ BLOOM_INPUTS_DOCSTRING = r"""
- 0 for tokens that are **masked**.
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
...
@@ -617,14 +613,24 @@ class BloomModel(BloomPreTrainedModel):
...
@@ -617,14 +613,24 @@ class BloomModel(BloomPreTrainedModel):
input_ids
=
None
,
input_ids
=
None
,
past_key_values
=
None
,
past_key_values
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
use_cache
=
None
,
use_cache
=
None
,
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
return_dict
=
None
,
return_dict
=
None
,
**
deprecated_arguments
)
->
Union
[
Tuple
[
torch
.
Tensor
],
BaseModelOutputWithPastAndCrossAttentions
]:
)
->
Union
[
Tuple
[
torch
.
Tensor
],
BaseModelOutputWithPastAndCrossAttentions
]:
if
deprecated_arguments
.
pop
(
"position_ids"
,
False
)
is
not
False
:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings
.
warn
(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
" passing `position_ids`."
,
FutureWarning
,
)
if
len
(
deprecated_arguments
)
>
0
:
raise
ValueError
(
f
"Got unexpected arguments:
{
deprecated_arguments
}
"
)
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
...
@@ -772,16 +778,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
...
@@ -772,16 +778,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
input_ids
=
input_ids
[:,
-
1
].
unsqueeze
(
-
1
)
input_ids
=
input_ids
[:,
-
1
].
unsqueeze
(
-
1
)
attention_mask
=
kwargs
.
get
(
"attention_mask"
,
None
)
attention_mask
=
kwargs
.
get
(
"attention_mask"
,
None
)
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
if
attention_mask
is
not
None
and
position_ids
is
None
:
# create position_ids on the fly for batch generation
position_ids
=
attention_mask
.
long
().
cumsum
(
-
1
)
-
1
position_ids
.
masked_fill_
(
attention_mask
==
0
,
1
)
if
past
:
position_ids
=
position_ids
[:,
-
1
].
unsqueeze
(
-
1
)
else
:
position_ids
=
None
return
{
return
{
"input_ids"
:
input_ids
,
"input_ids"
:
input_ids
,
"past_key_values"
:
past
,
"past_key_values"
:
past
,
...
@@ -801,7 +798,6 @@ class BloomForCausalLM(BloomPreTrainedModel):
...
@@ -801,7 +798,6 @@ class BloomForCausalLM(BloomPreTrainedModel):
input_ids
=
None
,
input_ids
=
None
,
past_key_values
=
None
,
past_key_values
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
labels
=
None
,
...
@@ -809,6 +805,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
...
@@ -809,6 +805,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
return_dict
=
None
,
return_dict
=
None
,
**
deprecated_arguments
)
->
Union
[
Tuple
[
torch
.
Tensor
],
CausalLMOutputWithCrossAttentions
]:
)
->
Union
[
Tuple
[
torch
.
Tensor
],
CausalLMOutputWithCrossAttentions
]:
r
"""
r
"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
...
@@ -816,13 +813,22 @@ class BloomForCausalLM(BloomPreTrainedModel):
...
@@ -816,13 +813,22 @@ class BloomForCausalLM(BloomPreTrainedModel):
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
"""
if
deprecated_arguments
.
pop
(
"position_ids"
,
False
)
is
not
False
:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings
.
warn
(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
" passing `position_ids`."
,
FutureWarning
,
)
if
len
(
deprecated_arguments
)
>
0
:
raise
ValueError
(
f
"Got unexpected arguments:
{
deprecated_arguments
}
"
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
transformer_outputs
=
self
.
transformer
(
transformer_outputs
=
self
.
transformer
(
input_ids
,
input_ids
,
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
use_cache
=
use_cache
,
use_cache
=
use_cache
,
...
@@ -907,7 +913,6 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
...
@@ -907,7 +913,6 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
input_ids
=
None
,
input_ids
=
None
,
past_key_values
=
None
,
past_key_values
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
labels
=
None
,
...
@@ -915,6 +920,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
...
@@ -915,6 +920,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
return_dict
=
None
,
return_dict
=
None
,
**
deprecated_arguments
)
->
Union
[
Tuple
[
torch
.
Tensor
],
SequenceClassifierOutputWithPast
]:
)
->
Union
[
Tuple
[
torch
.
Tensor
],
SequenceClassifierOutputWithPast
]:
r
"""
r
"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
...
@@ -922,6 +928,15 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
...
@@ -922,6 +928,15 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
"""
if
deprecated_arguments
.
pop
(
"position_ids"
,
False
)
is
not
False
:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings
.
warn
(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
" passing `position_ids`."
,
FutureWarning
,
)
if
len
(
deprecated_arguments
)
>
0
:
raise
ValueError
(
f
"Got unexpected arguments:
{
deprecated_arguments
}
"
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
...
@@ -929,7 +944,6 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
...
@@ -929,7 +944,6 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
input_ids
,
input_ids
,
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
use_cache
=
use_cache
,
use_cache
=
use_cache
,
...
@@ -1036,7 +1050,6 @@ class BloomForTokenClassification(BloomPreTrainedModel):
...
@@ -1036,7 +1050,6 @@ class BloomForTokenClassification(BloomPreTrainedModel):
input_ids
=
None
,
input_ids
=
None
,
past_key_values
=
None
,
past_key_values
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
labels
=
None
,
...
@@ -1044,6 +1057,7 @@ class BloomForTokenClassification(BloomPreTrainedModel):
...
@@ -1044,6 +1057,7 @@ class BloomForTokenClassification(BloomPreTrainedModel):
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
return_dict
=
None
,
return_dict
=
None
,
**
deprecated_arguments
)
->
Union
[
Tuple
[
torch
.
Tensor
],
TokenClassifierOutput
]:
)
->
Union
[
Tuple
[
torch
.
Tensor
],
TokenClassifierOutput
]:
r
"""
r
"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
...
@@ -1051,6 +1065,15 @@ class BloomForTokenClassification(BloomPreTrainedModel):
...
@@ -1051,6 +1065,15 @@ class BloomForTokenClassification(BloomPreTrainedModel):
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
"""
if
deprecated_arguments
.
pop
(
"position_ids"
,
False
)
is
not
False
:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings
.
warn
(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
" passing `position_ids`."
,
FutureWarning
,
)
if
len
(
deprecated_arguments
)
>
0
:
raise
ValueError
(
f
"Got unexpected arguments:
{
deprecated_arguments
}
"
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
...
@@ -1058,7 +1081,6 @@ class BloomForTokenClassification(BloomPreTrainedModel):
...
@@ -1058,7 +1081,6 @@ class BloomForTokenClassification(BloomPreTrainedModel):
input_ids
,
input_ids
,
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
use_cache
=
use_cache
,
use_cache
=
use_cache
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment