Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ffe60fdc
Unverified
Commit
ffe60fdc
authored
Mar 07, 2024
by
Joao Gante
Committed by
GitHub
Mar 07, 2024
Browse files
v4.39 deprecations 🧼 (#29492)
parent
979fccc9
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
9 additions
and
400 deletions
+9
-400
docs/source/en/internal/generation_utils.md
docs/source/en/internal/generation_utils.md
+0
-6
docs/source/ja/internal/generation_utils.md
docs/source/ja/internal/generation_utils.md
+0
-6
docs/source/zh/internal/generation_utils.md
docs/source/zh/internal/generation_utils.md
+0
-6
src/transformers/__init__.py
src/transformers/__init__.py
+0
-4
src/transformers/activations.py
src/transformers/activations.py
+0
-9
src/transformers/generation/__init__.py
src/transformers/generation/__init__.py
+0
-4
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+0
-62
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+0
-41
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+5
-8
src/transformers/models/opt/modeling_opt.py
src/transformers/models/opt/modeling_opt.py
+4
-21
src/transformers/utils/dummy_pt_objects.py
src/transformers/utils/dummy_pt_objects.py
+0
-4
src/transformers/utils/dummy_tf_objects.py
src/transformers/utils/dummy_tf_objects.py
+0
-4
tests/generation/test_tf_utils.py
tests/generation/test_tf_utils.py
+0
-97
tests/generation/test_utils.py
tests/generation/test_utils.py
+0
-128
No files found.
docs/source/en/internal/generation_utils.md
View file @
ffe60fdc
...
...
@@ -336,12 +336,6 @@ A [`Constraint`] can be used to force the generation to include specific tokens
-
process
-
finalize
## Utilities
[[autodoc]] top_k_top_p_filtering
[[autodoc]] tf_top_k_top_p_filtering
## Streamers
[[autodoc]] TextStreamer
...
...
docs/source/ja/internal/generation_utils.md
View file @
ffe60fdc
...
...
@@ -335,12 +335,6 @@ generation_output[:2]
-
process
-
finalize
## Utilities
[[autodoc]] top_k_top_p_filtering
[[autodoc]] tf_top_k_top_p_filtering
## Streamers
[[autodoc]] TextStreamer
...
...
docs/source/zh/internal/generation_utils.md
View file @
ffe60fdc
...
...
@@ -330,12 +330,6 @@ generation_output[:2]
-
process
-
finalize
## Utilities
[[autodoc]] top_k_top_p_filtering
[[autodoc]] tf_top_k_top_p_filtering
## Streamers
[[autodoc]] TextStreamer
...
...
src/transformers/__init__.py
View file @
ffe60fdc
...
...
@@ -1409,7 +1409,6 @@ else:
"TypicalLogitsWarper"
,
"UnbatchedClassifierFreeGuidanceLogitsProcessor"
,
"WhisperTimeStampLogitsProcessor"
,
"top_k_top_p_filtering"
,
]
)
_import_structure
[
"generation_utils"
]
=
[]
...
...
@@ -3814,7 +3813,6 @@ else:
"TFTemperatureLogitsWarper"
,
"TFTopKLogitsWarper"
,
"TFTopPLogitsWarper"
,
"tf_top_k_top_p_filtering"
,
]
)
_import_structure
[
"generation_tf_utils"
]
=
[]
...
...
@@ -6206,7 +6204,6 @@ if TYPE_CHECKING:
TypicalLogitsWarper
,
UnbatchedClassifierFreeGuidanceLogitsProcessor
,
WhisperTimeStampLogitsProcessor
,
top_k_top_p_filtering
,
)
from
.modeling_utils
import
PreTrainedModel
from
.models.albert
import
(
...
...
@@ -8178,7 +8175,6 @@ if TYPE_CHECKING:
TFTemperatureLogitsWarper
,
TFTopKLogitsWarper
,
TFTopPLogitsWarper
,
tf_top_k_top_p_filtering
,
)
from
.keras_callbacks
import
KerasMetricCallback
,
PushToHubCallback
from
.modeling_tf_utils
import
(
...
...
src/transformers/activations.py
View file @
ffe60fdc
...
...
@@ -13,7 +13,6 @@
# limitations under the License.
import
math
import
warnings
from
collections
import
OrderedDict
import
torch
...
...
@@ -138,14 +137,6 @@ class AccurateGELUActivation(nn.Module):
return
0.5
*
input
*
(
1
+
torch
.
tanh
(
self
.
precomputed_constant
*
(
input
+
0.044715
*
torch
.
pow
(
input
,
3
))))
class
SiLUActivation
(
nn
.
SiLU
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
warnings
.
warn
(
"The SiLUActivation class has been deprecated and will be removed in v4.39. Please use nn.SiLU instead."
,
)
super
().
__init__
(
*
args
,
**
kwargs
)
class
MishActivation
(
nn
.
Module
):
"""
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
...
...
src/transformers/generation/__init__.py
View file @
ffe60fdc
...
...
@@ -88,7 +88,6 @@ else:
]
_import_structure
[
"utils"
]
=
[
"GenerationMixin"
,
"top_k_top_p_filtering"
,
"GreedySearchEncoderDecoderOutput"
,
"GreedySearchDecoderOnlyOutput"
,
"SampleEncoderDecoderOutput"
,
...
...
@@ -130,7 +129,6 @@ else:
]
_import_structure
[
"tf_utils"
]
=
[
"TFGenerationMixin"
,
"tf_top_k_top_p_filtering"
,
"TFGreedySearchDecoderOnlyOutput"
,
"TFGreedySearchEncoderDecoderOutput"
,
"TFSampleEncoderDecoderOutput"
,
...
...
@@ -241,7 +239,6 @@ if TYPE_CHECKING:
GreedySearchEncoderDecoderOutput
,
SampleDecoderOnlyOutput
,
SampleEncoderDecoderOutput
,
top_k_top_p_filtering
,
)
try
:
...
...
@@ -279,7 +276,6 @@ if TYPE_CHECKING:
TFGreedySearchEncoderDecoderOutput
,
TFSampleDecoderOnlyOutput
,
TFSampleEncoderDecoderOutput
,
tf_top_k_top_p_filtering
,
)
try
:
...
...
src/transformers/generation/tf_utils.py
View file @
ffe60fdc
...
...
@@ -3088,68 +3088,6 @@ class TFGenerationMixin:
return
generated
def
tf_top_k_top_p_filtering
(
logits
,
top_k
=
0
,
top_p
=
1.0
,
filter_value
=-
float
(
"Inf"
),
min_tokens_to_keep
=
1
):
"""
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
top_k (`int`, *optional*, defaults to 0):
If > 0, only keep the top k tokens with highest probability (top-k filtering)
top_p (`float`, *optional*, defaults to 1.0):
If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimumber of tokens we keep per batch example in the output.
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
warnings
.
warn
(
"`tf_top_k_top_p_filtering` is scheduled for deletion in v4.39. Use `TFTopKLogitsWarper` and "
"`TFTopPLogitsWarper` instead."
,
DeprecationWarning
,
)
logits_shape
=
shape_list
(
logits
)
if
top_k
>
0
:
top_k
=
min
(
max
(
top_k
,
min_tokens_to_keep
),
logits_shape
[
-
1
])
# Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove
=
logits
<
tf
.
math
.
top_k
(
logits
,
k
=
top_k
)[
0
][...,
-
1
,
None
]
logits
=
tf
.
where
(
indices_to_remove
,
filter_value
,
logits
)
if
top_p
<
1.0
:
sorted_indices
=
tf
.
argsort
(
logits
,
direction
=
"DESCENDING"
)
sorted_logits
=
tf
.
gather
(
logits
,
sorted_indices
,
axis
=-
1
,
batch_dims
=
1
)
# expects logits to be of dim (batch_size, vocab_size)
cumulative_probs
=
tf
.
math
.
cumsum
(
stable_softmax
(
sorted_logits
,
axis
=-
1
),
axis
=-
1
)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove
=
cumulative_probs
>
top_p
if
min_tokens_to_keep
>
1
:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove
=
tf
.
concat
(
[
tf
.
zeros_like
(
sorted_indices_to_remove
[:,
:
min_tokens_to_keep
]),
sorted_indices_to_remove
[:,
min_tokens_to_keep
:],
],
-
1
,
)
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove
=
tf
.
concat
(
[
tf
.
zeros_like
(
sorted_indices_to_remove
[:,
:
1
]),
sorted_indices_to_remove
[:,
:
-
1
]],
-
1
,
)
# scatter sorted tensors to original indexing
indices_to_remove
=
scatter_values_on_batch_indices
(
sorted_indices_to_remove
,
sorted_indices
)
logits
=
tf
.
where
(
indices_to_remove
,
filter_value
,
logits
)
return
logits
def
scatter_values_on_batch_indices
(
values
,
batch_indices
):
shape
=
shape_list
(
batch_indices
)
# broadcast batch dim to shape
...
...
src/transformers/generation/utils.py
View file @
ffe60fdc
...
...
@@ -4810,47 +4810,6 @@ def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_at
return
outputs
def
top_k_top_p_filtering
(
logits
:
torch
.
FloatTensor
,
top_k
:
int
=
0
,
top_p
:
float
=
1.0
,
filter_value
:
float
=
-
float
(
"Inf"
),
min_tokens_to_keep
:
int
=
1
,
)
->
torch
.
FloatTensor
:
"""
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
top_k (`int`, *optional*, defaults to 0):
If > 0, only keep the top k tokens with highest probability (top-k filtering)
top_p (`float`, *optional*, defaults to 1.0):
If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimumber of tokens we keep per batch example in the output.
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
warnings
.
warn
(
"`top_k_top_p_filtering` is scheduled for deletion in v4.39. Use `TopKLogitsWarper` and `TopPLogitsWarper` "
"instead."
,
DeprecationWarning
,
)
if
top_k
>
0
:
logits
=
TopKLogitsWarper
(
top_k
=
top_k
,
filter_value
=
filter_value
,
min_tokens_to_keep
=
min_tokens_to_keep
)(
None
,
logits
)
if
0
<=
top_p
<=
1.0
:
logits
=
TopPLogitsWarper
(
top_p
=
top_p
,
filter_value
=
filter_value
,
min_tokens_to_keep
=
min_tokens_to_keep
)(
None
,
logits
)
return
logits
def
_ranking_fast
(
context_hidden
:
torch
.
FloatTensor
,
next_hidden
:
torch
.
FloatTensor
,
...
...
src/transformers/models/llama/modeling_llama.py
View file @
ffe60fdc
...
...
@@ -129,10 +129,7 @@ class LlamaRotaryEmbedding(nn.Module):
return
self
.
_cos_cached
@
torch
.
no_grad
()
def
forward
(
self
,
x
,
position_ids
,
seq_len
=
None
):
if
seq_len
is
not
None
:
logger
.
warning_once
(
"The `seq_len` argument is deprecated and unused. It will be removed in v4.39."
)
def
forward
(
self
,
x
,
position_ids
):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded
=
self
.
inv_freq
[
None
,
:,
None
].
float
().
expand
(
position_ids
.
shape
[
0
],
-
1
,
1
)
position_ids_expanded
=
position_ids
[:,
None
,
:].
float
()
...
...
@@ -151,17 +148,17 @@ class LlamaRotaryEmbedding(nn.Module):
class
LlamaLinearScalingRotaryEmbedding
(
LlamaRotaryEmbedding
):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def
forward
(
self
,
x
,
position_ids
,
seq_len
=
None
):
def
forward
(
self
,
x
,
position_ids
):
# difference to the original RoPE: a scaling factor is aplied to the position ids
position_ids
=
position_ids
.
float
()
/
self
.
scaling_factor
cos
,
sin
=
super
().
forward
(
x
,
position_ids
,
seq_len
)
cos
,
sin
=
super
().
forward
(
x
,
position_ids
)
return
cos
,
sin
class
LlamaDynamicNTKScalingRotaryEmbedding
(
LlamaRotaryEmbedding
):
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def
forward
(
self
,
x
,
position_ids
,
seq_len
=
None
):
def
forward
(
self
,
x
,
position_ids
):
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
seq_len
=
torch
.
max
(
position_ids
)
+
1
if
seq_len
>
self
.
max_position_embeddings
:
...
...
@@ -173,7 +170,7 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
)
self
.
register_buffer
(
"inv_freq"
,
inv_freq
,
persistent
=
False
)
# TODO joao: this may break with compilation
cos
,
sin
=
super
().
forward
(
x
,
position_ids
,
seq_len
)
cos
,
sin
=
super
().
forward
(
x
,
position_ids
)
return
cos
,
sin
...
...
src/transformers/models/opt/modeling_opt.py
View file @
ffe60fdc
...
...
@@ -120,27 +120,10 @@ class OPTAttention(nn.Module):
):
super
().
__init__
()
self
.
config
=
config
def
_handle_deprecated_argument
(
config_arg_name
,
config
,
fn_arg_name
,
kwargs
):
"""
If a the deprecated argument `fn_arg_name` is passed, raise a deprecation
warning and return that value, otherwise take the equivalent config.config_arg_name
"""
val
=
None
if
fn_arg_name
in
kwargs
:
logging
.
warning
(
"Passing in {fn_arg_name} to {self.__class__.__name__} is deprecated and won't be supported from "
"v4.39. Please set it in the config instead"
)
val
=
kwargs
.
pop
(
fn_arg_name
)
else
:
val
=
getattr
(
config
,
config_arg_name
)
return
val
self
.
embed_dim
=
_handle_deprecated_argument
(
"hidden_size"
,
config
,
"embed_dim"
,
kwargs
)
self
.
num_heads
=
_handle_deprecated_argument
(
"num_attention_heads"
,
config
,
"num_heads"
,
kwargs
)
self
.
dropout
=
_handle_deprecated_argument
(
"attention_dropout"
,
config
,
"dropout"
,
kwargs
)
self
.
enable_bias
=
_handle_deprecated_argument
(
"enable_bias"
,
config
,
"bias"
,
kwargs
)
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
dropout
=
config
.
attention_dropout
self
.
enable_bias
=
config
.
enable_bias
self
.
head_dim
=
self
.
embed_dim
//
self
.
num_heads
self
.
is_causal
=
True
...
...
src/transformers/utils/dummy_pt_objects.py
View file @
ffe60fdc
...
...
@@ -408,10 +408,6 @@ class WhisperTimeStampLogitsProcessor(metaclass=DummyObject):
requires_backends
(
self
,
[
"torch"
])
def
top_k_top_p_filtering
(
*
args
,
**
kwargs
):
requires_backends
(
top_k_top_p_filtering
,
[
"torch"
])
class
PreTrainedModel
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
...
...
src/transformers/utils/dummy_tf_objects.py
View file @
ffe60fdc
...
...
@@ -128,10 +128,6 @@ class TFTopPLogitsWarper(metaclass=DummyObject):
requires_backends
(
self
,
[
"tf"
])
def
tf_top_k_top_p_filtering
(
*
args
,
**
kwargs
):
requires_backends
(
tf_top_k_top_p_filtering
,
[
"tf"
])
class
KerasMetricCallback
(
metaclass
=
DummyObject
):
_backends
=
[
"tf"
]
...
...
tests/generation/test_tf_utils.py
View file @
ffe60fdc
...
...
@@ -41,7 +41,6 @@ if is_tf_available():
TFBartForConditionalGeneration
,
TFLogitsProcessorList
,
TFMinLengthLogitsProcessor
,
tf_top_k_top_p_filtering
,
)
from
transformers.modeling_tf_utils
import
keras
...
...
@@ -49,102 +48,6 @@ if is_tensorflow_text_available():
import
tensorflow_text
as
text
@
require_tf
class
UtilsFunctionsTest
(
unittest
.
TestCase
):
# tests whether the top_k_top_p_filtering function behaves as expected
def
test_top_k_top_p_filtering
(
self
):
logits
=
tf
.
convert_to_tensor
(
[
[
8.2220991
,
# 3rd highest value; idx. 0
-
0.5620044
,
5.23229752
,
4.0386393
,
-
6.8798378
,
-
0.54785802
,
-
3.2012153
,
2.92777176
,
1.88171953
,
7.35341276
,
# 5th highest value; idx. 9
8.43207833
,
# 2nd highest value; idx. 10
-
9.85711836
,
-
5.96209236
,
-
1.13039161
,
-
7.1115294
,
-
0.8369633
,
-
5.3186408
,
7.06427407
,
0.81369344
,
-
0.82023817
,
-
5.9179796
,
0.58813443
,
-
6.99778438
,
4.71551189
,
-
0.18771637
,
7.44020759
,
# 4th highest value; idx. 25
9.38450987
,
# 1st highest value; idx. 26
2.12662941
,
-
9.32562038
,
2.35652522
,
],
# cummulative prob of 5 highest values <= 0.6
[
0.58425518
,
4.53139238
,
-
5.57510464
,
-
6.28030699
,
-
7.19529503
,
-
4.02122551
,
1.39337037
,
-
6.06707057
,
1.59480517
,
-
9.643119
,
0.03907799
,
0.67231762
,
-
8.88206726
,
6.27115922
,
# 4th highest value; idx. 13
2.28520723
,
4.82767506
,
4.30421368
,
8.8275313
,
# 2nd highest value; idx. 17
5.44029958
,
# 5th highest value; idx. 18
-
4.4735794
,
7.38579536
,
# 3rd highest value; idx. 20
-
2.91051663
,
2.61946077
,
-
2.5674762
,
-
9.48959302
,
-
4.02922645
,
-
1.35416918
,
9.67702323
,
# 1st highest value; idx. 27
-
5.89478553
,
1.85370467
,
],
# cummulative prob of 5 highest values <= 0.6
],
dtype
=
tf
.
float32
,
)
non_inf_expected_idx
=
tf
.
convert_to_tensor
(
[[
0
,
0
],
[
0
,
9
],
[
0
,
10
],
[
0
,
25
],
[
0
,
26
],
[
1
,
13
],
[
1
,
17
],
[
1
,
18
],
[
1
,
20
],
[
1
,
27
]],
dtype
=
tf
.
int32
,
)
# expected non filtered idx as noted above
non_inf_expected_output
=
tf
.
convert_to_tensor
(
[
8.222099
,
7.3534126
,
8.432078
,
7.4402075
,
9.38451
,
6.271159
,
8.827531
,
5.4402995
,
7.3857956
,
9.677023
],
dtype
=
tf
.
float32
,
)
# expected non filtered values as noted above
output
=
tf_top_k_top_p_filtering
(
logits
,
top_k
=
10
,
top_p
=
0.6
,
min_tokens_to_keep
=
4
)
non_inf_output
=
output
[
output
!=
-
float
(
"inf"
)]
non_inf_idx
=
tf
.
cast
(
tf
.
where
(
tf
.
not_equal
(
output
,
tf
.
constant
(
-
float
(
"inf"
),
dtype
=
tf
.
float32
))),
dtype
=
tf
.
int32
,
)
tf
.
debugging
.
assert_near
(
non_inf_output
,
non_inf_expected_output
,
rtol
=
1e-12
)
tf
.
debugging
.
assert_equal
(
non_inf_idx
,
non_inf_expected_idx
)
@
require_tf
class
TFGenerationIntegrationTests
(
unittest
.
TestCase
,
GenerationIntegrationTestsMixin
):
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
...
...
tests/generation/test_utils.py
View file @
ffe60fdc
...
...
@@ -52,7 +52,6 @@ if is_torch_available():
GPT2Tokenizer
,
ImageGPTForCausalImageModeling
,
SpeechEncoderDecoderModel
,
top_k_top_p_filtering
,
)
from
transformers.cache_utils
import
DynamicCache
from
transformers.generation
import
(
...
...
@@ -2345,133 +2344,6 @@ class GenerationTesterMixin:
@
require_torch
class
UtilsFunctionsTest
(
unittest
.
TestCase
):
# tests whether the top_k_top_p function behaves as expected
def
test_top_k_top_p_filtering
(
self
):
logits
=
torch
.
tensor
(
[
[
8.2220991
,
# 3rd highest value; idx. 0
-
0.5620044
,
5.23229752
,
4.0386393
,
-
6.8798378
,
-
0.54785802
,
-
3.2012153
,
2.92777176
,
1.88171953
,
7.35341276
,
8.43207833
,
# 2nd highest value; idx. 10
-
9.85711836
,
-
5.96209236
,
-
1.13039161
,
-
7.1115294
,
-
0.8369633
,
-
5.3186408
,
7.06427407
,
0.81369344
,
-
0.82023817
,
-
5.9179796
,
0.58813443
,
-
6.99778438
,
4.71551189
,
-
0.18771637
,
7.44020759
,
# 4th highest value; idx. 25
9.38450987
,
# 1st highest value; idx. 26
2.12662941
,
-
9.32562038
,
2.35652522
,
],
# cummulative prob of 4 highest values <= 0.6
[
0.58425518
,
4.53139238
,
-
5.57510464
,
-
6.28030699
,
-
7.19529503
,
-
4.02122551
,
1.39337037
,
-
6.06707057
,
1.59480517
,
-
9.643119
,
0.03907799
,
0.67231762
,
-
8.88206726
,
6.27115922
,
# 4th highest value; idx. 13
2.28520723
,
4.82767506
,
4.30421368
,
8.8275313
,
# 2nd highest value; idx. 17
5.44029958
,
-
4.4735794
,
7.38579536
,
# 3rd highest value; idx. 20
-
2.91051663
,
2.61946077
,
-
2.5674762
,
-
9.48959302
,
-
4.02922645
,
-
1.35416918
,
9.67702323
,
# 1st highest value; idx. 27
-
5.89478553
,
1.85370467
,
],
# cummulative prob of 4 highest values <= 0.6
],
dtype
=
torch
.
float
,
device
=
torch_device
,
)
non_inf_expected_idx
=
torch
.
tensor
(
[[
0
,
0
],
[
0
,
10
],
[
0
,
25
],
[
0
,
26
],
[
1
,
13
],
[
1
,
17
],
[
1
,
20
],
[
1
,
27
]],
dtype
=
torch
.
long
,
device
=
torch_device
,
)
# expected non filtered idx as noted above
non_inf_expected_output
=
torch
.
tensor
(
[
8.2221
,
8.4321
,
7.4402
,
9.3845
,
6.2712
,
8.8275
,
7.3858
,
9.6770
,
],
# expected non filtered values as noted above
dtype
=
torch
.
float
,
device
=
torch_device
,
)
output
=
top_k_top_p_filtering
(
logits
,
top_k
=
10
,
top_p
=
0.6
,
min_tokens_to_keep
=
4
)
non_inf_output
=
output
[
output
!=
-
float
(
"inf"
)].
to
(
device
=
torch_device
)
non_inf_idx
=
(
output
!=
-
float
(
"inf"
)).
nonzero
().
to
(
device
=
torch_device
)
self
.
assertTrue
(
torch
.
allclose
(
non_inf_expected_output
,
non_inf_output
,
atol
=
1e-12
))
self
.
assertTrue
(
torch
.
all
(
torch
.
eq
(
non_inf_expected_idx
,
non_inf_idx
)))
# tests whether the function uses filter_value instead of default -inf
def
test_top_k_top_p_filtering_with_filter_value
(
self
):
logits
=
torch
.
tensor
(
[
[
1
,
1
,
1
,
0.99
,
# get filtered by top-p filtering
0.98
,
# get filtered by top-k filtering
]
],
dtype
=
torch
.
float
,
device
=
torch_device
,
)
expected_output
=
torch
.
tensor
(
[[
1
,
1
,
1
,
0
,
0
]],
dtype
=
torch
.
float
,
device
=
torch_device
,
)
output
=
top_k_top_p_filtering
(
logits
,
top_k
=
4
,
top_p
=
0.5
,
filter_value
=
0.0
)
self
.
assertTrue
(
torch
.
allclose
(
expected_output
,
output
,
atol
=
1e-12
))
def
test_speculative_sampling
(
self
):
# assume vocab size 10, input length 5 + 3 generated candidates
candidate_input_ids
=
torch
.
tensor
([[
8
,
0
,
3
,
9
,
8
,
1
,
4
,
5
]])
# input tokens
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment