Commit 9f42cb6e authored by Tri Dao's avatar Tri Dao
Browse files

[Gen] Clone logits before returning when cg=True

parent f8aea6ea
...@@ -4,10 +4,12 @@ import gc ...@@ -4,10 +4,12 @@ import gc
import time import time
from collections import namedtuple from collections import namedtuple
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import partial
from typing import Callable, Optional, Sequence, Union from typing import Callable, Optional, Sequence, Union
import torch import torch
from einops import rearrange import torch.nn.functional as F
from einops import rearrange, repeat
from torch import Tensor from torch import Tensor
from torch.profiler import ProfilerActivity, profile, record_function from torch.profiler import ProfilerActivity, profile, record_function
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
...@@ -205,6 +207,363 @@ def decode( ...@@ -205,6 +207,363 @@ def decode(
) )
def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, temperature=1.0):
"""Algorithm 1 from [1]
[1] Fast Inference from Transformers via Speculative Decoding
Yaniv Leviathan, Matan Kalman, Yossi Matias
https://arxiv.org/abs/2211.17192
Arguments:
logits: Tensor of shape (batch_size, seqlen + 1, vocab_size)
logits_draft: Tensor of shape (batch_size, seqlen, vocab_size)
tokens_draft: Tensor of shape (batch_size, seqlen)
Return:
tokens: Tensor of shape (batch_size, seqlen + 1)
num_generated_tokens: Tensor of shape (batch_size), with value in [1, seqlen + 1].
For each sequence in the batch, the number of valid tokens that were sampled by
speculative sampling.
"""
batch, seqlen_p_1, vocab_size = logits.shape
seqlen = seqlen_p_1 - 1
assert logits_draft.shape == (batch, seqlen, vocab_size)
assert tokens_draft.shape == (batch, seqlen)
assert tokens_draft.dtype in [torch.int64, torch.int32]
# TODO: if top_k = 1 we can simplify things and only work with indices
if top_p > 0.0:
assert top_p <= 1.0, "top-p should be in (0, 1]."
# Clone so that when we modify for top_p we don't change the original logits
logits = logits / temperature if temperature != 1.0 else logits.clone()
logits_draft = logits_draft / temperature if temperature != 1.0 else logits_draft.clone()
if top_k > 0:
top_k = min(top_k, logits.size(-1)) # Safety check
modify_logits_for_top_k_filtering(logits, top_k)
modify_logits_for_top_k_filtering(logits_draft, top_k)
modify_logits_for_top_p_filtering(logits, top_p)
modify_logits_for_top_p_filtering(logits_draft, top_p)
probs = torch.softmax(logits, dim=-1)
probs_draft = torch.softmax(logits_draft, dim=-1)
gather = lambda probs, tokens: rearrange(
probs.gather(dim=-1, index=rearrange(tokens, "... -> ... 1")), "... 1 -> ..."
)
# (batch, seqlen)
accepted = torch.rand(batch, seqlen, device=probs.device) * gather(
probs_draft, tokens_draft
) <= gather(probs[:, :-1], tokens_draft)
accepted_all = accepted.all(dim=-1)
# (batch,)
first_rejected_idx = torch.where(accepted_all, seqlen, accepted.int().argmin(dim=-1))
probs_diff = torch.clamp(probs[:, :-1] - probs_draft, min=0.0)
# torch.multinomial can deal with unnormalized probabilities
# probs_diff /= probs_diff.sum(dim=-1, keepdim=True)
resample_probs = torch.cat([probs_diff, probs[:, -1:]], dim=1)
resample_probs = rearrange(
resample_probs.gather(dim=1, index=repeat(first_rejected_idx, "b -> b 1 d", d=vocab_size)),
"b 1 d -> b d",
)
resample = torch.multinomial(resample_probs, num_samples=1).squeeze(dim=-1) # (batch,)
tokens = F.pad(tokens_draft, (0, 1))
tokens[:, first_rejected_idx] = resample
return tokens, first_rejected_idx + 1
def decode_speculative(
input_ids,
model,
model_draft,
max_length,
speculative_lookahead=3,
top_k=1,
top_p=0.0,
temperature=1.0,
eos_token_id=None,
vocab_size=None,
tensor_parallel=1,
fused_ft_kernel=False,
cg=False,
timing=False,
debug=False,
):
"""
TD: WIP, for my own understanding, lightly tested. Only support batch_size == 1 for now.
Speculative decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
then top-p.
We assume that all sequences in the same batch have the same length.
Arguments:
input_ids: (batch, seq_len)
max_length: int
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
"""
batch_size, seqlen_og = input_ids.shape
assert batch_size == 1, "Speculative decoding implementation only supports batch_size=1"
assert eos_token_id is None, "Speculative decoding implementation doesn't support eos_token_id"
if cg:
assert fused_ft_kernel
if not hasattr(model_draft, "_decoding_cache"):
model_draft._decoding_cache = None
model_draft._decoding_cache = update_graph_cache(
model_draft,
model_draft._decoding_cache,
batch_size,
seqlen_og,
max_length,
tensor_parallel=tensor_parallel,
)
inference_params_draft = model_draft._decoding_cache.inference_params
inference_params_draft.max_sequence_len = max_length
inference_params_draft.max_batch_size = batch_size
inference_params_draft.sequence_len_offset = 0
# fused_ft_kernel doesn't support passing in multiple tokens at once
inference_params = InferenceParams(
max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=False
)
else:
inference_params_draft = InferenceParams(
max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel
)
inference_params = InferenceParams(
max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=False
)
def logits_forward_fn(model, input_ids, position_ids, inference_params, cg=False):
if not cg:
return model(
input_ids,
position_ids=position_ids,
inference_params=inference_params,
num_last_tokens=1,
).logits.squeeze(dim=1)
else:
return model._decoding_cache.run(
input_ids, position_ids, inference_params.sequence_len_offset
).clone()
logits_postprocess_fn = (
lambda logits: logits[..., :vocab_size] if vocab_size is not None else logits
)
def sample_tokens(
input_ids, model, inference_params, sample_fn, num_tokens=1, cg=False, decoding=True,
last_token_logits=False
):
"""Sample `num_tokens` tokens from the model, given the previous logits.
Also return the logits of the sampled tokens.
Arguments:
input_ids: (batch, seqlen)
decoding: whether we're in the decoding phase or the prefilling phase. Prefill doesn't
need special position_ids.
last_token_logits: whether to return the logits of the last token. Normally we don't need this.
However, for speculative sampling, if the main model accepts all the draft tokens, plus it
samples one new token, then by right at the next iteration the draft model need to evaluate
the logits of the last draft token and the logits of the newly sampled token.
This makes implementation more complicated. So here we just evaluate the logits of the last
token in the draft model to simplify the implementation.
Return:
tokens: (batch, num_tokens)
scores: (batch, num_tokens), which contains @previous_logits and the logits of the next
(num_tokens - 1) tokens. The logits of the last token isn't computed unless last_token_logits=True.
In which case we have scores of shape (batch, num_tokens + 1)
"""
batch_size, seqlen = input_ids.shape
assert num_tokens >= 1
sequences = []
if decoding:
assert seqlen == 1
position_ids = torch.full(
(batch_size, 1),
inference_params.sequence_len_offset,
dtype=torch.long,
device=input_ids.device,
)
else:
position_ids = None
logits = logits_postprocess_fn(
logits_forward_fn(model, input_ids, position_ids, inference_params, cg=decoding and cg)
)
inference_params.sequence_len_offset += input_ids.shape[1]
scores = [logits]
next_token = sample_fn(logits)
sequences.append(next_token)
for i in range(num_tokens):
if i < num_tokens - 1 or last_token_logits:
position_ids = torch.full(
(batch_size, 1),
inference_params_draft.sequence_len_offset,
dtype=torch.long,
device=input_ids.device,
)
logits = logits_postprocess_fn(
logits_forward_fn(
model, rearrange(next_token, "b -> b 1"), position_ids, inference_params, cg=cg
)
)
inference_params.sequence_len_offset += 1
scores.append(logits)
if i < num_tokens - 1:
next_token = sample_fn(logits)
sequences.append(next_token)
return torch.stack(sequences, dim=1), torch.stack(scores, dim=1)
sampling_kwargs = dict(top_k=top_k, top_p=top_p, temperature=temperature)
sample_fn = partial(sample, **sampling_kwargs)
sample_tokens_main = partial(
sample_tokens, model=model, sample_fn=sample_fn, inference_params=inference_params, cg=False
) # main model doesn't use CUDA graph
sample_tokens_draft = partial(
sample_tokens,
model=model_draft,
sample_fn=sample_fn,
last_token_logits=True,
inference_params=inference_params_draft,
cg=cg
)
if debug:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
sequences = [input_ids]
scores = []
with torch.inference_mode():
if timing:
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()
start = time.time()
if seqlen_og >= max_length - 1:
# Don't do speculative sampling, just sample 1 token from the model
tokens, scores_new = sample_tokens_main(input_ids, num_tokens=1, decoding=False)
sequences.append(tokens)
scores.append(scores_new)
else:
# Sample from draft model, which produces @n_spec_tokens, and @model
# will then use to produce between 1 and 1 + @n_spec_tokens tokens.
# We want seqlen_og + 1 + @n_spec_tokens to be <= @max_length.
n_spec_tokens = min(speculative_lookahead, max_length - seqlen_og - 1)
tokens_draft, scores_draft = sample_tokens_draft(
input_ids,
num_tokens=n_spec_tokens,
decoding=False,
)
if debug:
scores_draft_ref = model_draft(
torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
).logits
print((scores_draft[:, :-1] - scores_draft_ref[:, :-1]).abs().max())
# Evaluate the draft tokens with the model
logits = model(
torch.cat([input_ids, tokens_draft], dim=1),
inference_params=inference_params,
num_last_tokens=n_spec_tokens + 1,
).logits
logits = logits_postprocess_fn(logits)
tokens, num_generated_tokens = sample_speculative(
logits, scores_draft[:, :-1], tokens_draft, **sampling_kwargs
)
if debug:
print(tokens)
print(num_generated_tokens)
# breakpoint()
# TODO: we're using the fact that batch_size == 1
# TODO: check eos_token_id
sequences.append(tokens[:1, : num_generated_tokens[0]])
scores.append(logits[:1, : num_generated_tokens[0]])
# Note that @model has not evaluated the last sampled token yet, so we'll need to pass
# that in the next time we call @model.
inference_params.sequence_len_offset = seqlen_og + num_generated_tokens[0].item() - 1
inference_params_draft.sequence_len_offset = inference_params.sequence_len_offset
if debug:
cur_ids = torch.cat([input_ids, sequences[-1]], dim=1)
scores_ref = model(
cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1
).logits
print((scores[-1] - scores_ref[:, :-1]).abs().max())
while True:
# sequence_len_offset is total length generated - 1
if inference_params.sequence_len_offset >= max_length - 1:
break
if inference_params.sequence_len_offset >= max_length - 2:
# Don't do speculative sampling, just sample 1 token from the model
tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1)
sequences.append(tokens)
scores.append(scores_new)
break
# Sample from draft model
n_spec_tokens = min(
speculative_lookahead, max_length - inference_params_draft.sequence_len_offset - 2
)
tokens_draft, scores_draft = sample_tokens_draft(
sequences[-1][:, -1:], num_tokens=n_spec_tokens
)
if debug:
scores_draft_ref = model_draft(
torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
).logits
print((scores_draft[:, :-1] - scores_draft_ref[:, :-1]).abs().max())
# Evaluate the draft tokens with the model
position_ids = repeat(
torch.arange(
inference_params.sequence_len_offset,
# 1 extra token from last time that hasn't been passed through model
inference_params.sequence_len_offset + n_spec_tokens + 1,
dtype=torch.long,
device=input_ids.device,
),
"s -> b s",
b=batch_size,
)
logits = model(
torch.cat([sequences[-1][:, -1:], tokens_draft], dim=1),
position_ids=position_ids,
inference_params=inference_params,
).logits # (batch, n_spec_tokens, vocab_size)
logits = logits_postprocess_fn(logits)
inference_params.sequence_len_offset += 1
if debug:
logits_ref = model(
torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
).logits
print((logits - logits_ref).abs().max())
tokens, num_generated_tokens = sample_speculative(
logits, scores_draft[:, :-1], tokens_draft, **sampling_kwargs
)
if debug:
print(tokens)
print(num_generated_tokens)
sequences.append(tokens[:1, : num_generated_tokens[0]])
scores.append(logits[:1, : num_generated_tokens[0]])
inference_params.sequence_len_offset += num_generated_tokens[0].item() - 1
inference_params_draft.sequence_len_offset = inference_params.sequence_len_offset
# breakpoint()
if debug:
cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1)
scores_ref = model(
cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1
).logits
print((scores[-1] - scores_ref[:, :-1]).abs().max())
if timing:
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
sequences = torch.cat(sequences, dim=1)
scores = torch.cat(scores, dim=1)
if debug:
scores_ref = model(sequences).logits
print((scores - scores_ref[:, seqlen_og - 1 : -1]).abs().max())
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
return output_cls(sequences=sequences, scores=scores)
class GenerationMixin: class GenerationMixin:
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
raise NotImplementedError raise NotImplementedError
...@@ -394,7 +753,7 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, ...@@ -394,7 +753,7 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
input_ids.copy_(new_input_ids) input_ids.copy_(new_input_ids)
position_ids.copy_(new_position_ids) position_ids.copy_(new_position_ids)
graph.replay() graph.replay()
return logits return logits.clone()
inference_params.sequence_len_offset = sequence_len_offset_og inference_params.sequence_len_offset = sequence_len_offset_og
return run return run
...@@ -368,11 +368,79 @@ def test_gpt2_multiple_token_generation(model_name, optimized): ...@@ -368,11 +368,79 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
logits_10 = model(input_ids[:, :10], inference_params=inference_params).logits logits_10 = model(input_ids[:, :10], inference_params=inference_params).logits
inference_params.sequence_len_offset += 10 inference_params.sequence_len_offset += 10
position_ids = torch.arange(10, 14, dtype=torch.long, device=device) position_ids = torch.arange(10, 14, dtype=torch.long, device=device)
logits_1014 = model(input_ids[:, 10:14], position_ids=position_ids, inference_params=inference_params).logits logits_1014 = model(
input_ids[:, 10:14], position_ids=position_ids, inference_params=inference_params
).logits
inference_params.sequence_len_offset += 4 inference_params.sequence_len_offset += 4
position_ids = torch.arange(14, 20, dtype=torch.long, device=device) position_ids = torch.arange(14, 20, dtype=torch.long, device=device)
logits_1420 = model(input_ids[:, 14:20], position_ids=position_ids, inference_params=inference_params).logits logits_1420 = model(
input_ids[:, 14:20], position_ids=position_ids, inference_params=inference_params
).logits
logits = torch.cat([logits_10, logits_1014, logits_1420], dim=1) logits = torch.cat([logits_10, logits_1014, logits_1420], dim=1)
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
assert torch.allclose(logits, logits_ref, rtol=rtol, atol=atol) assert torch.allclose(logits, logits_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("fused_ft_kernel, cg", [(False, False), (True, False), (True, True)])
# @pytest.mark.parametrize("fused_ft_kernel, cg", [(True, True)])
# @pytest.mark.parametrize("optimized", [False, True])
@pytest.mark.parametrize("optimized", [True])
# @pytest.mark.parametrize("model_name", ["gpt2-medium"])
@pytest.mark.parametrize("model_name", ["gpt2-xl"])
def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
dtype = torch.float16
device = "cuda"
rtol, atol = 3e-3, 3e-1
config = GPT2Config.from_pretrained(model_name)
config.residual_in_fp32 = True
if optimized:
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
config_draft = GPT2Config.from_pretrained("gpt2")
config_draft.residual_in_fp32 = True
if optimized:
config_draft.use_flash_attn = True
config_draft.fused_bias_fc = True
config_draft.fused_mlp = True
config_draft.fused_dropout_add_ln = True
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval()
model_draft = GPTLMHeadModel.from_pretrained("gpt2", config_draft, device=device, dtype=dtype)
model_draft.eval()
torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(
device=device
)
max_length = 100
from flash_attn.utils.generation import decode_speculative
torch.manual_seed(42)
out = decode_speculative(
input_ids,
model,
model_draft,
max_length=max_length,
top_k=5,
fused_ft_kernel=fused_ft_kernel,
cg=cg,
speculative_lookahead=4,
timing=True,
)
print(tokenizer.batch_decode(out.sequences))
out_og = model.generate(
input_ids,
max_length=max_length,
top_k=5,
fused_ft_kernel=fused_ft_kernel,
cg=False,
timing=True,
return_dict_in_generate=True,
)
print(tokenizer.batch_decode(out_og.sequences))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment