Unverified Commit 7ae4fc27 authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Fix Bark logits processors device misplacement (#31416)

Fix Logits Processors device misplacement
parent 9af1b6a8
...@@ -991,11 +991,11 @@ class BarkSemanticModel(BarkCausalModel): ...@@ -991,11 +991,11 @@ class BarkSemanticModel(BarkCausalModel):
list(range(semantic_generation_config.semantic_pad_token + 1, self.config.output_vocab_size)) list(range(semantic_generation_config.semantic_pad_token + 1, self.config.output_vocab_size))
) )
suppress_tokens_logits_processor = SuppressTokensLogitsProcessor(tokens_to_suppress) suppress_tokens_logits_processor = SuppressTokensLogitsProcessor(tokens_to_suppress, device=input_ids.device)
min_eos_p = kwargs.get("min_eos_p", semantic_generation_config.min_eos_p) min_eos_p = kwargs.get("min_eos_p", semantic_generation_config.min_eos_p)
early_stopping_logits_processor = BarkEosPrioritizerLogitsProcessor( early_stopping_logits_processor = BarkEosPrioritizerLogitsProcessor(
eos_token_id=semantic_generation_config.eos_token_id, min_eos_p=min_eos_p eos_token_id=semantic_generation_config.eos_token_id, min_eos_p=min_eos_p, device=input_ids.device
) )
# pass input_ids in order to stay consistent with the transformers generate method even though it is not used # pass input_ids in order to stay consistent with the transformers generate method even though it is not used
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment