"vscode:/vscode.git/clone" did not exist on "5f1918a4a8ed893822aa7dd2b75acf83f255ad79"
Unverified Commit 92abe603 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

>3-5x faster torch.compile forward compilation for autoregressive decoder models (#32227)



* draft

* apply changes to all relevant archs

* rerun ci - check_docstrings.py failing?

* fix docstring

* move 2D->4D mask creation to modeling file

* repo consistency

* fix the batch size = 1 case - calling contiguous is not enough

* nit

* style

* propagate to gemma/gemma-2

* prepare inputs for gemma generation

* implement test and tiny fix in gemma2

* Update src/transformers/models/bloom/modeling_bloom.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fix copies

* ci pass

* fix gemma's test_compile_static_cache tests

* flacky

* retrigger ci

---------
Co-authored-by: default avatarsanchit-gandhi <sanchit@huggingface.co>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent b46bd8b9
...@@ -816,7 +816,7 @@ class GemmaIntegrationTest(unittest.TestCase): ...@@ -816,7 +816,7 @@ class GemmaIntegrationTest(unittest.TestCase):
# Dynamic Cache # Dynamic Cache
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[8], dynamic_text) # Both GPU architectures have the same output self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) # Both GPU architectures have the same output
# Static Cache # Static Cache
generated_ids = model.generate( generated_ids = model.generate(
......
...@@ -22,6 +22,7 @@ import os.path ...@@ -22,6 +22,7 @@ import os.path
import random import random
import re import re
import tempfile import tempfile
import time
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
...@@ -37,6 +38,7 @@ from transformers import ( ...@@ -37,6 +38,7 @@ from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
AutoTokenizer, AutoTokenizer,
GenerationConfig,
PretrainedConfig, PretrainedConfig,
PreTrainedModel, PreTrainedModel,
is_torch_available, is_torch_available,
...@@ -4605,7 +4607,6 @@ class ModelTesterMixin: ...@@ -4605,7 +4607,6 @@ class ModelTesterMixin:
tokenizer = AutoTokenizer.from_pretrained(ckpt) tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device) model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device)
model.generation_config.max_new_tokens = 4
model.generation_config.max_new_tokens = 4 model.generation_config.max_new_tokens = 4
model.generation_config.cache_implementation = "static" model.generation_config.cache_implementation = "static"
...@@ -4617,6 +4618,66 @@ class ModelTesterMixin: ...@@ -4617,6 +4618,66 @@ class ModelTesterMixin:
for i in range(n_iter): for i in range(n_iter):
_ = model.generate(**input_ids, do_sample=False) _ = model.generate(**input_ids, do_sample=False)
@slow
@require_torch_gpu # Testing cuda graphs.
@require_read_token
def test_compile_cuda_graph_time(self):
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
# TODO felix: All models supporting `StaticCache` or `torch.compile` should be tested.
# At the moment, only llama, gemma and gemma2 are tested here!
if not hasattr(self, "_torch_compile_test_ckpt"):
self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_test_ckpt`.")
ckpt = self._torch_compile_test_ckpt
os.environ["TOKENIZERS_PARALLELISM"] = "false"
tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device)
cache_implementation = "static"
if model.config.model_type == "gemma2":
cache_implementation = "hybrid"
new_tokens = 50
gen_config = GenerationConfig(
max_new_tokens=new_tokens,
min_new_tokens=new_tokens,
use_cache=True,
pad_token_id=tokenizer.pad_token_id,
num_beams=1,
do_sample=False,
eos_token_id=None, # This is required for min_new_tokens to actually have an effect.
)
model.generation_config.eos_token_id = None # greedy_search falls back on this eos_token_id that we need to set to None as well for min_new_tokens to have an effect.
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
inp = tokenizer("Why cats are cute?", return_tensors="pt").to(torch_device)
# First run: the first run warms up each graph, which does things like CuBlas or Triton benchmarking
start = time.perf_counter()
_ = model.generate(**inp, generation_config=gen_config, cache_implementation=cache_implementation)
end = time.perf_counter()
graph_warmup_time = end - start
# Second run: CUDA Graph recording, and replays it
start = time.perf_counter()
_ = model.generate(**inp, generation_config=gen_config, cache_implementation=cache_implementation)
end = time.perf_counter()
record_time = end - start
# Finally: we hit the optimized, CUDA Graph replay path
start = time.perf_counter()
_ = model.generate(**inp, generation_config=gen_config, cache_implementation=cache_implementation)
end = time.perf_counter()
opt_time = end - start
# For the recording step, we expect only two cuda graphs and this step should be much faster than the first.
self.assertTrue(record_time < 0.15 * graph_warmup_time)
self.assertTrue(opt_time < record_time)
global_rng = random.Random() global_rng = random.Random()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment