"test/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "f11fc7cf4732ff2393b92793bdf53967defbc2c7"
Unverified Commit 9934bb1f authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: multi-device support for contrastive search (#24635)

parent 4b26a616
......@@ -2060,8 +2060,10 @@ class GenerationMixin:
context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)
# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
# model confidence
# model confidence. Keeping `selected_idx` on CPU enables multi-device contrastive search and doesn't
# introduce (noticeable) slowdowns on single-device runs.
selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k)
selected_idx = selected_idx.to("cpu")
# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
# the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment