Unverified Commit 92abe603 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

>3-5x faster torch.compile forward compilation for autoregressive decoder models (#32227)



* draft

* apply changes to all relevant archs

* rerun ci - check_docstrings.py failing?

* fix docstring

* move 2D->4D mask creation to modeling file

* repo consistency

* fix the batch size = 1 case - calling contiguous is not enough

* nit

* style

* propagate to gemma/gemma-2

* prepare inputs for gemma generation

* implement test and tiny fix in gemma2

* Update src/transformers/models/bloom/modeling_bloom.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fix copies

* ci pass

* fix gemma's test_compile_static_cache tests

* flacky

* retrigger ci

---------
Co-authored-by: default avatarsanchit-gandhi <sanchit@huggingface.co>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent b46bd8b9
......@@ -816,7 +816,7 @@ class GemmaIntegrationTest(unittest.TestCase):
# Dynamic Cache
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[8], dynamic_text) # Both GPU architectures have the same output
self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) # Both GPU architectures have the same output
# Static Cache
generated_ids = model.generate(
......
......@@ -22,6 +22,7 @@ import os.path
import random
import re
import tempfile
import time
import warnings
from collections import defaultdict
from typing import Dict, List, Tuple
......@@ -37,6 +38,7 @@ from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
GenerationConfig,
PretrainedConfig,
PreTrainedModel,
is_torch_available,
......@@ -4605,7 +4607,6 @@ class ModelTesterMixin:
tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device)
model.generation_config.max_new_tokens = 4
model.generation_config.max_new_tokens = 4
model.generation_config.cache_implementation = "static"
......@@ -4617,6 +4618,66 @@ class ModelTesterMixin:
for i in range(n_iter):
_ = model.generate(**input_ids, do_sample=False)
@slow
@require_torch_gpu # Testing cuda graphs.
@require_read_token
def test_compile_cuda_graph_time(self):
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
# TODO felix: All models supporting `StaticCache` or `torch.compile` should be tested.
# At the moment, only llama, gemma and gemma2 are tested here!
if not hasattr(self, "_torch_compile_test_ckpt"):
self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_test_ckpt`.")
ckpt = self._torch_compile_test_ckpt
os.environ["TOKENIZERS_PARALLELISM"] = "false"
tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device)
cache_implementation = "static"
if model.config.model_type == "gemma2":
cache_implementation = "hybrid"
new_tokens = 50
gen_config = GenerationConfig(
max_new_tokens=new_tokens,
min_new_tokens=new_tokens,
use_cache=True,
pad_token_id=tokenizer.pad_token_id,
num_beams=1,
do_sample=False,
eos_token_id=None, # This is required for min_new_tokens to actually have an effect.
)
model.generation_config.eos_token_id = None # greedy_search falls back on this eos_token_id that we need to set to None as well for min_new_tokens to have an effect.
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
inp = tokenizer("Why cats are cute?", return_tensors="pt").to(torch_device)
# First run: the first run warms up each graph, which does things like CuBlas or Triton benchmarking
start = time.perf_counter()
_ = model.generate(**inp, generation_config=gen_config, cache_implementation=cache_implementation)
end = time.perf_counter()
graph_warmup_time = end - start
# Second run: CUDA Graph recording, and replays it
start = time.perf_counter()
_ = model.generate(**inp, generation_config=gen_config, cache_implementation=cache_implementation)
end = time.perf_counter()
record_time = end - start
# Finally: we hit the optimized, CUDA Graph replay path
start = time.perf_counter()
_ = model.generate(**inp, generation_config=gen_config, cache_implementation=cache_implementation)
end = time.perf_counter()
opt_time = end - start
# For the recording step, we expect only two cuda graphs and this step should be much faster than the first.
self.assertTrue(record_time < 0.15 * graph_warmup_time)
self.assertTrue(opt_time < record_time)
global_rng = random.Random()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment