Commit e0fbaa70 authored by Tri Dao's avatar Tri Dao
Browse files

[Gen] Simplify decode_speculative

parent e6a80264
......@@ -159,7 +159,7 @@ def decode(
else:
logits = model._decoding_cache.run(
input_ids, position_ids, inference_params.seqlen_offset
).clone()
).squeeze(dim=1)
return logits[..., :vocab_size] if vocab_size is not None else logits
def sample_tokens(logits, inference_params):
......@@ -305,256 +305,250 @@ def decode_speculative(
batch_size,
seqlen_og,
max_length,
# draft model needs to process either 1 or 2 tokens at a time
decoding_seqlens=(1, 2),
tensor_parallel=tensor_parallel,
)
inference_params_draft = model_draft._decoding_cache.inference_params
inference_params_draft.reset(max_length, batch_size)
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
if not hasattr(model, "_decoding_cache"):
model._decoding_cache = None
model._decoding_cache = update_graph_cache(
model,
model._decoding_cache,
batch_size,
seqlen_og,
max_length,
decoding_seqlens=range(1, speculative_lookahead + 2),
tensor_parallel=tensor_parallel,
)
inference_params = model._decoding_cache.inference_params
inference_params.reset(max_length, batch_size)
else:
inference_params_draft = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
def logits_forward_fn(model, input_ids, position_ids, inference_params, cg=False):
if not cg:
return model(
def get_logits(input_ids, inference_params, model, num_last_tokens=1, cg=False):
decoding = inference_params.seqlen_offset > 0
if decoding:
seqlen = input_ids.shape[1]
# if inference_params.lengths_per_sample is None:
# TODO: in the case of batched decoding where each sequence has a different length,
# we need to compute the position_ids for each sequence using lengths_per_sample
if True:
cache_seqlens = torch.full(
(input_ids.shape[0],),
inference_params.seqlen_offset,
dtype=torch.int32,
device=input_ids.device,
)
else:
cache_seqlens = inference_params.lengths_per_sample
position_ids = cache_seqlens[:, None] + torch.arange(
seqlen, dtype=torch.long, device=input_ids.device
)
else:
position_ids = None
if not cg or not decoding:
logits = model(
input_ids,
position_ids=position_ids,
inference_params=inference_params,
num_last_tokens=1,
).logits.squeeze(dim=1)
num_last_tokens=num_last_tokens,
).logits
else:
return model._decoding_cache.run(
# NOTE: careful, CUDA graph is set to have num_last_tokens=input_ids.shape[1].
# This might not be compatible the num_last_tokens used here.
assert num_last_tokens <= input_ids.shape[1]
logits = model._decoding_cache.run(
input_ids, position_ids, inference_params.seqlen_offset
).clone()
logits_postprocess_fn = (
lambda logits: logits[..., :vocab_size] if vocab_size is not None else logits
)
)[:, -num_last_tokens:]
return logits[..., :vocab_size] if vocab_size is not None else logits
def sample_tokens(
input_ids,
model,
inference_params,
sample_fn,
num_tokens=1,
cg=False,
decoding=True,
last_token_logits=False,
):
def sample_tokens(input_ids, get_logits_fn, inference_params, sample_fn, num_tokens=1):
"""Sample `num_tokens` tokens from the model, given the previous logits.
Also return the logits of the sampled tokens.
Arguments:
input_ids: (batch, seqlen)
decoding: whether we're in the decoding phase or the prefilling phase. Prefill doesn't
need special position_ids.
last_token_logits: whether to return the logits of the last token. Normally we don't need this.
However, for speculative sampling, if the main model accepts all the draft tokens, plus it
samples one new token, then by right at the next iteration the draft model need to evaluate
the logits of the last draft token and the logits of the newly sampled token.
This makes implementation more complicated. So here we just evaluate the logits of the last
token in the draft model to simplify the implementation.
Return:
tokens: (batch, num_tokens)
scores: (batch, num_tokens), which contains @previous_logits and the logits of the next
(num_tokens - 1) tokens. The logits of the last token isn't computed unless last_token_logits=True.
In which case we have scores of shape (batch, num_tokens + 1)
(num_tokens - 1) tokens. The logits of the last token isn't computed.
"""
batch_size, seqlen = input_ids.shape
assert num_tokens >= 1
sequences = []
if decoding:
assert seqlen == 1
position_ids = repeat(
torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
+ inference_params.seqlen_offset,
"s -> b s",
b=batch_size,
)
# position_ids = torch.full(
# (batch_size, 1),
# inference_params.seqlen_offset,
# dtype=torch.long,
# device=input_ids.device,
# )
else:
position_ids = None
logits = logits_postprocess_fn(
logits_forward_fn(model, input_ids, position_ids, inference_params, cg=decoding and cg)
)
inference_params.seqlen_offset += input_ids.shape[1]
scores = [logits]
next_token = sample_fn(logits)
sequences.append(next_token)
sequences, scores = [input_ids], []
for i in range(num_tokens):
if i < num_tokens - 1 or last_token_logits:
position_ids = torch.full(
(batch_size, 1),
inference_params_draft.seqlen_offset,
dtype=torch.long,
device=input_ids.device,
)
logits = logits_postprocess_fn(
logits_forward_fn(
model,
rearrange(next_token, "b -> b 1"),
position_ids,
inference_params,
cg=cg,
)
)
inference_params.seqlen_offset += 1
scores.append(logits)
if i < num_tokens - 1:
next_token = sample_fn(logits)
sequences.append(next_token)
return torch.stack(sequences, dim=1), torch.stack(scores, dim=1)
scores.append(get_logits_fn(sequences[-1], inference_params)[:, -1])
inference_params.seqlen_offset += sequences[-1].shape[1]
sequences.append(sample_fn(scores[-1]).unsqueeze(1))
return torch.cat(sequences[1:], dim=1), torch.stack(scores, dim=1)
sampling_kwargs = dict(top_k=top_k, top_p=top_p, temperature=temperature)
sample_fn = partial(sample, **sampling_kwargs)
get_logits_main = partial(get_logits, model=model, cg=cg)
get_logits_draft = partial(get_logits, model=model_draft, cg=cg)
sample_tokens_main = partial(
sample_tokens, model=model, sample_fn=sample_fn, inference_params=inference_params, cg=False
) # main model doesn't use CUDA graph
sample_tokens,
get_logits_fn=get_logits_main,
sample_fn=sample_fn,
inference_params=inference_params,
)
sample_tokens_draft = partial(
sample_tokens,
model=model_draft,
get_logits_fn=get_logits_draft,
sample_fn=sample_fn,
last_token_logits=True,
inference_params=inference_params_draft,
cg=cg,
)
if debug:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
sequences = [input_ids]
scores = []
with torch.inference_mode():
if enable_timing:
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()
start = time.time()
if seqlen_og >= max_length - 1:
if enable_timing:
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()
start = time.time()
sequences, scores = [input_ids], []
num_main_model_calls = 0
num_draft_tokens = 0
num_accepted_tokens_history = []
if seqlen_og >= max_length - 1:
# Don't do speculative sampling, just sample 1 token from the model
tokens, scores_new = sample_tokens_main(input_ids, num_tokens=1)
sequences.append(tokens)
scores.append(scores_new)
else:
# Sample from draft model, which produces @n_spec_tokens, and @model
# will then use to produce between 1 and 1 + @n_spec_tokens tokens.
# We want seqlen_og + 1 + @n_spec_tokens to be <= @max_length.
n_spec_tokens = min(speculative_lookahead, max_length - seqlen_og - 1)
tokens_draft, scores_draft = sample_tokens_draft(input_ids, num_tokens=n_spec_tokens)
num_draft_tokens += n_spec_tokens
if debug:
scores_draft_ref = model_draft(
torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
).logits
print((scores_draft - scores_draft_ref[:, :-1]).abs().max())
# Evaluate the draft tokens with the model
logits = get_logits_main(
torch.cat([input_ids, tokens_draft], dim=1),
inference_params,
num_last_tokens=n_spec_tokens + 1,
)
num_main_model_calls += 1
if debug:
logits_ref = model(
torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
).logits
print((logits - logits_ref).abs().max())
# breakpoint()
tokens, num_generated_tokens = sample_speculative(
logits, scores_draft, tokens_draft, **sampling_kwargs
)
num_accepted_tokens_history.append(num_generated_tokens - 1)
if debug:
print(tokens)
print(num_generated_tokens)
# breakpoint()
# TODO: we're using the fact that batch_size == 1
# TODO: check eos_token_id
sequences.append(tokens[:1, : num_generated_tokens[0]])
scores.append(logits[:1, : num_generated_tokens[0]])
# Note that @model has not evaluated the last sampled token yet, so we'll need to pass
# that in the next time we call @model.
num_generated = num_generated_tokens[0].item()
inference_params.seqlen_offset = seqlen_og + num_generated - 1
inference_params_draft.seqlen_offset = (
inference_params.seqlen_offset - 1
if num_generated > 1
else inference_params.seqlen_offset
)
if debug:
cur_ids = torch.cat([input_ids, sequences[-1]], dim=1)
scores_ref = model(cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1).logits
print((scores[-1] - scores_ref[:, :-1]).abs().max())
# breakpoint()
while True:
# seqlen_offset is total length generated - 1
if inference_params.seqlen_offset >= max_length - 1:
break
if inference_params.seqlen_offset >= max_length - 2:
# Don't do speculative sampling, just sample 1 token from the model
tokens, scores_new = sample_tokens_main(input_ids, num_tokens=1, decoding=False)
tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1)
sequences.append(tokens)
scores.append(scores_new)
else:
# Sample from draft model, which produces @n_spec_tokens, and @model
# will then use to produce between 1 and 1 + @n_spec_tokens tokens.
# We want seqlen_og + 1 + @n_spec_tokens to be <= @max_length.
n_spec_tokens = min(speculative_lookahead, max_length - seqlen_og - 1)
tokens_draft, scores_draft = sample_tokens_draft(
input_ids,
num_tokens=n_spec_tokens,
decoding=False,
)
if debug:
scores_draft_ref = model_draft(
torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
).logits
print((scores_draft[:, :-1] - scores_draft_ref[:, :-1]).abs().max())
# Evaluate the draft tokens with the model
logits = model(
torch.cat([input_ids, tokens_draft], dim=1),
inference_params=inference_params,
num_last_tokens=n_spec_tokens + 1,
break
# Sample from draft model
n_spec_tokens = min(
speculative_lookahead, max_length - inference_params_draft.seqlen_offset - 2
)
# If the main model accepts all the draft tokens, plus it samples one new token,
# then at the next iteration the draft model need to evaluate the logits of the last draft
# token and the logits of the newly sampled token. So here we pass in the last 2 tokens
# of sequences[-1].
# This exception is when the main model rejects all the draft tokens, in which case we
# will only have 1 token to pass in.
tokens_draft, scores_draft = sample_tokens_draft(
sequences[-1][:, -2:], num_tokens=n_spec_tokens
)
num_draft_tokens += n_spec_tokens
if debug:
scores_draft_ref = model_draft(
torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
).logits
logits = logits_postprocess_fn(logits)
tokens, num_generated_tokens = sample_speculative(
logits, scores_draft[:, :-1], tokens_draft, **sampling_kwargs
)
if debug:
print(tokens)
print(num_generated_tokens)
# breakpoint()
# TODO: we're using the fact that batch_size == 1
# TODO: check eos_token_id
sequences.append(tokens[:1, : num_generated_tokens[0]])
scores.append(logits[:1, : num_generated_tokens[0]])
# Note that @model has not evaluated the last sampled token yet, so we'll need to pass
# that in the next time we call @model.
inference_params.seqlen_offset = seqlen_og + num_generated_tokens[0].item() - 1
inference_params_draft.seqlen_offset = inference_params.seqlen_offset
if debug:
cur_ids = torch.cat([input_ids, sequences[-1]], dim=1)
scores_ref = model(
cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1
).logits
print((scores[-1] - scores_ref[:, :-1]).abs().max())
while True:
# seqlen_offset is total length generated - 1
if inference_params.seqlen_offset >= max_length - 1:
break
if inference_params.seqlen_offset >= max_length - 2:
# Don't do speculative sampling, just sample 1 token from the model
tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1)
sequences.append(tokens)
scores.append(scores_new)
break
# Sample from draft model
n_spec_tokens = min(
speculative_lookahead, max_length - inference_params_draft.seqlen_offset - 2
)
tokens_draft, scores_draft = sample_tokens_draft(
sequences[-1][:, -1:], num_tokens=n_spec_tokens
)
if debug:
scores_draft_ref = model_draft(
torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
).logits
print((scores_draft[:, :-1] - scores_draft_ref[:, :-1]).abs().max())
# Evaluate the draft tokens with the model
position_ids = repeat(
torch.arange(
inference_params.seqlen_offset,
# 1 extra token from last time that hasn't been passed through model
inference_params.seqlen_offset + n_spec_tokens + 1,
dtype=torch.long,
device=input_ids.device,
),
"s -> b s",
b=batch_size,
)
logits = model(
torch.cat([sequences[-1][:, -1:], tokens_draft], dim=1),
position_ids=position_ids,
inference_params=inference_params,
).logits # (batch, n_spec_tokens, vocab_size)
logits = logits_postprocess_fn(logits)
inference_params.seqlen_offset += 1
if debug:
logits_ref = model(
torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
).logits
print((logits - logits_ref).abs().max())
tokens, num_generated_tokens = sample_speculative(
logits, scores_draft[:, :-1], tokens_draft, **sampling_kwargs
)
if debug:
print(tokens)
print(num_generated_tokens)
sequences.append(tokens[:1, : num_generated_tokens[0]])
scores.append(logits[:1, : num_generated_tokens[0]])
inference_params.seqlen_offset += num_generated_tokens[0].item() - 1
inference_params_draft.seqlen_offset = inference_params.seqlen_offset
print((scores_draft - scores_draft_ref[:, :-1]).abs().max())
# breakpoint()
# Evaluate the draft tokens with the model
logits = get_logits_main(
torch.cat([sequences[-1][:, -1:], tokens_draft], dim=1),
inference_params,
num_last_tokens=n_spec_tokens + 1,
) # (batch, n_spec_tokens + 1, vocab_size)
num_main_model_calls += 1
if debug:
logits_ref = model(
torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
).logits
print((logits - logits_ref).abs().max())
# breakpoint()
tokens, num_generated_tokens = sample_speculative(
logits, scores_draft, tokens_draft, **sampling_kwargs
)
num_accepted_tokens_history.append(num_generated_tokens - 1)
if debug:
print(tokens)
print(num_generated_tokens)
# breakpoint()
if debug:
cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1)
scores_ref = model(
cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1
).logits
print((scores[-1] - scores_ref[:, :-1]).abs().max())
if enable_timing:
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
sequences.append(tokens[:1, : num_generated_tokens[0]])
scores.append(logits[:1, : num_generated_tokens[0]])
# We've evaluated 1 token from sequences[-1][:, -1:] above, plus
# num_generated_tokens[0].item() - 1 tokens from the draft model.
num_generated = num_generated_tokens[0].item()
inference_params.seqlen_offset += num_generated
inference_params_draft.seqlen_offset = (
inference_params.seqlen_offset - 1
if num_generated > 1
else inference_params.seqlen_offset
)
if debug:
cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1)
scores_ref = model(cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1).logits
print((scores[-1] - scores_ref[:, :-1]).abs().max())
# breakpoint()
if enable_timing:
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
print(f"Number of calls to main model: {num_main_model_calls}")
print(
f"Acceptance rate: {torch.cat(num_accepted_tokens_history).sum().item() / num_draft_tokens * 100:.2f}%"
)
sequences = torch.cat(sequences, dim=1)
scores = torch.cat(scores, dim=1)
if debug:
......@@ -607,20 +601,6 @@ def allocate_inference_cache(
return {i: torch.empty(kv_cache_shape, device=device, dtype=dtype) for i in layers}
def seqlen_to_seqlen_type(seqlen: int) -> int:
"""Convert sequence length to a seqlen_type.
This is used to determine which cuda graph to use.
Arguments:
seqlen: int
"""
return 0
def seqlen_type_to_max_seqlen(seqlen_type: int) -> int:
assert seqlen_type in [0]
return 2**32
@dataclass
class DecodingCGCache:
max_batch_size: int = 0
......@@ -640,6 +620,7 @@ def update_graph_cache(
batch_size,
seqlen_og,
max_seqlen,
decoding_seqlens=(1,),
tensor_parallel=1,
dtype=None,
n_warmups=2,
......@@ -687,38 +668,36 @@ def update_graph_cache(
lengths_per_sample=lengths_per_sample,
)
cache.mempool = torch.cuda.graphs.graph_pool_handle()
for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1):
if (batch_size, s_type) not in cache.callables:
max_seqlen_ = min(max(seqlen_og, seqlen_type_to_max_seqlen(s_type)), max_seqlen)
cache.callables[batch_size, s_type] = capture_graph(
for decoding_seqlen in decoding_seqlens:
if (batch_size, decoding_seqlen) not in cache.callables:
cache.callables[batch_size, decoding_seqlen] = capture_graph(
model,
cache.inference_params,
batch_size,
max_seqlen_,
max_seqlen,
decoding_seqlen=decoding_seqlen,
mempool=cache.mempool,
n_warmups=n_warmups,
)
def dispatch(input_ids, position_ids, seqlen):
batch_size = input_ids.shape[0]
return cache.callables[batch_size, seqlen_to_seqlen_type(seqlen)](
input_ids, position_ids, seqlen
)
batch_size, decoding_seqlen = input_ids.shape[:2]
return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
cache.run = dispatch
cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
return cache
def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, n_warmups=2):
def capture_graph(
model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
):
device = next(iter(model.parameters())).device
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
seqlen_offset_og = inference_params.seqlen_offset
# TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is
# used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample.
inference_params.seqlen_offset = max_seqlen - 1
inference_params.lengths_per_sample[:] = max_seqlen - 1
inference_params.seqlen_offset = max_seqlen - decoding_seqlen
inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
# Warmup before capture
s = torch.cuda.Stream()
......@@ -729,7 +708,7 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
input_ids,
position_ids=position_ids,
inference_params=inference_params,
num_last_tokens=1,
num_last_tokens=decoding_seqlen,
).logits
s.synchronize()
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
......@@ -746,8 +725,8 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
input_ids,
position_ids=position_ids,
inference_params=inference_params,
num_last_tokens=1,
).logits.squeeze(dim=1)
num_last_tokens=decoding_seqlen,
).logits
def run(new_input_ids, new_position_ids, seqlen):
inference_params.lengths_per_sample[:] = seqlen
......
......@@ -383,11 +383,14 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
@pytest.mark.parametrize("cg", [False, True])
# @pytest.mark.parametrize("optimized", [False, True])
@pytest.mark.parametrize("optimized", [True])
# @pytest.mark.parametrize("cg", [True])
@pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize("optimized", [True])
# @pytest.mark.parametrize("model_name", ["gpt2-medium"])
@pytest.mark.parametrize("model_name", ["gpt2-xl"])
def test_gpt2_speculative_decoding(model_name, optimized, cg):
if cg and not optimized:
pytest.skip() # CG requires use_flash_attn
dtype = torch.float16
device = "cuda"
rtol, atol = 3e-3, 3e-1
......@@ -421,6 +424,7 @@ def test_gpt2_speculative_decoding(model_name, optimized, cg):
from flash_attn.utils.generation import decode_speculative
torch.manual_seed(42)
print(f"Speculative decoding, {optimized = }")
out = decode_speculative(
input_ids,
model,
......@@ -430,13 +434,15 @@ def test_gpt2_speculative_decoding(model_name, optimized, cg):
cg=cg,
speculative_lookahead=4,
enable_timing=True,
# debug=True,
)
print(tokenizer.batch_decode(out.sequences))
print(f"Without speculative decoding, {cg = }")
out_og = model.generate(
input_ids,
max_length=max_length,
top_k=5,
cg=False,
cg=cg,
enable_timing=True,
return_dict_in_generate=True,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment