Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ad28ca29
Unverified
Commit
ad28ca29
authored
Jul 10, 2022
by
Stas Bekman
Committed by
GitHub
Jul 10, 2022
Browse files
[bloom] fix alibi device placement (#18087)
parent
8b332a6a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
5 deletions
+5
-5
src/transformers/models/bloom/modeling_bloom.py
src/transformers/models/bloom/modeling_bloom.py
+5
-5
No files found.
src/transformers/models/bloom/modeling_bloom.py
View file @
ad28ca29
...
@@ -93,7 +93,7 @@ def attention_mask_func(attention_scores, attention_mask, causal_mask):
...
@@ -93,7 +93,7 @@ def attention_mask_func(attention_scores, attention_mask, causal_mask):
)
)
def
build_alibi_tensor
(
max_seq_len
,
n_head
,
dtype
=
torch
.
bfloat16
):
def
build_alibi_tensor
(
max_seq_len
,
n_head
,
device
,
dtype
=
torch
.
bfloat16
):
"""
"""
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
...
@@ -129,7 +129,7 @@ def build_alibi_tensor(max_seq_len, n_head, dtype=torch.bfloat16):
...
@@ -129,7 +129,7 @@ def build_alibi_tensor(max_seq_len, n_head, dtype=torch.bfloat16):
arange_tensor
=
torch
.
arange
(
max_seq_len
).
unsqueeze
(
0
).
unsqueeze
(
0
)
arange_tensor
=
torch
.
arange
(
max_seq_len
).
unsqueeze
(
0
).
unsqueeze
(
0
)
alibi
=
slopes
*
arange_tensor
.
expand
(
n_head
,
-
1
,
-
1
)
alibi
=
slopes
*
arange_tensor
.
expand
(
n_head
,
-
1
,
-
1
)
alibi
=
alibi
.
to
(
dtype
)
alibi
=
alibi
.
to
(
device
=
device
,
dtype
=
dtype
)
return
alibi
return
alibi
...
@@ -147,7 +147,7 @@ def pre_process_alibi_for_pad(alibi, attention_mask, num_heads):
...
@@ -147,7 +147,7 @@ def pre_process_alibi_for_pad(alibi, attention_mask, num_heads):
# This usually happens when the inference is done with past_key_values
# This usually happens when the inference is done with past_key_values
# In this case we re-create the alibi tensor with the correct sequence length
# In this case we re-create the alibi tensor with the correct sequence length
if
attention_mask
.
shape
[
-
1
]
!=
alibi
.
shape
[
-
1
]:
if
attention_mask
.
shape
[
-
1
]
!=
alibi
.
shape
[
-
1
]:
alibi
=
build_alibi_tensor
(
attention_mask
.
shape
[
-
1
],
num_heads
,
alibi
.
dtype
).
repeat
(
alibi
=
build_alibi_tensor
(
attention_mask
.
shape
[
-
1
],
num_heads
,
alibi
.
device
,
alibi
.
dtype
).
repeat
(
attention_mask
.
shape
[
0
],
1
,
1
attention_mask
.
shape
[
0
],
1
,
1
)
)
# Get the indexes of the padding tokens
# Get the indexes of the padding tokens
...
@@ -156,7 +156,7 @@ def pre_process_alibi_for_pad(alibi, attention_mask, num_heads):
...
@@ -156,7 +156,7 @@ def pre_process_alibi_for_pad(alibi, attention_mask, num_heads):
# Clone the embeddings - we can detach because the embeddings are not learned
# Clone the embeddings - we can detach because the embeddings are not learned
# Get a refence tensor
# Get a refence tensor
slice_reference_alibi
=
build_alibi_tensor
(
alibi
.
shape
[
-
1
],
num_heads
,
alibi
.
dtype
)
slice_reference_alibi
=
build_alibi_tensor
(
alibi
.
shape
[
-
1
],
num_heads
,
alibi
.
device
,
alibi
.
dtype
)
# Loop over the batch where the padding is and replace the alibi tensor by the reference tensor
# Loop over the batch where the padding is and replace the alibi tensor by the reference tensor
# Only where you do not have padding. Replace padding tokens by zeros
# Only where you do not have padding. Replace padding tokens by zeros
...
@@ -767,7 +767,7 @@ class BloomModel(BloomPreTrainedModel):
...
@@ -767,7 +767,7 @@ class BloomModel(BloomPreTrainedModel):
current_sequence_length
=
hidden_states
.
shape
[
1
]
current_sequence_length
=
hidden_states
.
shape
[
1
]
if
past_key_values
[
0
]
is
not
None
:
if
past_key_values
[
0
]
is
not
None
:
current_sequence_length
+=
past_key_values
[
0
][
0
].
shape
[
1
]
current_sequence_length
+=
past_key_values
[
0
][
0
].
shape
[
1
]
alibi
=
build_alibi_tensor
(
current_sequence_length
,
self
.
n_head
,
hidden_states
.
dtype
)
alibi
=
build_alibi_tensor
(
current_sequence_length
,
self
.
n_head
,
hidden_states
.
device
,
hidden_states
.
dtype
)
for
i
,
(
block
,
layer_past
)
in
enumerate
(
zip
(
self
.
h
,
past_key_values
)):
for
i
,
(
block
,
layer_past
)
in
enumerate
(
zip
(
self
.
h
,
past_key_values
)):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment