Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a796f7ee
Unverified
Commit
a796f7ee
authored
Sep 13, 2023
by
Joao Gante
Committed by
GitHub
Sep 13, 2023
Browse files
Falcon: batched generation (#26137)
parent
95a90410
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
105 additions
and
14 deletions
+105
-14
src/transformers/models/falcon/modeling_falcon.py
src/transformers/models/falcon/modeling_falcon.py
+66
-12
tests/models/falcon/test_modeling_falcon.py
tests/models/falcon/test_modeling_falcon.py
+39
-2
No files found.
src/transformers/models/falcon/modeling_falcon.py
View file @
a796f7ee
...
...
@@ -67,6 +67,7 @@ def rotate_half(x):
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
# TODO (joao): Is this the same implementation as in Llama? If so, let's make them the same and add the copy facilities
class
FalconRotaryEmbedding
(
nn
.
Module
):
"""Implementation of RotaryEmbedding from GPT-NeoX.
This implementation is designed to operate on queries and keys that are compatible with `[batch_size,
...
...
@@ -99,19 +100,40 @@ class FalconRotaryEmbedding(nn.Module):
self
.
cos_cached
=
self
.
cos_cached
.
type
(
dtype
)
self
.
sin_cached
=
self
.
sin_cached
.
type
(
dtype
)
def
cos_sin
(
self
,
seq_len
:
int
,
past_key_values_length
:
int
,
device
=
"cpu"
,
dtype
=
torch
.
bfloat16
)
->
torch
.
Tensor
:
def
cos_sin
(
self
,
seq_len
:
int
,
past_key_values_length
:
int
,
position_ids
:
torch
.
Tensor
,
device
=
"cpu"
,
dtype
=
torch
.
bfloat16
)
->
torch
.
Tensor
:
total_length
=
seq_len
+
past_key_values_length
if
total_length
>
self
.
seq_len_cached
:
self
.
_set_cos_sin_cache
(
total_length
,
device
,
dtype
)
return
(
self
.
cos_cached
[:,
past_key_values_length
:
seq_len
+
past_key_values_length
],
self
.
sin_cached
[:,
past_key_values_length
:
seq_len
+
past_key_values_length
],
)
# Gather cos, sin at the designated position ids
cos
=
self
.
cos_cached
.
squeeze
(
0
)[
position_ids
]
# [bs, seq_len, dim]
sin
=
self
.
sin_cached
.
squeeze
(
0
)[
position_ids
]
# [bs, seq_len, dim]
return
cos
,
sin
def
forward
(
self
,
query
,
key
,
past_key_values_length
,
position_ids
):
_
,
seq_len
,
_
=
query
.
shape
cos
,
sin
=
self
.
cos_sin
(
seq_len
,
past_key_values_length
,
position_ids
,
query
.
device
,
query
.
dtype
)
# Query and key's shapes are [bs * num_heads, seq_len, dim], might need manual expansion. Ifs and elses used to
# avoid unnecessary repeat_interleave operations.
query_expansion_factor
=
int
(
query
.
shape
[
0
]
/
cos
.
shape
[
0
])
if
query_expansion_factor
>
1
:
query_cos
=
torch
.
repeat_interleave
(
cos
,
query_expansion_factor
,
dim
=
0
)
query_sin
=
torch
.
repeat_interleave
(
sin
,
query_expansion_factor
,
dim
=
0
)
else
:
query_cos
,
query_sin
=
cos
,
sin
key_expansion_factor
=
int
(
key
.
shape
[
0
]
/
cos
.
shape
[
0
])
if
key_expansion_factor
>
1
:
if
key_expansion_factor
!=
query_expansion_factor
:
key_cos
=
torch
.
repeat_interleave
(
cos
,
key_expansion_factor
,
dim
=
0
)
key_sin
=
torch
.
repeat_interleave
(
sin
,
key_expansion_factor
,
dim
=
0
)
else
:
key_cos
,
key_sin
=
query_cos
,
query_sin
else
:
key_cos
,
key_sin
=
cos
,
sin
def
forward
(
self
,
query
,
key
,
past_key_values_length
=
0
):
batch
,
seq_len
,
head_dim
=
query
.
shape
cos
,
sin
=
self
.
cos_sin
(
seq_len
,
past_key_values_length
,
query
.
device
,
query
.
dtype
)
return
(
query
*
cos
)
+
(
rotate_half
(
query
)
*
sin
),
(
key
*
cos
)
+
(
rotate_half
(
key
)
*
sin
)
return
(
query
*
query_cos
)
+
(
rotate_half
(
query
)
*
query_sin
),
(
key
*
key_cos
)
+
(
rotate_half
(
key
)
*
key_sin
)
class
FalconLinearScalingRotaryEmbedding
(
FalconRotaryEmbedding
):
...
...
@@ -270,7 +292,7 @@ class FalconAttention(nn.Module):
f
"
{
self
.
num_heads
}
)."
)
self
.
maybe_rotary
=
self
.
_init_rope
()
if
config
.
rotary
else
lambda
q
,
k
,
t
:
(
q
,
k
)
self
.
maybe_rotary
=
self
.
_init_rope
()
if
config
.
rotary
else
lambda
q
,
k
,
t
,
p
:
(
q
,
k
)
# Layer-wise attention scaling
self
.
inv_norm_factor
=
1.0
/
math
.
sqrt
(
self
.
head_dim
)
...
...
@@ -378,6 +400,7 @@ class FalconAttention(nn.Module):
hidden_states
:
torch
.
Tensor
,
alibi
:
Optional
[
torch
.
Tensor
],
attention_mask
:
torch
.
Tensor
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
layer_past
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
use_cache
:
bool
=
False
,
...
...
@@ -399,7 +422,7 @@ class FalconAttention(nn.Module):
value_layer
=
value_layer
.
transpose
(
1
,
2
).
reshape
(
batch_size
*
num_kv_heads
,
query_length
,
self
.
head_dim
)
past_kv_length
=
0
if
layer_past
is
None
else
layer_past
[
0
].
shape
[
1
]
query_layer
,
key_layer
=
self
.
maybe_rotary
(
query_layer
,
key_layer
,
past_kv_length
)
query_layer
,
key_layer
=
self
.
maybe_rotary
(
query_layer
,
key_layer
,
past_kv_length
,
position_ids
)
if
layer_past
is
not
None
:
past_key
,
past_value
=
layer_past
...
...
@@ -415,7 +438,8 @@ class FalconAttention(nn.Module):
else
:
present
=
None
attention_mask_float
=
(
attention_mask
*
1.0
).
masked_fill
(
attention_mask
,
float
(
"-1e9"
)).
to
(
query_layer
.
dtype
)
float_min
=
torch
.
finfo
(
query_layer
.
dtype
).
min
attention_mask_float
=
(
attention_mask
*
1.0
).
masked_fill
(
attention_mask
,
float_min
).
to
(
query_layer
.
dtype
)
query_layer_
=
query_layer
.
reshape
(
batch_size
,
self
.
num_heads
,
-
1
,
self
.
head_dim
)
key_layer_
=
key_layer
.
reshape
(
batch_size
,
num_kv_heads
,
-
1
,
self
.
head_dim
)
...
...
@@ -536,6 +560,7 @@ class FalconDecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
alibi
:
Optional
[
torch
.
Tensor
],
attention_mask
:
torch
.
Tensor
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
layer_past
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
use_cache
:
bool
=
False
,
...
...
@@ -554,6 +579,7 @@ class FalconDecoderLayer(nn.Module):
attention_layernorm_out
,
layer_past
=
layer_past
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
alibi
=
alibi
,
head_mask
=
head_mask
,
use_cache
=
use_cache
,
...
...
@@ -632,6 +658,11 @@ FALCON_INPUTS_DOCSTRING = r"""
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
...
...
@@ -836,6 +867,7 @@ class FalconModel(FalconPreTrainedModel):
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...]]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
...
...
@@ -892,6 +924,14 @@ class FalconModel(FalconPreTrainedModel):
alibi
=
build_alibi_tensor
(
attention_mask
,
self
.
num_heads
,
dtype
=
hidden_states
.
dtype
)
else
:
alibi
=
None
if
position_ids
is
None
:
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
position_ids
=
torch
.
arange
(
past_key_values_length
,
seq_length
+
past_key_values_length
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
view
(
-
1
,
seq_length
)
else
:
position_ids
=
position_ids
.
view
(
-
1
,
seq_length
).
long
()
causal_mask
=
self
.
_prepare_attn_mask
(
attention_mask
,
...
...
@@ -922,6 +962,7 @@ class FalconModel(FalconPreTrainedModel):
hidden_states
,
alibi
,
causal_mask
,
position_ids
,
head_mask
[
i
],
)
else
:
...
...
@@ -929,6 +970,7 @@ class FalconModel(FalconPreTrainedModel):
hidden_states
,
layer_past
=
layer_past
,
attention_mask
=
causal_mask
,
position_ids
=
position_ids
,
head_mask
=
head_mask
[
i
],
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
...
...
@@ -988,13 +1030,23 @@ class FalconForCausalLM(FalconPreTrainedModel):
input_ids
:
torch
.
LongTensor
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
dict
:
if
past_key_values
is
not
None
:
input_ids
=
input_ids
[:,
-
1
:]
# Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE.
if
not
self
.
transformer
.
use_alibi
and
attention_mask
is
not
None
and
position_ids
is
None
:
# create position_ids on the fly for batch generation
position_ids
=
attention_mask
.
long
().
cumsum
(
-
1
)
-
1
position_ids
.
masked_fill_
(
attention_mask
==
0
,
1
)
if
past_key_values
:
position_ids
=
position_ids
[:,
-
1
].
unsqueeze
(
-
1
)
return
{
"input_ids"
:
input_ids
,
"position_ids"
:
position_ids
,
"past_key_values"
:
past_key_values
,
"use_cache"
:
kwargs
.
get
(
"use_cache"
),
"attention_mask"
:
attention_mask
,
...
...
@@ -1011,6 +1063,7 @@ class FalconForCausalLM(FalconPreTrainedModel):
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...]]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
labels
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1032,6 +1085,7 @@ class FalconForCausalLM(FalconPreTrainedModel):
input_ids
,
past_key_values
=
past_key_values
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
use_cache
=
use_cache
,
...
...
tests/models/falcon/test_modeling_falcon.py
View file @
a796f7ee
...
...
@@ -19,8 +19,16 @@ import unittest
from
parameterized
import
parameterized
from
transformers
import
AutoConfig
,
AutoModel
,
AutoTokenizer
,
FalconConfig
,
is_torch_available
,
set_seed
from
transformers.testing_utils
import
CaptureLogger
,
require_torch
,
slow
,
tooslow
,
torch_device
from
transformers
import
(
AutoConfig
,
AutoModel
,
AutoModelForCausalLM
,
AutoTokenizer
,
FalconConfig
,
is_torch_available
,
set_seed
,
)
from
transformers.testing_utils
import
CaptureLogger
,
require_bitsandbytes
,
require_torch
,
slow
,
tooslow
,
torch_device
from
transformers.utils
import
logging
as
transformers_logging
from
...generation.test_utils
import
GenerationTesterMixin
...
...
@@ -502,6 +510,35 @@ class FalconLanguageGenerationTest(unittest.TestCase):
outputs_cache
=
model
.
generate
(
**
inputs
,
do_sample
=
False
,
max_new_tokens
=
20
,
use_cache
=
True
)
self
.
assertTrue
((
outputs_cache
-
outputs_no_cache
).
sum
().
item
()
==
0
)
@
require_bitsandbytes
@
slow
def
test_batched_generation
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"tiiuae/falcon-7b"
,
padding_side
=
"left"
)
tokenizer
.
pad_token
=
tokenizer
.
eos_token
model
=
AutoModelForCausalLM
.
from_pretrained
(
"tiiuae/falcon-7b"
,
device_map
=
"auto"
,
load_in_4bit
=
True
,
)
test_text
=
"A sequence: 1, 2"
# should generate the rest of the sequence
unpadded_inputs
=
tokenizer
([
test_text
],
return_tensors
=
"pt"
).
to
(
"cuda:0"
)
unpadded_inputs
.
pop
(
"token_type_ids"
)
unpadded_gen_out
=
model
.
generate
(
**
unpadded_inputs
,
max_new_tokens
=
20
)
unpadded_gen_text
=
tokenizer
.
batch_decode
(
unpadded_gen_out
,
skip_special_tokens
=
True
)
dummy_text
=
"This is a longer text "
*
2
# forces left-padding on `test_text`
padded_inputs
=
tokenizer
([
test_text
,
dummy_text
],
return_tensors
=
"pt"
,
padding
=
True
).
to
(
"cuda:0"
)
padded_inputs
.
pop
(
"token_type_ids"
)
padded_gen_out
=
model
.
generate
(
**
padded_inputs
,
max_new_tokens
=
20
)
padded_gen_text
=
tokenizer
.
batch_decode
(
padded_gen_out
,
skip_special_tokens
=
True
)
expected_output
=
"A sequence: 1, 2, 3, 4, 5, 6, 7, 8, "
self
.
assertLess
(
unpadded_inputs
.
input_ids
.
shape
[
-
1
],
padded_inputs
.
input_ids
.
shape
[
-
1
])
# left-padding exists
self
.
assertEqual
(
unpadded_gen_text
[
0
],
expected_output
)
self
.
assertEqual
(
padded_gen_text
[
0
],
expected_output
)
# TODO Lysandre: Remove this in version v4.34
class
FalconOverrideTest
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment