Commit e0fbaa70 authored by Tri Dao's avatar Tri Dao
Browse files

[Gen] Simplify decode_speculative

parent e6a80264
This diff is collapsed.
......@@ -383,11 +383,14 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
@pytest.mark.parametrize("cg", [False, True])
# @pytest.mark.parametrize("optimized", [False, True])
@pytest.mark.parametrize("optimized", [True])
# @pytest.mark.parametrize("cg", [True])
@pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize("optimized", [True])
# @pytest.mark.parametrize("model_name", ["gpt2-medium"])
@pytest.mark.parametrize("model_name", ["gpt2-xl"])
def test_gpt2_speculative_decoding(model_name, optimized, cg):
if cg and not optimized:
pytest.skip() # CG requires use_flash_attn
dtype = torch.float16
device = "cuda"
rtol, atol = 3e-3, 3e-1
......@@ -421,6 +424,7 @@ def test_gpt2_speculative_decoding(model_name, optimized, cg):
from flash_attn.utils.generation import decode_speculative
torch.manual_seed(42)
print(f"Speculative decoding, {optimized = }")
out = decode_speculative(
input_ids,
model,
......@@ -430,13 +434,15 @@ def test_gpt2_speculative_decoding(model_name, optimized, cg):
cg=cg,
speculative_lookahead=4,
enable_timing=True,
# debug=True,
)
print(tokenizer.batch_decode(out.sequences))
print(f"Without speculative decoding, {cg = }")
out_og = model.generate(
input_ids,
max_length=max_length,
top_k=5,
cg=False,
cg=cg,
enable_timing=True,
return_dict_in_generate=True,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment