Commit e0fbaa70 authored by Tri Dao's avatar Tri Dao
Browse files

[Gen] Simplify decode_speculative

parent e6a80264
...@@ -159,7 +159,7 @@ def decode( ...@@ -159,7 +159,7 @@ def decode(
else: else:
logits = model._decoding_cache.run( logits = model._decoding_cache.run(
input_ids, position_ids, inference_params.seqlen_offset input_ids, position_ids, inference_params.seqlen_offset
).clone() ).squeeze(dim=1)
return logits[..., :vocab_size] if vocab_size is not None else logits return logits[..., :vocab_size] if vocab_size is not None else logits
def sample_tokens(logits, inference_params): def sample_tokens(logits, inference_params):
...@@ -305,256 +305,250 @@ def decode_speculative( ...@@ -305,256 +305,250 @@ def decode_speculative(
batch_size, batch_size,
seqlen_og, seqlen_og,
max_length, max_length,
# draft model needs to process either 1 or 2 tokens at a time
decoding_seqlens=(1, 2),
tensor_parallel=tensor_parallel, tensor_parallel=tensor_parallel,
) )
inference_params_draft = model_draft._decoding_cache.inference_params inference_params_draft = model_draft._decoding_cache.inference_params
inference_params_draft.reset(max_length, batch_size) inference_params_draft.reset(max_length, batch_size)
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) if not hasattr(model, "_decoding_cache"):
model._decoding_cache = None
model._decoding_cache = update_graph_cache(
model,
model._decoding_cache,
batch_size,
seqlen_og,
max_length,
decoding_seqlens=range(1, speculative_lookahead + 2),
tensor_parallel=tensor_parallel,
)
inference_params = model._decoding_cache.inference_params
inference_params.reset(max_length, batch_size)
else: else:
inference_params_draft = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) inference_params_draft = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
def logits_forward_fn(model, input_ids, position_ids, inference_params, cg=False): def get_logits(input_ids, inference_params, model, num_last_tokens=1, cg=False):
if not cg: decoding = inference_params.seqlen_offset > 0
return model( if decoding:
seqlen = input_ids.shape[1]
# if inference_params.lengths_per_sample is None:
# TODO: in the case of batched decoding where each sequence has a different length,
# we need to compute the position_ids for each sequence using lengths_per_sample
if True:
cache_seqlens = torch.full(
(input_ids.shape[0],),
inference_params.seqlen_offset,
dtype=torch.int32,
device=input_ids.device,
)
else:
cache_seqlens = inference_params.lengths_per_sample
position_ids = cache_seqlens[:, None] + torch.arange(
seqlen, dtype=torch.long, device=input_ids.device
)
else:
position_ids = None
if not cg or not decoding:
logits = model(
input_ids, input_ids,
position_ids=position_ids, position_ids=position_ids,
inference_params=inference_params, inference_params=inference_params,
num_last_tokens=1, num_last_tokens=num_last_tokens,
).logits.squeeze(dim=1) ).logits
else: else:
return model._decoding_cache.run( # NOTE: careful, CUDA graph is set to have num_last_tokens=input_ids.shape[1].
# This might not be compatible the num_last_tokens used here.
assert num_last_tokens <= input_ids.shape[1]
logits = model._decoding_cache.run(
input_ids, position_ids, inference_params.seqlen_offset input_ids, position_ids, inference_params.seqlen_offset
).clone() )[:, -num_last_tokens:]
return logits[..., :vocab_size] if vocab_size is not None else logits
logits_postprocess_fn = (
lambda logits: logits[..., :vocab_size] if vocab_size is not None else logits
)
def sample_tokens( def sample_tokens(input_ids, get_logits_fn, inference_params, sample_fn, num_tokens=1):
input_ids,
model,
inference_params,
sample_fn,
num_tokens=1,
cg=False,
decoding=True,
last_token_logits=False,
):
"""Sample `num_tokens` tokens from the model, given the previous logits. """Sample `num_tokens` tokens from the model, given the previous logits.
Also return the logits of the sampled tokens. Also return the logits of the sampled tokens.
Arguments: Arguments:
input_ids: (batch, seqlen) input_ids: (batch, seqlen)
decoding: whether we're in the decoding phase or the prefilling phase. Prefill doesn't
need special position_ids.
last_token_logits: whether to return the logits of the last token. Normally we don't need this.
However, for speculative sampling, if the main model accepts all the draft tokens, plus it
samples one new token, then by right at the next iteration the draft model need to evaluate
the logits of the last draft token and the logits of the newly sampled token.
This makes implementation more complicated. So here we just evaluate the logits of the last
token in the draft model to simplify the implementation.
Return: Return:
tokens: (batch, num_tokens) tokens: (batch, num_tokens)
scores: (batch, num_tokens), which contains @previous_logits and the logits of the next scores: (batch, num_tokens), which contains @previous_logits and the logits of the next
(num_tokens - 1) tokens. The logits of the last token isn't computed unless last_token_logits=True. (num_tokens - 1) tokens. The logits of the last token isn't computed.
In which case we have scores of shape (batch, num_tokens + 1)
""" """
batch_size, seqlen = input_ids.shape
assert num_tokens >= 1 assert num_tokens >= 1
sequences = [] sequences, scores = [input_ids], []
if decoding:
assert seqlen == 1
position_ids = repeat(
torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
+ inference_params.seqlen_offset,
"s -> b s",
b=batch_size,
)
# position_ids = torch.full(
# (batch_size, 1),
# inference_params.seqlen_offset,
# dtype=torch.long,
# device=input_ids.device,
# )
else:
position_ids = None
logits = logits_postprocess_fn(
logits_forward_fn(model, input_ids, position_ids, inference_params, cg=decoding and cg)
)
inference_params.seqlen_offset += input_ids.shape[1]
scores = [logits]
next_token = sample_fn(logits)
sequences.append(next_token)
for i in range(num_tokens): for i in range(num_tokens):
if i < num_tokens - 1 or last_token_logits: scores.append(get_logits_fn(sequences[-1], inference_params)[:, -1])
position_ids = torch.full( inference_params.seqlen_offset += sequences[-1].shape[1]
(batch_size, 1), sequences.append(sample_fn(scores[-1]).unsqueeze(1))
inference_params_draft.seqlen_offset, return torch.cat(sequences[1:], dim=1), torch.stack(scores, dim=1)
dtype=torch.long,
device=input_ids.device,
)
logits = logits_postprocess_fn(
logits_forward_fn(
model,
rearrange(next_token, "b -> b 1"),
position_ids,
inference_params,
cg=cg,
)
)
inference_params.seqlen_offset += 1
scores.append(logits)
if i < num_tokens - 1:
next_token = sample_fn(logits)
sequences.append(next_token)
return torch.stack(sequences, dim=1), torch.stack(scores, dim=1)
sampling_kwargs = dict(top_k=top_k, top_p=top_p, temperature=temperature) sampling_kwargs = dict(top_k=top_k, top_p=top_p, temperature=temperature)
sample_fn = partial(sample, **sampling_kwargs) sample_fn = partial(sample, **sampling_kwargs)
get_logits_main = partial(get_logits, model=model, cg=cg)
get_logits_draft = partial(get_logits, model=model_draft, cg=cg)
sample_tokens_main = partial( sample_tokens_main = partial(
sample_tokens, model=model, sample_fn=sample_fn, inference_params=inference_params, cg=False sample_tokens,
) # main model doesn't use CUDA graph get_logits_fn=get_logits_main,
sample_fn=sample_fn,
inference_params=inference_params,
)
sample_tokens_draft = partial( sample_tokens_draft = partial(
sample_tokens, sample_tokens,
model=model_draft, get_logits_fn=get_logits_draft,
sample_fn=sample_fn, sample_fn=sample_fn,
last_token_logits=True,
inference_params=inference_params_draft, inference_params=inference_params_draft,
cg=cg,
) )
if debug: if debug:
from transformers import AutoTokenizer from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer = AutoTokenizer.from_pretrained("gpt2")
sequences = [input_ids] if enable_timing:
scores = [] if tensor_parallel > 1:
with torch.inference_mode(): torch.distributed.barrier()
if enable_timing: torch.cuda.synchronize()
if tensor_parallel > 1: start = time.time()
torch.distributed.barrier()
torch.cuda.synchronize() sequences, scores = [input_ids], []
start = time.time() num_main_model_calls = 0
num_draft_tokens = 0
if seqlen_og >= max_length - 1: num_accepted_tokens_history = []
if seqlen_og >= max_length - 1:
# Don't do speculative sampling, just sample 1 token from the model
tokens, scores_new = sample_tokens_main(input_ids, num_tokens=1)
sequences.append(tokens)
scores.append(scores_new)
else:
# Sample from draft model, which produces @n_spec_tokens, and @model
# will then use to produce between 1 and 1 + @n_spec_tokens tokens.
# We want seqlen_og + 1 + @n_spec_tokens to be <= @max_length.
n_spec_tokens = min(speculative_lookahead, max_length - seqlen_og - 1)
tokens_draft, scores_draft = sample_tokens_draft(input_ids, num_tokens=n_spec_tokens)
num_draft_tokens += n_spec_tokens
if debug:
scores_draft_ref = model_draft(
torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
).logits
print((scores_draft - scores_draft_ref[:, :-1]).abs().max())
# Evaluate the draft tokens with the model
logits = get_logits_main(
torch.cat([input_ids, tokens_draft], dim=1),
inference_params,
num_last_tokens=n_spec_tokens + 1,
)
num_main_model_calls += 1
if debug:
logits_ref = model(
torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
).logits
print((logits - logits_ref).abs().max())
# breakpoint()
tokens, num_generated_tokens = sample_speculative(
logits, scores_draft, tokens_draft, **sampling_kwargs
)
num_accepted_tokens_history.append(num_generated_tokens - 1)
if debug:
print(tokens)
print(num_generated_tokens)
# breakpoint()
# TODO: we're using the fact that batch_size == 1
# TODO: check eos_token_id
sequences.append(tokens[:1, : num_generated_tokens[0]])
scores.append(logits[:1, : num_generated_tokens[0]])
# Note that @model has not evaluated the last sampled token yet, so we'll need to pass
# that in the next time we call @model.
num_generated = num_generated_tokens[0].item()
inference_params.seqlen_offset = seqlen_og + num_generated - 1
inference_params_draft.seqlen_offset = (
inference_params.seqlen_offset - 1
if num_generated > 1
else inference_params.seqlen_offset
)
if debug:
cur_ids = torch.cat([input_ids, sequences[-1]], dim=1)
scores_ref = model(cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1).logits
print((scores[-1] - scores_ref[:, :-1]).abs().max())
# breakpoint()
while True:
# seqlen_offset is total length generated - 1
if inference_params.seqlen_offset >= max_length - 1:
break
if inference_params.seqlen_offset >= max_length - 2:
# Don't do speculative sampling, just sample 1 token from the model # Don't do speculative sampling, just sample 1 token from the model
tokens, scores_new = sample_tokens_main(input_ids, num_tokens=1, decoding=False) tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1)
sequences.append(tokens) sequences.append(tokens)
scores.append(scores_new) scores.append(scores_new)
else: break
# Sample from draft model, which produces @n_spec_tokens, and @model # Sample from draft model
# will then use to produce between 1 and 1 + @n_spec_tokens tokens. n_spec_tokens = min(
# We want seqlen_og + 1 + @n_spec_tokens to be <= @max_length. speculative_lookahead, max_length - inference_params_draft.seqlen_offset - 2
n_spec_tokens = min(speculative_lookahead, max_length - seqlen_og - 1) )
tokens_draft, scores_draft = sample_tokens_draft( # If the main model accepts all the draft tokens, plus it samples one new token,
input_ids, # then at the next iteration the draft model need to evaluate the logits of the last draft
num_tokens=n_spec_tokens, # token and the logits of the newly sampled token. So here we pass in the last 2 tokens
decoding=False, # of sequences[-1].
) # This exception is when the main model rejects all the draft tokens, in which case we
if debug: # will only have 1 token to pass in.
scores_draft_ref = model_draft( tokens_draft, scores_draft = sample_tokens_draft(
torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 sequences[-1][:, -2:], num_tokens=n_spec_tokens
).logits )
print((scores_draft[:, :-1] - scores_draft_ref[:, :-1]).abs().max()) num_draft_tokens += n_spec_tokens
if debug:
# Evaluate the draft tokens with the model scores_draft_ref = model_draft(
logits = model( torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
torch.cat([input_ids, tokens_draft], dim=1),
inference_params=inference_params,
num_last_tokens=n_spec_tokens + 1,
).logits ).logits
logits = logits_postprocess_fn(logits) print((scores_draft - scores_draft_ref[:, :-1]).abs().max())
tokens, num_generated_tokens = sample_speculative( # breakpoint()
logits, scores_draft[:, :-1], tokens_draft, **sampling_kwargs # Evaluate the draft tokens with the model
) logits = get_logits_main(
if debug: torch.cat([sequences[-1][:, -1:], tokens_draft], dim=1),
print(tokens) inference_params,
print(num_generated_tokens) num_last_tokens=n_spec_tokens + 1,
# breakpoint() ) # (batch, n_spec_tokens + 1, vocab_size)
# TODO: we're using the fact that batch_size == 1 num_main_model_calls += 1
# TODO: check eos_token_id if debug:
sequences.append(tokens[:1, : num_generated_tokens[0]]) logits_ref = model(
scores.append(logits[:1, : num_generated_tokens[0]]) torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
# Note that @model has not evaluated the last sampled token yet, so we'll need to pass ).logits
# that in the next time we call @model. print((logits - logits_ref).abs().max())
inference_params.seqlen_offset = seqlen_og + num_generated_tokens[0].item() - 1 # breakpoint()
inference_params_draft.seqlen_offset = inference_params.seqlen_offset tokens, num_generated_tokens = sample_speculative(
if debug: logits, scores_draft, tokens_draft, **sampling_kwargs
cur_ids = torch.cat([input_ids, sequences[-1]], dim=1) )
scores_ref = model( num_accepted_tokens_history.append(num_generated_tokens - 1)
cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1 if debug:
).logits print(tokens)
print((scores[-1] - scores_ref[:, :-1]).abs().max()) print(num_generated_tokens)
while True:
# seqlen_offset is total length generated - 1
if inference_params.seqlen_offset >= max_length - 1:
break
if inference_params.seqlen_offset >= max_length - 2:
# Don't do speculative sampling, just sample 1 token from the model
tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1)
sequences.append(tokens)
scores.append(scores_new)
break
# Sample from draft model
n_spec_tokens = min(
speculative_lookahead, max_length - inference_params_draft.seqlen_offset - 2
)
tokens_draft, scores_draft = sample_tokens_draft(
sequences[-1][:, -1:], num_tokens=n_spec_tokens
)
if debug:
scores_draft_ref = model_draft(
torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
).logits
print((scores_draft[:, :-1] - scores_draft_ref[:, :-1]).abs().max())
# Evaluate the draft tokens with the model
position_ids = repeat(
torch.arange(
inference_params.seqlen_offset,
# 1 extra token from last time that hasn't been passed through model
inference_params.seqlen_offset + n_spec_tokens + 1,
dtype=torch.long,
device=input_ids.device,
),
"s -> b s",
b=batch_size,
)
logits = model(
torch.cat([sequences[-1][:, -1:], tokens_draft], dim=1),
position_ids=position_ids,
inference_params=inference_params,
).logits # (batch, n_spec_tokens, vocab_size)
logits = logits_postprocess_fn(logits)
inference_params.seqlen_offset += 1
if debug:
logits_ref = model(
torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
).logits
print((logits - logits_ref).abs().max())
tokens, num_generated_tokens = sample_speculative(
logits, scores_draft[:, :-1], tokens_draft, **sampling_kwargs
)
if debug:
print(tokens)
print(num_generated_tokens)
sequences.append(tokens[:1, : num_generated_tokens[0]])
scores.append(logits[:1, : num_generated_tokens[0]])
inference_params.seqlen_offset += num_generated_tokens[0].item() - 1
inference_params_draft.seqlen_offset = inference_params.seqlen_offset
# breakpoint() # breakpoint()
if debug: sequences.append(tokens[:1, : num_generated_tokens[0]])
cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1) scores.append(logits[:1, : num_generated_tokens[0]])
scores_ref = model( # We've evaluated 1 token from sequences[-1][:, -1:] above, plus
cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1 # num_generated_tokens[0].item() - 1 tokens from the draft model.
).logits num_generated = num_generated_tokens[0].item()
print((scores[-1] - scores_ref[:, :-1]).abs().max()) inference_params.seqlen_offset += num_generated
inference_params_draft.seqlen_offset = (
if enable_timing: inference_params.seqlen_offset - 1
if tensor_parallel > 1: if num_generated > 1
torch.distributed.barrier() else inference_params.seqlen_offset
torch.cuda.synchronize() )
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") if debug:
cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1)
scores_ref = model(cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1).logits
print((scores[-1] - scores_ref[:, :-1]).abs().max())
# breakpoint()
if enable_timing:
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
print(f"Number of calls to main model: {num_main_model_calls}")
print(
f"Acceptance rate: {torch.cat(num_accepted_tokens_history).sum().item() / num_draft_tokens * 100:.2f}%"
)
sequences = torch.cat(sequences, dim=1) sequences = torch.cat(sequences, dim=1)
scores = torch.cat(scores, dim=1) scores = torch.cat(scores, dim=1)
if debug: if debug:
...@@ -607,20 +601,6 @@ def allocate_inference_cache( ...@@ -607,20 +601,6 @@ def allocate_inference_cache(
return {i: torch.empty(kv_cache_shape, device=device, dtype=dtype) for i in layers} return {i: torch.empty(kv_cache_shape, device=device, dtype=dtype) for i in layers}
def seqlen_to_seqlen_type(seqlen: int) -> int:
"""Convert sequence length to a seqlen_type.
This is used to determine which cuda graph to use.
Arguments:
seqlen: int
"""
return 0
def seqlen_type_to_max_seqlen(seqlen_type: int) -> int:
assert seqlen_type in [0]
return 2**32
@dataclass @dataclass
class DecodingCGCache: class DecodingCGCache:
max_batch_size: int = 0 max_batch_size: int = 0
...@@ -640,6 +620,7 @@ def update_graph_cache( ...@@ -640,6 +620,7 @@ def update_graph_cache(
batch_size, batch_size,
seqlen_og, seqlen_og,
max_seqlen, max_seqlen,
decoding_seqlens=(1,),
tensor_parallel=1, tensor_parallel=1,
dtype=None, dtype=None,
n_warmups=2, n_warmups=2,
...@@ -687,38 +668,36 @@ def update_graph_cache( ...@@ -687,38 +668,36 @@ def update_graph_cache(
lengths_per_sample=lengths_per_sample, lengths_per_sample=lengths_per_sample,
) )
cache.mempool = torch.cuda.graphs.graph_pool_handle() cache.mempool = torch.cuda.graphs.graph_pool_handle()
for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1): for decoding_seqlen in decoding_seqlens:
if (batch_size, s_type) not in cache.callables: if (batch_size, decoding_seqlen) not in cache.callables:
max_seqlen_ = min(max(seqlen_og, seqlen_type_to_max_seqlen(s_type)), max_seqlen) cache.callables[batch_size, decoding_seqlen] = capture_graph(
cache.callables[batch_size, s_type] = capture_graph(
model, model,
cache.inference_params, cache.inference_params,
batch_size, batch_size,
max_seqlen_, max_seqlen,
decoding_seqlen=decoding_seqlen,
mempool=cache.mempool, mempool=cache.mempool,
n_warmups=n_warmups, n_warmups=n_warmups,
) )
def dispatch(input_ids, position_ids, seqlen): def dispatch(input_ids, position_ids, seqlen):
batch_size = input_ids.shape[0] batch_size, decoding_seqlen = input_ids.shape[:2]
return cache.callables[batch_size, seqlen_to_seqlen_type(seqlen)]( return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
input_ids, position_ids, seqlen
)
cache.run = dispatch cache.run = dispatch
cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
return cache return cache
def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, n_warmups=2): def capture_graph(
model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
):
device = next(iter(model.parameters())).device device = next(iter(model.parameters())).device
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device) input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device) position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
seqlen_offset_og = inference_params.seqlen_offset seqlen_offset_og = inference_params.seqlen_offset
# TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is inference_params.seqlen_offset = max_seqlen - decoding_seqlen
# used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample. inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
inference_params.seqlen_offset = max_seqlen - 1
inference_params.lengths_per_sample[:] = max_seqlen - 1
# Warmup before capture # Warmup before capture
s = torch.cuda.Stream() s = torch.cuda.Stream()
...@@ -729,7 +708,7 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, ...@@ -729,7 +708,7 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
input_ids, input_ids,
position_ids=position_ids, position_ids=position_ids,
inference_params=inference_params, inference_params=inference_params,
num_last_tokens=1, num_last_tokens=decoding_seqlen,
).logits ).logits
s.synchronize() s.synchronize()
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0, # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
...@@ -746,8 +725,8 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, ...@@ -746,8 +725,8 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
input_ids, input_ids,
position_ids=position_ids, position_ids=position_ids,
inference_params=inference_params, inference_params=inference_params,
num_last_tokens=1, num_last_tokens=decoding_seqlen,
).logits.squeeze(dim=1) ).logits
def run(new_input_ids, new_position_ids, seqlen): def run(new_input_ids, new_position_ids, seqlen):
inference_params.lengths_per_sample[:] = seqlen inference_params.lengths_per_sample[:] = seqlen
......
...@@ -383,11 +383,14 @@ def test_gpt2_multiple_token_generation(model_name, optimized): ...@@ -383,11 +383,14 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
@pytest.mark.parametrize("cg", [False, True]) @pytest.mark.parametrize("cg", [False, True])
# @pytest.mark.parametrize("optimized", [False, True]) # @pytest.mark.parametrize("cg", [True])
@pytest.mark.parametrize("optimized", [True]) @pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize("optimized", [True])
# @pytest.mark.parametrize("model_name", ["gpt2-medium"]) # @pytest.mark.parametrize("model_name", ["gpt2-medium"])
@pytest.mark.parametrize("model_name", ["gpt2-xl"]) @pytest.mark.parametrize("model_name", ["gpt2-xl"])
def test_gpt2_speculative_decoding(model_name, optimized, cg): def test_gpt2_speculative_decoding(model_name, optimized, cg):
if cg and not optimized:
pytest.skip() # CG requires use_flash_attn
dtype = torch.float16 dtype = torch.float16
device = "cuda" device = "cuda"
rtol, atol = 3e-3, 3e-1 rtol, atol = 3e-3, 3e-1
...@@ -421,6 +424,7 @@ def test_gpt2_speculative_decoding(model_name, optimized, cg): ...@@ -421,6 +424,7 @@ def test_gpt2_speculative_decoding(model_name, optimized, cg):
from flash_attn.utils.generation import decode_speculative from flash_attn.utils.generation import decode_speculative
torch.manual_seed(42) torch.manual_seed(42)
print(f"Speculative decoding, {optimized = }")
out = decode_speculative( out = decode_speculative(
input_ids, input_ids,
model, model,
...@@ -430,13 +434,15 @@ def test_gpt2_speculative_decoding(model_name, optimized, cg): ...@@ -430,13 +434,15 @@ def test_gpt2_speculative_decoding(model_name, optimized, cg):
cg=cg, cg=cg,
speculative_lookahead=4, speculative_lookahead=4,
enable_timing=True, enable_timing=True,
# debug=True,
) )
print(tokenizer.batch_decode(out.sequences)) print(tokenizer.batch_decode(out.sequences))
print(f"Without speculative decoding, {cg = }")
out_og = model.generate( out_og = model.generate(
input_ids, input_ids,
max_length=max_length, max_length=max_length,
top_k=5, top_k=5,
cg=False, cg=cg,
enable_timing=True, enable_timing=True,
return_dict_in_generate=True, return_dict_in_generate=True,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment