Commit e0fbaa70 authored by Tri Dao's avatar Tri Dao
Browse files

[Gen] Simplify decode_speculative

parent e6a80264
This diff is collapsed.
...@@ -383,11 +383,14 @@ def test_gpt2_multiple_token_generation(model_name, optimized): ...@@ -383,11 +383,14 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
@pytest.mark.parametrize("cg", [False, True]) @pytest.mark.parametrize("cg", [False, True])
# @pytest.mark.parametrize("optimized", [False, True]) # @pytest.mark.parametrize("cg", [True])
@pytest.mark.parametrize("optimized", [True]) @pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize("optimized", [True])
# @pytest.mark.parametrize("model_name", ["gpt2-medium"]) # @pytest.mark.parametrize("model_name", ["gpt2-medium"])
@pytest.mark.parametrize("model_name", ["gpt2-xl"]) @pytest.mark.parametrize("model_name", ["gpt2-xl"])
def test_gpt2_speculative_decoding(model_name, optimized, cg): def test_gpt2_speculative_decoding(model_name, optimized, cg):
if cg and not optimized:
pytest.skip() # CG requires use_flash_attn
dtype = torch.float16 dtype = torch.float16
device = "cuda" device = "cuda"
rtol, atol = 3e-3, 3e-1 rtol, atol = 3e-3, 3e-1
...@@ -421,6 +424,7 @@ def test_gpt2_speculative_decoding(model_name, optimized, cg): ...@@ -421,6 +424,7 @@ def test_gpt2_speculative_decoding(model_name, optimized, cg):
from flash_attn.utils.generation import decode_speculative from flash_attn.utils.generation import decode_speculative
torch.manual_seed(42) torch.manual_seed(42)
print(f"Speculative decoding, {optimized = }")
out = decode_speculative( out = decode_speculative(
input_ids, input_ids,
model, model,
...@@ -430,13 +434,15 @@ def test_gpt2_speculative_decoding(model_name, optimized, cg): ...@@ -430,13 +434,15 @@ def test_gpt2_speculative_decoding(model_name, optimized, cg):
cg=cg, cg=cg,
speculative_lookahead=4, speculative_lookahead=4,
enable_timing=True, enable_timing=True,
# debug=True,
) )
print(tokenizer.batch_decode(out.sequences)) print(tokenizer.batch_decode(out.sequences))
print(f"Without speculative decoding, {cg = }")
out_og = model.generate( out_og = model.generate(
input_ids, input_ids,
max_length=max_length, max_length=max_length,
top_k=5, top_k=5,
cg=False, cg=cg,
enable_timing=True, enable_timing=True,
return_dict_in_generate=True, return_dict_in_generate=True,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment