Unverified Commit ece1b62b authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: v4.38 removals and related updates (#29171)

parent 24d59c79
......@@ -40,6 +40,11 @@ else:
"BeamSearchScorer",
"ConstrainedBeamSearchScorer",
]
_import_structure["candidate_generator"] = [
"AssistedCandidateGenerator",
"CandidateGenerator",
"PromptLookupCandidateGenerator",
]
_import_structure["logits_process"] = [
"AlternatingCodebooksLogitsProcessor",
"ClassifierFreeGuidanceLogitsProcessor",
......@@ -178,6 +183,7 @@ if TYPE_CHECKING:
else:
from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .candidate_generator import AssistedCandidateGenerator, CandidateGenerator, PromptLookupCandidateGenerator
from .logits_process import (
AlternatingCodebooksLogitsProcessor,
ClassifierFreeGuidanceLogitsProcessor,
......
......@@ -99,6 +99,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
# Make sure all data at the same device as assistant model
device = assistant_model.device
input_ids = input_ids.to(device)
if inputs_tensor is not None:
inputs_tensor = inputs_tensor.to(device)
# Prepare the assistant and the starting number of candidate tokens
......
......@@ -4319,7 +4319,6 @@ class GenerationMixin:
def assisted_decoding(
self,
input_ids: torch.LongTensor,
assistant_model: Optional["PreTrainedModel"] = None,
candidate_generator: Optional["CandidateGenerator"] = None,
do_sample: bool = False,
logits_processor: Optional[LogitsProcessorList] = None,
......@@ -4355,12 +4354,7 @@ class GenerationMixin:
The sequence used as a prompt for the generation.
candidate_generator (`CandidateGenerator`, *optional*):
A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For
more information, the documentation of [`CandidateGenerator`] should be read. Only one of `assistant_model` or `candidate_generator` should be passed as input to this function.
assistant_model (`PreTrainedModel`, *optional*):
An assistant model that can be used to accelerate generation. The assistant model must have the exact
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
is much faster than running generation with the model you're calling generate from. As such, the
assistant model should be much smaller.
more information, the documentation of [`CandidateGenerator`] should be read.
do_sample (`bool`, *optional*, defaults to `False`):
Whether or not to use sampling ; use greedy decoding otherwise.
logits_processor (`LogitsProcessorList`, *optional*):
......@@ -4417,6 +4411,7 @@ class GenerationMixin:
... StoppingCriteriaList,
... MaxLengthCriteria,
... )
>>> from transformers.generation import AssistedCandidateGenerator
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
......@@ -4432,33 +4427,22 @@ class GenerationMixin:
... ]
... )
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> candidate_generator = AssistedCandidateGenerator(
... input_ids=input_ids,
... assistant_model=assistant_model,
... generation_config=model.generation_config,
... logits_processor=logits_processor,
... model_kwargs={},
... )
>>> outputs = model.assisted_decoding(
... input_ids,
... assistant_model=assistant_model,
... candidate_generator=candidate_generator,
... logits_processor=logits_processor,
... stopping_criteria=stopping_criteria,
... )
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
```"""
# handling deprecated arguments
if (assistant_model is None) == (candidate_generator is None):
raise ValueError("One (and only one) of `assistant_model` and `candidate_generator` should be defined.")
if assistant_model is not None:
candidate_generator = AssistedCandidateGenerator(
input_ids=input_ids,
assistant_model=assistant_model,
logits_processor=logits_processor,
model_kwargs=model_kwargs,
eos_token_id=eos_token_id,
)
warnings.warn(
"Passing `assistant_model` to `assisted_decoding` is deprecated and will be removed in v4.38. "
"Pass the `candidate_generator` argument instead.",
FutureWarning,
)
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
......
......@@ -129,8 +129,8 @@ class OPTAttention(nn.Module):
val = None
if fn_arg_name in kwargs:
logging.warning(
"Passing in {} to {self.__class__.__name__} is deprecated and won't be supported from v4.38."
" Please set it in the config instead"
"Passing in {fn_arg_name} to {self.__class__.__name__} is deprecated and won't be supported from "
"v4.39. Please set it in the config instead"
)
val = kwargs.pop(fn_arg_name)
else:
......
......@@ -120,7 +120,6 @@ from .import_utils import (
is_essentia_available,
is_faiss_available,
is_flash_attn_2_available,
is_flash_attn_available,
is_flash_attn_greater_or_equal_2_10,
is_flax_available,
is_fsdp_available,
......
......@@ -665,14 +665,6 @@ def is_flash_attn_greater_or_equal_2_10():
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
def is_flash_attn_available():
logger.warning(
"Using `is_flash_attn_available` is deprecated and will be removed in v4.38. "
"Please use `is_flash_attn_2_available` instead."
)
return is_flash_attn_2_available()
def is_torchdistx_available():
return _torchdistx_available
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment