Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a796f7ee
Unverified
Commit
a796f7ee
authored
Sep 13, 2023
by
Joao Gante
Committed by
GitHub
Sep 13, 2023
Browse files
Falcon: batched generation (#26137)
parent
95a90410
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
105 additions
and
14 deletions
+105
-14
src/transformers/models/falcon/modeling_falcon.py
src/transformers/models/falcon/modeling_falcon.py
+66
-12
tests/models/falcon/test_modeling_falcon.py
tests/models/falcon/test_modeling_falcon.py
+39
-2
No files found.
src/transformers/models/falcon/modeling_falcon.py
View file @
a796f7ee
...
@@ -67,6 +67,7 @@ def rotate_half(x):
...
@@ -67,6 +67,7 @@ def rotate_half(x):
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
# TODO (joao): Is this the same implementation as in Llama? If so, let's make them the same and add the copy facilities
class
FalconRotaryEmbedding
(
nn
.
Module
):
class
FalconRotaryEmbedding
(
nn
.
Module
):
"""Implementation of RotaryEmbedding from GPT-NeoX.
"""Implementation of RotaryEmbedding from GPT-NeoX.
This implementation is designed to operate on queries and keys that are compatible with `[batch_size,
This implementation is designed to operate on queries and keys that are compatible with `[batch_size,
...
@@ -99,19 +100,40 @@ class FalconRotaryEmbedding(nn.Module):
...
@@ -99,19 +100,40 @@ class FalconRotaryEmbedding(nn.Module):
self
.
cos_cached
=
self
.
cos_cached
.
type
(
dtype
)
self
.
cos_cached
=
self
.
cos_cached
.
type
(
dtype
)
self
.
sin_cached
=
self
.
sin_cached
.
type
(
dtype
)
self
.
sin_cached
=
self
.
sin_cached
.
type
(
dtype
)
def
cos_sin
(
self
,
seq_len
:
int
,
past_key_values_length
:
int
,
device
=
"cpu"
,
dtype
=
torch
.
bfloat16
)
->
torch
.
Tensor
:
def
cos_sin
(
self
,
seq_len
:
int
,
past_key_values_length
:
int
,
position_ids
:
torch
.
Tensor
,
device
=
"cpu"
,
dtype
=
torch
.
bfloat16
)
->
torch
.
Tensor
:
total_length
=
seq_len
+
past_key_values_length
total_length
=
seq_len
+
past_key_values_length
if
total_length
>
self
.
seq_len_cached
:
if
total_length
>
self
.
seq_len_cached
:
self
.
_set_cos_sin_cache
(
total_length
,
device
,
dtype
)
self
.
_set_cos_sin_cache
(
total_length
,
device
,
dtype
)
return
(
# Gather cos, sin at the designated position ids
self
.
cos_cached
[:,
past_key_values_length
:
seq_len
+
past_key_values_length
],
cos
=
self
.
cos_cached
.
squeeze
(
0
)[
position_ids
]
# [bs, seq_len, dim]
self
.
sin_cached
[:,
past_key_values_length
:
seq_len
+
past_key_values_length
],
sin
=
self
.
sin_cached
.
squeeze
(
0
)[
position_ids
]
# [bs, seq_len, dim]
)
return
cos
,
sin
def
forward
(
self
,
query
,
key
,
past_key_values_length
,
position_ids
):
_
,
seq_len
,
_
=
query
.
shape
cos
,
sin
=
self
.
cos_sin
(
seq_len
,
past_key_values_length
,
position_ids
,
query
.
device
,
query
.
dtype
)
# Query and key's shapes are [bs * num_heads, seq_len, dim], might need manual expansion. Ifs and elses used to
# avoid unnecessary repeat_interleave operations.
query_expansion_factor
=
int
(
query
.
shape
[
0
]
/
cos
.
shape
[
0
])
if
query_expansion_factor
>
1
:
query_cos
=
torch
.
repeat_interleave
(
cos
,
query_expansion_factor
,
dim
=
0
)
query_sin
=
torch
.
repeat_interleave
(
sin
,
query_expansion_factor
,
dim
=
0
)
else
:
query_cos
,
query_sin
=
cos
,
sin
key_expansion_factor
=
int
(
key
.
shape
[
0
]
/
cos
.
shape
[
0
])
if
key_expansion_factor
>
1
:
if
key_expansion_factor
!=
query_expansion_factor
:
key_cos
=
torch
.
repeat_interleave
(
cos
,
key_expansion_factor
,
dim
=
0
)
key_sin
=
torch
.
repeat_interleave
(
sin
,
key_expansion_factor
,
dim
=
0
)
else
:
key_cos
,
key_sin
=
query_cos
,
query_sin
else
:
key_cos
,
key_sin
=
cos
,
sin
def
forward
(
self
,
query
,
key
,
past_key_values_length
=
0
):
return
(
query
*
query_cos
)
+
(
rotate_half
(
query
)
*
query_sin
),
(
key
*
key_cos
)
+
(
rotate_half
(
key
)
*
key_sin
)
batch
,
seq_len
,
head_dim
=
query
.
shape
cos
,
sin
=
self
.
cos_sin
(
seq_len
,
past_key_values_length
,
query
.
device
,
query
.
dtype
)
return
(
query
*
cos
)
+
(
rotate_half
(
query
)
*
sin
),
(
key
*
cos
)
+
(
rotate_half
(
key
)
*
sin
)
class
FalconLinearScalingRotaryEmbedding
(
FalconRotaryEmbedding
):
class
FalconLinearScalingRotaryEmbedding
(
FalconRotaryEmbedding
):
...
@@ -270,7 +292,7 @@ class FalconAttention(nn.Module):
...
@@ -270,7 +292,7 @@ class FalconAttention(nn.Module):
f
"
{
self
.
num_heads
}
)."
f
"
{
self
.
num_heads
}
)."
)
)
self
.
maybe_rotary
=
self
.
_init_rope
()
if
config
.
rotary
else
lambda
q
,
k
,
t
:
(
q
,
k
)
self
.
maybe_rotary
=
self
.
_init_rope
()
if
config
.
rotary
else
lambda
q
,
k
,
t
,
p
:
(
q
,
k
)
# Layer-wise attention scaling
# Layer-wise attention scaling
self
.
inv_norm_factor
=
1.0
/
math
.
sqrt
(
self
.
head_dim
)
self
.
inv_norm_factor
=
1.0
/
math
.
sqrt
(
self
.
head_dim
)
...
@@ -378,6 +400,7 @@ class FalconAttention(nn.Module):
...
@@ -378,6 +400,7 @@ class FalconAttention(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
alibi
:
Optional
[
torch
.
Tensor
],
alibi
:
Optional
[
torch
.
Tensor
],
attention_mask
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
layer_past
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
layer_past
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
use_cache
:
bool
=
False
,
use_cache
:
bool
=
False
,
...
@@ -399,7 +422,7 @@ class FalconAttention(nn.Module):
...
@@ -399,7 +422,7 @@ class FalconAttention(nn.Module):
value_layer
=
value_layer
.
transpose
(
1
,
2
).
reshape
(
batch_size
*
num_kv_heads
,
query_length
,
self
.
head_dim
)
value_layer
=
value_layer
.
transpose
(
1
,
2
).
reshape
(
batch_size
*
num_kv_heads
,
query_length
,
self
.
head_dim
)
past_kv_length
=
0
if
layer_past
is
None
else
layer_past
[
0
].
shape
[
1
]
past_kv_length
=
0
if
layer_past
is
None
else
layer_past
[
0
].
shape
[
1
]
query_layer
,
key_layer
=
self
.
maybe_rotary
(
query_layer
,
key_layer
,
past_kv_length
)
query_layer
,
key_layer
=
self
.
maybe_rotary
(
query_layer
,
key_layer
,
past_kv_length
,
position_ids
)
if
layer_past
is
not
None
:
if
layer_past
is
not
None
:
past_key
,
past_value
=
layer_past
past_key
,
past_value
=
layer_past
...
@@ -415,7 +438,8 @@ class FalconAttention(nn.Module):
...
@@ -415,7 +438,8 @@ class FalconAttention(nn.Module):
else
:
else
:
present
=
None
present
=
None
attention_mask_float
=
(
attention_mask
*
1.0
).
masked_fill
(
attention_mask
,
float
(
"-1e9"
)).
to
(
query_layer
.
dtype
)
float_min
=
torch
.
finfo
(
query_layer
.
dtype
).
min
attention_mask_float
=
(
attention_mask
*
1.0
).
masked_fill
(
attention_mask
,
float_min
).
to
(
query_layer
.
dtype
)
query_layer_
=
query_layer
.
reshape
(
batch_size
,
self
.
num_heads
,
-
1
,
self
.
head_dim
)
query_layer_
=
query_layer
.
reshape
(
batch_size
,
self
.
num_heads
,
-
1
,
self
.
head_dim
)
key_layer_
=
key_layer
.
reshape
(
batch_size
,
num_kv_heads
,
-
1
,
self
.
head_dim
)
key_layer_
=
key_layer
.
reshape
(
batch_size
,
num_kv_heads
,
-
1
,
self
.
head_dim
)
...
@@ -536,6 +560,7 @@ class FalconDecoderLayer(nn.Module):
...
@@ -536,6 +560,7 @@ class FalconDecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
alibi
:
Optional
[
torch
.
Tensor
],
alibi
:
Optional
[
torch
.
Tensor
],
attention_mask
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
layer_past
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
layer_past
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
use_cache
:
bool
=
False
,
use_cache
:
bool
=
False
,
...
@@ -554,6 +579,7 @@ class FalconDecoderLayer(nn.Module):
...
@@ -554,6 +579,7 @@ class FalconDecoderLayer(nn.Module):
attention_layernorm_out
,
attention_layernorm_out
,
layer_past
=
layer_past
,
layer_past
=
layer_past
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
alibi
=
alibi
,
alibi
=
alibi
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
use_cache
=
use_cache
,
use_cache
=
use_cache
,
...
@@ -632,6 +658,11 @@ FALCON_INPUTS_DOCSTRING = r"""
...
@@ -632,6 +658,11 @@ FALCON_INPUTS_DOCSTRING = r"""
- 0 for tokens that are **masked**.
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
...
@@ -836,6 +867,7 @@ class FalconModel(FalconPreTrainedModel):
...
@@ -836,6 +867,7 @@ class FalconModel(FalconPreTrainedModel):
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...]]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...]]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
LongTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
...
@@ -892,6 +924,14 @@ class FalconModel(FalconPreTrainedModel):
...
@@ -892,6 +924,14 @@ class FalconModel(FalconPreTrainedModel):
alibi
=
build_alibi_tensor
(
attention_mask
,
self
.
num_heads
,
dtype
=
hidden_states
.
dtype
)
alibi
=
build_alibi_tensor
(
attention_mask
,
self
.
num_heads
,
dtype
=
hidden_states
.
dtype
)
else
:
else
:
alibi
=
None
alibi
=
None
if
position_ids
is
None
:
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
position_ids
=
torch
.
arange
(
past_key_values_length
,
seq_length
+
past_key_values_length
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
view
(
-
1
,
seq_length
)
else
:
position_ids
=
position_ids
.
view
(
-
1
,
seq_length
).
long
()
causal_mask
=
self
.
_prepare_attn_mask
(
causal_mask
=
self
.
_prepare_attn_mask
(
attention_mask
,
attention_mask
,
...
@@ -922,6 +962,7 @@ class FalconModel(FalconPreTrainedModel):
...
@@ -922,6 +962,7 @@ class FalconModel(FalconPreTrainedModel):
hidden_states
,
hidden_states
,
alibi
,
alibi
,
causal_mask
,
causal_mask
,
position_ids
,
head_mask
[
i
],
head_mask
[
i
],
)
)
else
:
else
:
...
@@ -929,6 +970,7 @@ class FalconModel(FalconPreTrainedModel):
...
@@ -929,6 +970,7 @@ class FalconModel(FalconPreTrainedModel):
hidden_states
,
hidden_states
,
layer_past
=
layer_past
,
layer_past
=
layer_past
,
attention_mask
=
causal_mask
,
attention_mask
=
causal_mask
,
position_ids
=
position_ids
,
head_mask
=
head_mask
[
i
],
head_mask
=
head_mask
[
i
],
use_cache
=
use_cache
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
...
@@ -988,13 +1030,23 @@ class FalconForCausalLM(FalconPreTrainedModel):
...
@@ -988,13 +1030,23 @@ class FalconForCausalLM(FalconPreTrainedModel):
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
dict
:
)
->
dict
:
if
past_key_values
is
not
None
:
if
past_key_values
is
not
None
:
input_ids
=
input_ids
[:,
-
1
:]
input_ids
=
input_ids
[:,
-
1
:]
# Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE.
if
not
self
.
transformer
.
use_alibi
and
attention_mask
is
not
None
and
position_ids
is
None
:
# create position_ids on the fly for batch generation
position_ids
=
attention_mask
.
long
().
cumsum
(
-
1
)
-
1
position_ids
.
masked_fill_
(
attention_mask
==
0
,
1
)
if
past_key_values
:
position_ids
=
position_ids
[:,
-
1
].
unsqueeze
(
-
1
)
return
{
return
{
"input_ids"
:
input_ids
,
"input_ids"
:
input_ids
,
"position_ids"
:
position_ids
,
"past_key_values"
:
past_key_values
,
"past_key_values"
:
past_key_values
,
"use_cache"
:
kwargs
.
get
(
"use_cache"
),
"use_cache"
:
kwargs
.
get
(
"use_cache"
),
"attention_mask"
:
attention_mask
,
"attention_mask"
:
attention_mask
,
...
@@ -1011,6 +1063,7 @@ class FalconForCausalLM(FalconPreTrainedModel):
...
@@ -1011,6 +1063,7 @@ class FalconForCausalLM(FalconPreTrainedModel):
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...]]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...]]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
labels
:
Optional
[
torch
.
Tensor
]
=
None
,
labels
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1032,6 +1085,7 @@ class FalconForCausalLM(FalconPreTrainedModel):
...
@@ -1032,6 +1085,7 @@ class FalconForCausalLM(FalconPreTrainedModel):
input_ids
,
input_ids
,
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
use_cache
=
use_cache
,
use_cache
=
use_cache
,
...
...
tests/models/falcon/test_modeling_falcon.py
View file @
a796f7ee
...
@@ -19,8 +19,16 @@ import unittest
...
@@ -19,8 +19,16 @@ import unittest
from
parameterized
import
parameterized
from
parameterized
import
parameterized
from
transformers
import
AutoConfig
,
AutoModel
,
AutoTokenizer
,
FalconConfig
,
is_torch_available
,
set_seed
from
transformers
import
(
from
transformers.testing_utils
import
CaptureLogger
,
require_torch
,
slow
,
tooslow
,
torch_device
AutoConfig
,
AutoModel
,
AutoModelForCausalLM
,
AutoTokenizer
,
FalconConfig
,
is_torch_available
,
set_seed
,
)
from
transformers.testing_utils
import
CaptureLogger
,
require_bitsandbytes
,
require_torch
,
slow
,
tooslow
,
torch_device
from
transformers.utils
import
logging
as
transformers_logging
from
transformers.utils
import
logging
as
transformers_logging
from
...generation.test_utils
import
GenerationTesterMixin
from
...generation.test_utils
import
GenerationTesterMixin
...
@@ -502,6 +510,35 @@ class FalconLanguageGenerationTest(unittest.TestCase):
...
@@ -502,6 +510,35 @@ class FalconLanguageGenerationTest(unittest.TestCase):
outputs_cache
=
model
.
generate
(
**
inputs
,
do_sample
=
False
,
max_new_tokens
=
20
,
use_cache
=
True
)
outputs_cache
=
model
.
generate
(
**
inputs
,
do_sample
=
False
,
max_new_tokens
=
20
,
use_cache
=
True
)
self
.
assertTrue
((
outputs_cache
-
outputs_no_cache
).
sum
().
item
()
==
0
)
self
.
assertTrue
((
outputs_cache
-
outputs_no_cache
).
sum
().
item
()
==
0
)
@
require_bitsandbytes
@
slow
def
test_batched_generation
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"tiiuae/falcon-7b"
,
padding_side
=
"left"
)
tokenizer
.
pad_token
=
tokenizer
.
eos_token
model
=
AutoModelForCausalLM
.
from_pretrained
(
"tiiuae/falcon-7b"
,
device_map
=
"auto"
,
load_in_4bit
=
True
,
)
test_text
=
"A sequence: 1, 2"
# should generate the rest of the sequence
unpadded_inputs
=
tokenizer
([
test_text
],
return_tensors
=
"pt"
).
to
(
"cuda:0"
)
unpadded_inputs
.
pop
(
"token_type_ids"
)
unpadded_gen_out
=
model
.
generate
(
**
unpadded_inputs
,
max_new_tokens
=
20
)
unpadded_gen_text
=
tokenizer
.
batch_decode
(
unpadded_gen_out
,
skip_special_tokens
=
True
)
dummy_text
=
"This is a longer text "
*
2
# forces left-padding on `test_text`
padded_inputs
=
tokenizer
([
test_text
,
dummy_text
],
return_tensors
=
"pt"
,
padding
=
True
).
to
(
"cuda:0"
)
padded_inputs
.
pop
(
"token_type_ids"
)
padded_gen_out
=
model
.
generate
(
**
padded_inputs
,
max_new_tokens
=
20
)
padded_gen_text
=
tokenizer
.
batch_decode
(
padded_gen_out
,
skip_special_tokens
=
True
)
expected_output
=
"A sequence: 1, 2, 3, 4, 5, 6, 7, 8, "
self
.
assertLess
(
unpadded_inputs
.
input_ids
.
shape
[
-
1
],
padded_inputs
.
input_ids
.
shape
[
-
1
])
# left-padding exists
self
.
assertEqual
(
unpadded_gen_text
[
0
],
expected_output
)
self
.
assertEqual
(
padded_gen_text
[
0
],
expected_output
)
# TODO Lysandre: Remove this in version v4.34
# TODO Lysandre: Remove this in version v4.34
class
FalconOverrideTest
(
unittest
.
TestCase
):
class
FalconOverrideTest
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment