Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ffe60fdc
Unverified
Commit
ffe60fdc
authored
Mar 07, 2024
by
Joao Gante
Committed by
GitHub
Mar 07, 2024
Browse files
v4.39 deprecations 🧼 (#29492)
parent
979fccc9
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
9 additions
and
400 deletions
+9
-400
docs/source/en/internal/generation_utils.md
docs/source/en/internal/generation_utils.md
+0
-6
docs/source/ja/internal/generation_utils.md
docs/source/ja/internal/generation_utils.md
+0
-6
docs/source/zh/internal/generation_utils.md
docs/source/zh/internal/generation_utils.md
+0
-6
src/transformers/__init__.py
src/transformers/__init__.py
+0
-4
src/transformers/activations.py
src/transformers/activations.py
+0
-9
src/transformers/generation/__init__.py
src/transformers/generation/__init__.py
+0
-4
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+0
-62
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+0
-41
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+5
-8
src/transformers/models/opt/modeling_opt.py
src/transformers/models/opt/modeling_opt.py
+4
-21
src/transformers/utils/dummy_pt_objects.py
src/transformers/utils/dummy_pt_objects.py
+0
-4
src/transformers/utils/dummy_tf_objects.py
src/transformers/utils/dummy_tf_objects.py
+0
-4
tests/generation/test_tf_utils.py
tests/generation/test_tf_utils.py
+0
-97
tests/generation/test_utils.py
tests/generation/test_utils.py
+0
-128
No files found.
docs/source/en/internal/generation_utils.md
View file @
ffe60fdc
...
...
@@ -336,12 +336,6 @@ A [`Constraint`] can be used to force the generation to include specific tokens
-
process
-
finalize
## Utilities
[[autodoc]] top_k_top_p_filtering
[[autodoc]] tf_top_k_top_p_filtering
## Streamers
[[autodoc]] TextStreamer
...
...
docs/source/ja/internal/generation_utils.md
View file @
ffe60fdc
...
...
@@ -335,12 +335,6 @@ generation_output[:2]
-
process
-
finalize
## Utilities
[[autodoc]] top_k_top_p_filtering
[[autodoc]] tf_top_k_top_p_filtering
## Streamers
[[autodoc]] TextStreamer
...
...
docs/source/zh/internal/generation_utils.md
View file @
ffe60fdc
...
...
@@ -330,12 +330,6 @@ generation_output[:2]
-
process
-
finalize
## Utilities
[[autodoc]] top_k_top_p_filtering
[[autodoc]] tf_top_k_top_p_filtering
## Streamers
[[autodoc]] TextStreamer
...
...
src/transformers/__init__.py
View file @
ffe60fdc
...
...
@@ -1409,7 +1409,6 @@ else:
"TypicalLogitsWarper"
,
"UnbatchedClassifierFreeGuidanceLogitsProcessor"
,
"WhisperTimeStampLogitsProcessor"
,
"top_k_top_p_filtering"
,
]
)
_import_structure
[
"generation_utils"
]
=
[]
...
...
@@ -3814,7 +3813,6 @@ else:
"TFTemperatureLogitsWarper"
,
"TFTopKLogitsWarper"
,
"TFTopPLogitsWarper"
,
"tf_top_k_top_p_filtering"
,
]
)
_import_structure
[
"generation_tf_utils"
]
=
[]
...
...
@@ -6206,7 +6204,6 @@ if TYPE_CHECKING:
TypicalLogitsWarper
,
UnbatchedClassifierFreeGuidanceLogitsProcessor
,
WhisperTimeStampLogitsProcessor
,
top_k_top_p_filtering
,
)
from
.modeling_utils
import
PreTrainedModel
from
.models.albert
import
(
...
...
@@ -8178,7 +8175,6 @@ if TYPE_CHECKING:
TFTemperatureLogitsWarper
,
TFTopKLogitsWarper
,
TFTopPLogitsWarper
,
tf_top_k_top_p_filtering
,
)
from
.keras_callbacks
import
KerasMetricCallback
,
PushToHubCallback
from
.modeling_tf_utils
import
(
...
...
src/transformers/activations.py
View file @
ffe60fdc
...
...
@@ -13,7 +13,6 @@
# limitations under the License.
import
math
import
warnings
from
collections
import
OrderedDict
import
torch
...
...
@@ -138,14 +137,6 @@ class AccurateGELUActivation(nn.Module):
return
0.5
*
input
*
(
1
+
torch
.
tanh
(
self
.
precomputed_constant
*
(
input
+
0.044715
*
torch
.
pow
(
input
,
3
))))
class
SiLUActivation
(
nn
.
SiLU
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
warnings
.
warn
(
"The SiLUActivation class has been deprecated and will be removed in v4.39. Please use nn.SiLU instead."
,
)
super
().
__init__
(
*
args
,
**
kwargs
)
class
MishActivation
(
nn
.
Module
):
"""
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
...
...
src/transformers/generation/__init__.py
View file @
ffe60fdc
...
...
@@ -88,7 +88,6 @@ else:
]
_import_structure
[
"utils"
]
=
[
"GenerationMixin"
,
"top_k_top_p_filtering"
,
"GreedySearchEncoderDecoderOutput"
,
"GreedySearchDecoderOnlyOutput"
,
"SampleEncoderDecoderOutput"
,
...
...
@@ -130,7 +129,6 @@ else:
]
_import_structure
[
"tf_utils"
]
=
[
"TFGenerationMixin"
,
"tf_top_k_top_p_filtering"
,
"TFGreedySearchDecoderOnlyOutput"
,
"TFGreedySearchEncoderDecoderOutput"
,
"TFSampleEncoderDecoderOutput"
,
...
...
@@ -241,7 +239,6 @@ if TYPE_CHECKING:
GreedySearchEncoderDecoderOutput
,
SampleDecoderOnlyOutput
,
SampleEncoderDecoderOutput
,
top_k_top_p_filtering
,
)
try
:
...
...
@@ -279,7 +276,6 @@ if TYPE_CHECKING:
TFGreedySearchEncoderDecoderOutput
,
TFSampleDecoderOnlyOutput
,
TFSampleEncoderDecoderOutput
,
tf_top_k_top_p_filtering
,
)
try
:
...
...
src/transformers/generation/tf_utils.py
View file @
ffe60fdc
...
...
@@ -3088,68 +3088,6 @@ class TFGenerationMixin:
return
generated
def
tf_top_k_top_p_filtering
(
logits
,
top_k
=
0
,
top_p
=
1.0
,
filter_value
=-
float
(
"Inf"
),
min_tokens_to_keep
=
1
):
"""
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
top_k (`int`, *optional*, defaults to 0):
If > 0, only keep the top k tokens with highest probability (top-k filtering)
top_p (`float`, *optional*, defaults to 1.0):
If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimumber of tokens we keep per batch example in the output.
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
warnings
.
warn
(
"`tf_top_k_top_p_filtering` is scheduled for deletion in v4.39. Use `TFTopKLogitsWarper` and "
"`TFTopPLogitsWarper` instead."
,
DeprecationWarning
,
)
logits_shape
=
shape_list
(
logits
)
if
top_k
>
0
:
top_k
=
min
(
max
(
top_k
,
min_tokens_to_keep
),
logits_shape
[
-
1
])
# Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove
=
logits
<
tf
.
math
.
top_k
(
logits
,
k
=
top_k
)[
0
][...,
-
1
,
None
]
logits
=
tf
.
where
(
indices_to_remove
,
filter_value
,
logits
)
if
top_p
<
1.0
:
sorted_indices
=
tf
.
argsort
(
logits
,
direction
=
"DESCENDING"
)
sorted_logits
=
tf
.
gather
(
logits
,
sorted_indices
,
axis
=-
1
,
batch_dims
=
1
)
# expects logits to be of dim (batch_size, vocab_size)
cumulative_probs
=
tf
.
math
.
cumsum
(
stable_softmax
(
sorted_logits
,
axis
=-
1
),
axis
=-
1
)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove
=
cumulative_probs
>
top_p
if
min_tokens_to_keep
>
1
:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove
=
tf
.
concat
(
[
tf
.
zeros_like
(
sorted_indices_to_remove
[:,
:
min_tokens_to_keep
]),
sorted_indices_to_remove
[:,
min_tokens_to_keep
:],
],
-
1
,
)
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove
=
tf
.
concat
(
[
tf
.
zeros_like
(
sorted_indices_to_remove
[:,
:
1
]),
sorted_indices_to_remove
[:,
:
-
1
]],
-
1
,
)
# scatter sorted tensors to original indexing
indices_to_remove
=
scatter_values_on_batch_indices
(
sorted_indices_to_remove
,
sorted_indices
)
logits
=
tf
.
where
(
indices_to_remove
,
filter_value
,
logits
)
return
logits
def
scatter_values_on_batch_indices
(
values
,
batch_indices
):
shape
=
shape_list
(
batch_indices
)
# broadcast batch dim to shape
...
...
src/transformers/generation/utils.py
View file @
ffe60fdc
...
...
@@ -4810,47 +4810,6 @@ def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_at
return
outputs
def
top_k_top_p_filtering
(
logits
:
torch
.
FloatTensor
,
top_k
:
int
=
0
,
top_p
:
float
=
1.0
,
filter_value
:
float
=
-
float
(
"Inf"
),
min_tokens_to_keep
:
int
=
1
,
)
->
torch
.
FloatTensor
:
"""
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
top_k (`int`, *optional*, defaults to 0):
If > 0, only keep the top k tokens with highest probability (top-k filtering)
top_p (`float`, *optional*, defaults to 1.0):
If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimumber of tokens we keep per batch example in the output.
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
warnings
.
warn
(
"`top_k_top_p_filtering` is scheduled for deletion in v4.39. Use `TopKLogitsWarper` and `TopPLogitsWarper` "
"instead."
,
DeprecationWarning
,
)
if
top_k
>
0
:
logits
=
TopKLogitsWarper
(
top_k
=
top_k
,
filter_value
=
filter_value
,
min_tokens_to_keep
=
min_tokens_to_keep
)(
None
,
logits
)
if
0
<=
top_p
<=
1.0
:
logits
=
TopPLogitsWarper
(
top_p
=
top_p
,
filter_value
=
filter_value
,
min_tokens_to_keep
=
min_tokens_to_keep
)(
None
,
logits
)
return
logits
def
_ranking_fast
(
context_hidden
:
torch
.
FloatTensor
,
next_hidden
:
torch
.
FloatTensor
,
...
...
src/transformers/models/llama/modeling_llama.py
View file @
ffe60fdc
...
...
@@ -129,10 +129,7 @@ class LlamaRotaryEmbedding(nn.Module):
return
self
.
_cos_cached
@
torch
.
no_grad
()
def
forward
(
self
,
x
,
position_ids
,
seq_len
=
None
):
if
seq_len
is
not
None
:
logger
.
warning_once
(
"The `seq_len` argument is deprecated and unused. It will be removed in v4.39."
)
def
forward
(
self
,
x
,
position_ids
):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded
=
self
.
inv_freq
[
None
,
:,
None
].
float
().
expand
(
position_ids
.
shape
[
0
],
-
1
,
1
)
position_ids_expanded
=
position_ids
[:,
None
,
:].
float
()
...
...
@@ -151,17 +148,17 @@ class LlamaRotaryEmbedding(nn.Module):
class
LlamaLinearScalingRotaryEmbedding
(
LlamaRotaryEmbedding
):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def
forward
(
self
,
x
,
position_ids
,
seq_len
=
None
):
def
forward
(
self
,
x
,
position_ids
):
# difference to the original RoPE: a scaling factor is aplied to the position ids
position_ids
=
position_ids
.
float
()
/
self
.
scaling_factor
cos
,
sin
=
super
().
forward
(
x
,
position_ids
,
seq_len
)
cos
,
sin
=
super
().
forward
(
x
,
position_ids
)
return
cos
,
sin
class
LlamaDynamicNTKScalingRotaryEmbedding
(
LlamaRotaryEmbedding
):
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def
forward
(
self
,
x
,
position_ids
,
seq_len
=
None
):
def
forward
(
self
,
x
,
position_ids
):
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
seq_len
=
torch
.
max
(
position_ids
)
+
1
if
seq_len
>
self
.
max_position_embeddings
:
...
...
@@ -173,7 +170,7 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
)
self
.
register_buffer
(
"inv_freq"
,
inv_freq
,
persistent
=
False
)
# TODO joao: this may break with compilation
cos
,
sin
=
super
().
forward
(
x
,
position_ids
,
seq_len
)
cos
,
sin
=
super
().
forward
(
x
,
position_ids
)
return
cos
,
sin
...
...
src/transformers/models/opt/modeling_opt.py
View file @
ffe60fdc
...
...
@@ -120,27 +120,10 @@ class OPTAttention(nn.Module):
):
super
().
__init__
()
self
.
config
=
config
def
_handle_deprecated_argument
(
config_arg_name
,
config
,
fn_arg_name
,
kwargs
):
"""
If a the deprecated argument `fn_arg_name` is passed, raise a deprecation
warning and return that value, otherwise take the equivalent config.config_arg_name
"""
val
=
None
if
fn_arg_name
in
kwargs
:
logging
.
warning
(
"Passing in {fn_arg_name} to {self.__class__.__name__} is deprecated and won't be supported from "
"v4.39. Please set it in the config instead"
)
val
=
kwargs
.
pop
(
fn_arg_name
)
else
:
val
=
getattr
(
config
,
config_arg_name
)
return
val
self
.
embed_dim
=
_handle_deprecated_argument
(
"hidden_size"
,
config
,
"embed_dim"
,
kwargs
)
self
.
num_heads
=
_handle_deprecated_argument
(
"num_attention_heads"
,
config
,
"num_heads"
,
kwargs
)
self
.
dropout
=
_handle_deprecated_argument
(
"attention_dropout"
,
config
,
"dropout"
,
kwargs
)
self
.
enable_bias
=
_handle_deprecated_argument
(
"enable_bias"
,
config
,
"bias"
,
kwargs
)
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
dropout
=
config
.
attention_dropout
self
.
enable_bias
=
config
.
enable_bias
self
.
head_dim
=
self
.
embed_dim
//
self
.
num_heads
self
.
is_causal
=
True
...
...
src/transformers/utils/dummy_pt_objects.py
View file @
ffe60fdc
...
...
@@ -408,10 +408,6 @@ class WhisperTimeStampLogitsProcessor(metaclass=DummyObject):
requires_backends
(
self
,
[
"torch"
])
def
top_k_top_p_filtering
(
*
args
,
**
kwargs
):
requires_backends
(
top_k_top_p_filtering
,
[
"torch"
])
class
PreTrainedModel
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
...
...
src/transformers/utils/dummy_tf_objects.py
View file @
ffe60fdc
...
...
@@ -128,10 +128,6 @@ class TFTopPLogitsWarper(metaclass=DummyObject):
requires_backends
(
self
,
[
"tf"
])
def
tf_top_k_top_p_filtering
(
*
args
,
**
kwargs
):
requires_backends
(
tf_top_k_top_p_filtering
,
[
"tf"
])
class
KerasMetricCallback
(
metaclass
=
DummyObject
):
_backends
=
[
"tf"
]
...
...
tests/generation/test_tf_utils.py
View file @
ffe60fdc
...
...
@@ -41,7 +41,6 @@ if is_tf_available():
TFBartForConditionalGeneration
,
TFLogitsProcessorList
,
TFMinLengthLogitsProcessor
,
tf_top_k_top_p_filtering
,
)
from
transformers.modeling_tf_utils
import
keras
...
...
@@ -49,102 +48,6 @@ if is_tensorflow_text_available():
import
tensorflow_text
as
text
@
require_tf
class
UtilsFunctionsTest
(
unittest
.
TestCase
):
# tests whether the top_k_top_p_filtering function behaves as expected
def
test_top_k_top_p_filtering
(
self
):
logits
=
tf
.
convert_to_tensor
(
[
[
8.2220991
,
# 3rd highest value; idx. 0
-
0.5620044
,
5.23229752
,
4.0386393
,
-
6.8798378
,
-
0.54785802
,
-
3.2012153
,
2.92777176
,
1.88171953
,
7.35341276
,
# 5th highest value; idx. 9
8.43207833
,
# 2nd highest value; idx. 10
-
9.85711836
,
-
5.96209236
,
-
1.13039161
,
-
7.1115294
,
-
0.8369633
,
-
5.3186408
,
7.06427407
,
0.81369344
,
-
0.82023817
,
-
5.9179796
,
0.58813443
,
-
6.99778438
,
4.71551189
,
-
0.18771637
,
7.44020759
,
# 4th highest value; idx. 25
9.38450987
,
# 1st highest value; idx. 26
2.12662941
,
-
9.32562038
,
2.35652522
,
],
# cummulative prob of 5 highest values <= 0.6
[
0.58425518
,
4.53139238
,
-
5.57510464
,
-
6.28030699
,
-
7.19529503
,
-
4.02122551
,
1.39337037
,
-
6.06707057
,
1.59480517
,
-
9.643119
,
0.03907799
,
0.67231762
,
-
8.88206726
,
6.27115922
,
# 4th highest value; idx. 13
2.28520723
,
4.82767506
,
4.30421368
,
8.8275313
,
# 2nd highest value; idx. 17
5.44029958
,
# 5th highest value; idx. 18
-
4.4735794
,
7.38579536
,
# 3rd highest value; idx. 20
-
2.91051663
,
2.61946077
,
-
2.5674762
,
-
9.48959302
,
-
4.02922645
,
-
1.35416918
,
9.67702323
,
# 1st highest value; idx. 27
-
5.89478553
,
1.85370467
,
],
# cummulative prob of 5 highest values <= 0.6
],
dtype
=
tf
.
float32
,
)
non_inf_expected_idx
=
tf
.
convert_to_tensor
(
[[
0
,
0
],
[
0
,
9
],
[
0
,
10
],
[
0
,
25
],
[
0
,
26
],
[
1
,
13
],
[
1
,
17
],
[
1
,
18
],
[
1
,
20
],
[
1
,
27
]],
dtype
=
tf
.
int32
,
)
# expected non filtered idx as noted above
non_inf_expected_output
=
tf
.
convert_to_tensor
(
[
8.222099
,
7.3534126
,
8.432078
,
7.4402075
,
9.38451
,
6.271159
,
8.827531
,
5.4402995
,
7.3857956
,
9.677023
],
dtype
=
tf
.
float32
,
)
# expected non filtered values as noted above
output
=
tf_top_k_top_p_filtering
(
logits
,
top_k
=
10
,
top_p
=
0.6
,
min_tokens_to_keep
=
4
)
non_inf_output
=
output
[
output
!=
-
float
(
"inf"
)]
non_inf_idx
=
tf
.
cast
(
tf
.
where
(
tf
.
not_equal
(
output
,
tf
.
constant
(
-
float
(
"inf"
),
dtype
=
tf
.
float32
))),
dtype
=
tf
.
int32
,
)
tf
.
debugging
.
assert_near
(
non_inf_output
,
non_inf_expected_output
,
rtol
=
1e-12
)
tf
.
debugging
.
assert_equal
(
non_inf_idx
,
non_inf_expected_idx
)
@
require_tf
class
TFGenerationIntegrationTests
(
unittest
.
TestCase
,
GenerationIntegrationTestsMixin
):
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
...
...
tests/generation/test_utils.py
View file @
ffe60fdc
...
...
@@ -52,7 +52,6 @@ if is_torch_available():
GPT2Tokenizer
,
ImageGPTForCausalImageModeling
,
SpeechEncoderDecoderModel
,
top_k_top_p_filtering
,
)
from
transformers.cache_utils
import
DynamicCache
from
transformers.generation
import
(
...
...
@@ -2345,133 +2344,6 @@ class GenerationTesterMixin:
@
require_torch
class
UtilsFunctionsTest
(
unittest
.
TestCase
):
# tests whether the top_k_top_p function behaves as expected
def
test_top_k_top_p_filtering
(
self
):
logits
=
torch
.
tensor
(
[
[
8.2220991
,
# 3rd highest value; idx. 0
-
0.5620044
,
5.23229752
,
4.0386393
,
-
6.8798378
,
-
0.54785802
,
-
3.2012153
,
2.92777176
,
1.88171953
,
7.35341276
,
8.43207833
,
# 2nd highest value; idx. 10
-
9.85711836
,
-
5.96209236
,
-
1.13039161
,
-
7.1115294
,
-
0.8369633
,
-
5.3186408
,
7.06427407
,
0.81369344
,
-
0.82023817
,
-
5.9179796
,
0.58813443
,
-
6.99778438
,
4.71551189
,
-
0.18771637
,
7.44020759
,
# 4th highest value; idx. 25
9.38450987
,
# 1st highest value; idx. 26
2.12662941
,
-
9.32562038
,
2.35652522
,
],
# cummulative prob of 4 highest values <= 0.6
[
0.58425518
,
4.53139238
,
-
5.57510464
,
-
6.28030699
,
-
7.19529503
,
-
4.02122551
,
1.39337037
,
-
6.06707057
,
1.59480517
,
-
9.643119
,
0.03907799
,
0.67231762
,
-
8.88206726
,
6.27115922
,
# 4th highest value; idx. 13
2.28520723
,
4.82767506
,
4.30421368
,
8.8275313
,
# 2nd highest value; idx. 17
5.44029958
,
-
4.4735794
,
7.38579536
,
# 3rd highest value; idx. 20
-
2.91051663
,
2.61946077
,
-
2.5674762
,
-
9.48959302
,
-
4.02922645
,
-
1.35416918
,
9.67702323
,
# 1st highest value; idx. 27
-
5.89478553
,
1.85370467
,
],
# cummulative prob of 4 highest values <= 0.6
],
dtype
=
torch
.
float
,
device
=
torch_device
,
)
non_inf_expected_idx
=
torch
.
tensor
(
[[
0
,
0
],
[
0
,
10
],
[
0
,
25
],
[
0
,
26
],
[
1
,
13
],
[
1
,
17
],
[
1
,
20
],
[
1
,
27
]],
dtype
=
torch
.
long
,
device
=
torch_device
,
)
# expected non filtered idx as noted above
non_inf_expected_output
=
torch
.
tensor
(
[
8.2221
,
8.4321
,
7.4402
,
9.3845
,
6.2712
,
8.8275
,
7.3858
,
9.6770
,
],
# expected non filtered values as noted above
dtype
=
torch
.
float
,
device
=
torch_device
,
)
output
=
top_k_top_p_filtering
(
logits
,
top_k
=
10
,
top_p
=
0.6
,
min_tokens_to_keep
=
4
)
non_inf_output
=
output
[
output
!=
-
float
(
"inf"
)].
to
(
device
=
torch_device
)
non_inf_idx
=
(
output
!=
-
float
(
"inf"
)).
nonzero
().
to
(
device
=
torch_device
)
self
.
assertTrue
(
torch
.
allclose
(
non_inf_expected_output
,
non_inf_output
,
atol
=
1e-12
))
self
.
assertTrue
(
torch
.
all
(
torch
.
eq
(
non_inf_expected_idx
,
non_inf_idx
)))
# tests whether the function uses filter_value instead of default -inf
def
test_top_k_top_p_filtering_with_filter_value
(
self
):
logits
=
torch
.
tensor
(
[
[
1
,
1
,
1
,
0.99
,
# get filtered by top-p filtering
0.98
,
# get filtered by top-k filtering
]
],
dtype
=
torch
.
float
,
device
=
torch_device
,
)
expected_output
=
torch
.
tensor
(
[[
1
,
1
,
1
,
0
,
0
]],
dtype
=
torch
.
float
,
device
=
torch_device
,
)
output
=
top_k_top_p_filtering
(
logits
,
top_k
=
4
,
top_p
=
0.5
,
filter_value
=
0.0
)
self
.
assertTrue
(
torch
.
allclose
(
expected_output
,
output
,
atol
=
1e-12
))
def
test_speculative_sampling
(
self
):
# assume vocab size 10, input length 5 + 3 generated candidates
candidate_input_ids
=
torch
.
tensor
([[
8
,
0
,
3
,
9
,
8
,
1
,
4
,
5
]])
# input tokens
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment