Unverified Commit 1b3dba94 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Make `Gemma` work with `torch.compile` (#30775)



* fix

* [run-slow] gemma

* add test

* add `test_compile_static_cache`

* fix

* style

* remove subprocess

* use attribute

* fix

* style

* update

* [run-slow] dbrx,gemma,jetmoe,phi3,recurrent_gemma

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 0753134f
...@@ -55,15 +55,14 @@ class DbrxRotaryEmbedding(nn.Module): ...@@ -55,15 +55,14 @@ class DbrxRotaryEmbedding(nn.Module):
self.dim = dim self.dim = dim
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.base = base self.base = base
self.register_buffer("inv_freq", None, persistent=False)
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
@torch.no_grad() @torch.no_grad()
def forward(self, x, position_ids, seq_len=None): def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None: self.inv_freq.to(x.device)
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float() position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts # Force float32 since bfloat16 loses precision on long contexts
......
...@@ -104,15 +104,14 @@ class GemmaRotaryEmbedding(nn.Module): ...@@ -104,15 +104,14 @@ class GemmaRotaryEmbedding(nn.Module):
self.dim = dim self.dim = dim
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.base = base self.base = base
self.register_buffer("inv_freq", None, persistent=False)
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
@torch.no_grad() @torch.no_grad()
def forward(self, x, position_ids, seq_len=None): def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None: self.inv_freq.to(x.device)
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float() position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts # Force float32 since bfloat16 loses precision on long contexts
......
...@@ -397,15 +397,14 @@ class JetMoeRotaryEmbedding(nn.Module): ...@@ -397,15 +397,14 @@ class JetMoeRotaryEmbedding(nn.Module):
self.dim = dim self.dim = dim
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.base = base self.base = base
self.register_buffer("inv_freq", None, persistent=False)
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
@torch.no_grad() @torch.no_grad()
def forward(self, x, position_ids, seq_len=None): def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None: self.inv_freq.to(x.device)
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float() position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts # Force float32 since bfloat16 loses precision on long contexts
......
...@@ -99,15 +99,14 @@ class Phi3RotaryEmbedding(nn.Module): ...@@ -99,15 +99,14 @@ class Phi3RotaryEmbedding(nn.Module):
self.dim = dim self.dim = dim
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.base = base self.base = base
self.register_buffer("inv_freq", None, persistent=False)
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
@torch.no_grad() @torch.no_grad()
def forward(self, x, position_ids, seq_len=None): def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None: self.inv_freq.to(x.device)
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float() position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts # Force float32 since bfloat16 loses precision on long contexts
......
...@@ -68,16 +68,14 @@ class RecurrentGemmaRotaryEmbedding(nn.Module): ...@@ -68,16 +68,14 @@ class RecurrentGemmaRotaryEmbedding(nn.Module):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.base = base self.base = base
self.register_buffer("inv_freq", None, persistent=False) inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
@torch.no_grad() @torch.no_grad()
# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->RecurrentGemma # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->RecurrentGemma
def forward(self, x, position_ids, seq_len=None): def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None: self.inv_freq.to(x.device)
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float() position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts # Force float32 since bfloat16 loses precision on long contexts
......
...@@ -17,6 +17,7 @@ import tempfile ...@@ -17,6 +17,7 @@ import tempfile
import unittest import unittest
import pytest import pytest
from packaging import version
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
from transformers.testing_utils import ( from transformers.testing_utils import (
...@@ -40,7 +41,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin ...@@ -40,7 +41,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import GemmaForCausalLM, GemmaForSequenceClassification, GemmaModel from transformers import GemmaForCausalLM, GemmaForSequenceClassification, GemmaModel, GemmaTokenizer
class GemmaModelTester: class GemmaModelTester:
...@@ -302,6 +303,9 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -302,6 +303,9 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# This is because we are hitting edge cases with the causal_mask buffer # This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.6] model_split_percents = [0.5, 0.6]
# used in `test_torch_compile`
_torch_compile_test_ckpt = "google/gemma-2b"
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
...@@ -801,3 +805,51 @@ class GemmaIntegrationTest(unittest.TestCase): ...@@ -801,3 +805,51 @@ class GemmaIntegrationTest(unittest.TestCase):
output_text = tokenizer.batch_decode(output, skip_special_tokens=True) output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version]) self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
@slow
@require_torch_gpu
@require_read_token
def test_compile_static_cache(self):
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
if version.parse(torch.__version__) < version.parse("2.3.0"):
self.skipTest("This test requires torch >= 2.3 to run.")
NUM_TOKENS_TO_GENERATE = 40
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test
# was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs.
EXPECTED_TEXT_COMPLETION = {
8: [
"Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have looked on the internet and I have found",
"Hi today\nI have a problem with my 2007 1.9 tdi 105bhp.\nI have a problem with the engine management light on.\nI have checked the",
],
7: [
"Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have looked on the internet and I have found",
"Hi today\nI have a problem with my 2007 1.9 tdi 105bhp.\nI have a problem with the engine management light on.\nI have checked the",
],
}
prompts = ["Hello I am doing", "Hi today"]
tokenizer = GemmaTokenizer.from_pretrained("google/gemma-2b", pad_token="</s>", padding_side="right")
model = GemmaForCausalLM.from_pretrained("google/gemma-2b", device_map="sequential", torch_dtype=torch.float16)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
# Dynamic Cache
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[8], dynamic_text) # Both GPU architectures have the same output
# Static Cache
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text)
# Static Cache + compile
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
...@@ -312,6 +312,9 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -312,6 +312,9 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# This is because we are hitting edge cases with the causal_mask buffer # This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.7, 0.8] model_split_percents = [0.5, 0.7, 0.8]
# used in `test_torch_compile`
_torch_compile_test_ckpt = "meta-llama/Llama-2-7b-hf"
def setUp(self): def setUp(self):
self.model_tester = LlamaModelTester(self) self.model_tester = LlamaModelTester(self)
self.config_tester = ConfigTester(self, config_class=LlamaConfig, hidden_size=37) self.config_tester = ConfigTester(self, config_class=LlamaConfig, hidden_size=37)
......
...@@ -27,6 +27,7 @@ from collections import defaultdict ...@@ -27,6 +27,7 @@ from collections import defaultdict
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import numpy as np import numpy as np
from packaging import version
from parameterized import parameterized from parameterized import parameterized
from pytest import mark from pytest import mark
...@@ -35,6 +36,7 @@ from transformers import ( ...@@ -35,6 +36,7 @@ from transformers import (
AutoModel, AutoModel,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
AutoTokenizer,
PretrainedConfig, PretrainedConfig,
PreTrainedModel, PreTrainedModel,
is_torch_available, is_torch_available,
...@@ -71,6 +73,7 @@ from transformers.testing_utils import ( ...@@ -71,6 +73,7 @@ from transformers.testing_utils import (
require_accelerate, require_accelerate,
require_bitsandbytes, require_bitsandbytes,
require_flash_attn, require_flash_attn,
require_read_token,
require_safetensors, require_safetensors,
require_torch, require_torch,
require_torch_gpu, require_torch_gpu,
...@@ -4399,6 +4402,38 @@ class ModelTesterMixin: ...@@ -4399,6 +4402,38 @@ class ModelTesterMixin:
normalized_1 = F.softmax(out_shared_prefix_last_tokens) normalized_1 = F.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
# For now, Let's focus only on GPU for `torch.compile`
@slow
@require_torch_gpu
@require_read_token
def test_torch_compile(self):
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest("This test requires torch >= 2.3 to run.")
if not hasattr(self, "_torch_compile_test_ckpt"):
self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_test_ckpt`.")
ckpt = self._torch_compile_test_ckpt
os.environ["TOKENIZERS_PARALLELISM"] = "false"
batch_size = 1
n_iter = 3
tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device)
model.generation_config.max_new_tokens = 4
model.generation_config.max_new_tokens = 4
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
input_text = "Why dogs are cute?"
input_ids = tokenizer([input_text] * batch_size, return_tensors="pt").to(torch_device)
for i in range(n_iter):
_ = model.generate(**input_ids, do_sample=False)
global_rng = random.Random() global_rng = random.Random()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment