Unverified Commit ece1b62b authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: v4.38 removals and related updates (#29171)

parent 24d59c79
...@@ -40,6 +40,11 @@ else: ...@@ -40,6 +40,11 @@ else:
"BeamSearchScorer", "BeamSearchScorer",
"ConstrainedBeamSearchScorer", "ConstrainedBeamSearchScorer",
] ]
_import_structure["candidate_generator"] = [
"AssistedCandidateGenerator",
"CandidateGenerator",
"PromptLookupCandidateGenerator",
]
_import_structure["logits_process"] = [ _import_structure["logits_process"] = [
"AlternatingCodebooksLogitsProcessor", "AlternatingCodebooksLogitsProcessor",
"ClassifierFreeGuidanceLogitsProcessor", "ClassifierFreeGuidanceLogitsProcessor",
...@@ -178,6 +183,7 @@ if TYPE_CHECKING: ...@@ -178,6 +183,7 @@ if TYPE_CHECKING:
else: else:
from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .candidate_generator import AssistedCandidateGenerator, CandidateGenerator, PromptLookupCandidateGenerator
from .logits_process import ( from .logits_process import (
AlternatingCodebooksLogitsProcessor, AlternatingCodebooksLogitsProcessor,
ClassifierFreeGuidanceLogitsProcessor, ClassifierFreeGuidanceLogitsProcessor,
......
...@@ -99,7 +99,8 @@ class AssistedCandidateGenerator(CandidateGenerator): ...@@ -99,7 +99,8 @@ class AssistedCandidateGenerator(CandidateGenerator):
# Make sure all data at the same device as assistant model # Make sure all data at the same device as assistant model
device = assistant_model.device device = assistant_model.device
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
inputs_tensor = inputs_tensor.to(device) if inputs_tensor is not None:
inputs_tensor = inputs_tensor.to(device)
# Prepare the assistant and the starting number of candidate tokens # Prepare the assistant and the starting number of candidate tokens
self.assistant_model = assistant_model self.assistant_model = assistant_model
......
...@@ -4319,7 +4319,6 @@ class GenerationMixin: ...@@ -4319,7 +4319,6 @@ class GenerationMixin:
def assisted_decoding( def assisted_decoding(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
assistant_model: Optional["PreTrainedModel"] = None,
candidate_generator: Optional["CandidateGenerator"] = None, candidate_generator: Optional["CandidateGenerator"] = None,
do_sample: bool = False, do_sample: bool = False,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
...@@ -4355,12 +4354,7 @@ class GenerationMixin: ...@@ -4355,12 +4354,7 @@ class GenerationMixin:
The sequence used as a prompt for the generation. The sequence used as a prompt for the generation.
candidate_generator (`CandidateGenerator`, *optional*): candidate_generator (`CandidateGenerator`, *optional*):
A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For
more information, the documentation of [`CandidateGenerator`] should be read. Only one of `assistant_model` or `candidate_generator` should be passed as input to this function. more information, the documentation of [`CandidateGenerator`] should be read.
assistant_model (`PreTrainedModel`, *optional*):
An assistant model that can be used to accelerate generation. The assistant model must have the exact
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
is much faster than running generation with the model you're calling generate from. As such, the
assistant model should be much smaller.
do_sample (`bool`, *optional*, defaults to `False`): do_sample (`bool`, *optional*, defaults to `False`):
Whether or not to use sampling ; use greedy decoding otherwise. Whether or not to use sampling ; use greedy decoding otherwise.
logits_processor (`LogitsProcessorList`, *optional*): logits_processor (`LogitsProcessorList`, *optional*):
...@@ -4417,6 +4411,7 @@ class GenerationMixin: ...@@ -4417,6 +4411,7 @@ class GenerationMixin:
... StoppingCriteriaList, ... StoppingCriteriaList,
... MaxLengthCriteria, ... MaxLengthCriteria,
... ) ... )
>>> from transformers.generation import AssistedCandidateGenerator
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
...@@ -4432,33 +4427,22 @@ class GenerationMixin: ...@@ -4432,33 +4427,22 @@ class GenerationMixin:
... ] ... ]
... ) ... )
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> candidate_generator = AssistedCandidateGenerator(
... input_ids=input_ids,
... assistant_model=assistant_model,
... generation_config=model.generation_config,
... logits_processor=logits_processor,
... model_kwargs={},
... )
>>> outputs = model.assisted_decoding( >>> outputs = model.assisted_decoding(
... input_ids, ... input_ids,
... assistant_model=assistant_model, ... candidate_generator=candidate_generator,
... logits_processor=logits_processor, ... logits_processor=logits_processor,
... stopping_criteria=stopping_criteria, ... stopping_criteria=stopping_criteria,
... ) ... )
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
["It might be possible to get a better understanding of the nature of the problem, but it's not"] ["It might be possible to get a better understanding of the nature of the problem, but it's not"]
```""" ```"""
# handling deprecated arguments
if (assistant_model is None) == (candidate_generator is None):
raise ValueError("One (and only one) of `assistant_model` and `candidate_generator` should be defined.")
if assistant_model is not None:
candidate_generator = AssistedCandidateGenerator(
input_ids=input_ids,
assistant_model=assistant_model,
logits_processor=logits_processor,
model_kwargs=model_kwargs,
eos_token_id=eos_token_id,
)
warnings.warn(
"Passing `assistant_model` to `assisted_decoding` is deprecated and will be removed in v4.38. "
"Pass the `candidate_generator` argument instead.",
FutureWarning,
)
# init values # init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
......
...@@ -129,8 +129,8 @@ class OPTAttention(nn.Module): ...@@ -129,8 +129,8 @@ class OPTAttention(nn.Module):
val = None val = None
if fn_arg_name in kwargs: if fn_arg_name in kwargs:
logging.warning( logging.warning(
"Passing in {} to {self.__class__.__name__} is deprecated and won't be supported from v4.38." "Passing in {fn_arg_name} to {self.__class__.__name__} is deprecated and won't be supported from "
" Please set it in the config instead" "v4.39. Please set it in the config instead"
) )
val = kwargs.pop(fn_arg_name) val = kwargs.pop(fn_arg_name)
else: else:
......
...@@ -120,7 +120,6 @@ from .import_utils import ( ...@@ -120,7 +120,6 @@ from .import_utils import (
is_essentia_available, is_essentia_available,
is_faiss_available, is_faiss_available,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_available,
is_flash_attn_greater_or_equal_2_10, is_flash_attn_greater_or_equal_2_10,
is_flax_available, is_flax_available,
is_fsdp_available, is_fsdp_available,
......
...@@ -665,14 +665,6 @@ def is_flash_attn_greater_or_equal_2_10(): ...@@ -665,14 +665,6 @@ def is_flash_attn_greater_or_equal_2_10():
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0") return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
def is_flash_attn_available():
logger.warning(
"Using `is_flash_attn_available` is deprecated and will be removed in v4.38. "
"Please use `is_flash_attn_2_available` instead."
)
return is_flash_attn_2_available()
def is_torchdistx_available(): def is_torchdistx_available():
return _torchdistx_available return _torchdistx_available
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment