Unverified Commit e3fc90ae authored by Andrei Panferov's avatar Andrei Panferov Committed by GitHub
Browse files

Cleaner Cache `dtype` and `device` extraction for CUDA graph generation for...

Cleaner Cache `dtype` and `device` extraction for CUDA graph generation for quantizers compatibility (#29079)

* input_layernorm as the beacon of hope

* cleaner dtype extraction

* AQLM + CUDA graph test

* is available check

* shorter text test
parent a3f9221a
...@@ -817,9 +817,13 @@ class LlamaPreTrainedModel(PreTrainedModel): ...@@ -817,9 +817,13 @@ class LlamaPreTrainedModel(PreTrainedModel):
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
for layer in self.model.layers: for layer in self.model.layers:
weights = layer.self_attn.o_proj.weight device = layer.input_layernorm.weight.device
if hasattr(self.config, "_pre_quantization_dtype"):
dtype = self.config._pre_quantization_dtype
else:
dtype = layer.self_attn.o_proj.weight.dtype
layer.self_attn.past_key_value = cache_cls( layer.self_attn.past_key_value = cache_cls(
self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype self.config, max_batch_size, max_cache_len, device=device, dtype=dtype
) )
def _reset_cache(self): def _reset_cache(self):
......
...@@ -14,10 +14,13 @@ ...@@ -14,10 +14,13 @@
# limitations under the License. # limitations under the License.
import gc import gc
import importlib
import tempfile import tempfile
import unittest import unittest
from transformers import AqlmConfig, AutoConfig, AutoModelForCausalLM, AutoTokenizer, OPTForCausalLM from packaging import version
from transformers import AqlmConfig, AutoConfig, AutoModelForCausalLM, AutoTokenizer, OPTForCausalLM, StaticCache
from transformers.testing_utils import ( from transformers.testing_utils import (
require_accelerate, require_accelerate,
require_aqlm, require_aqlm,
...@@ -26,7 +29,7 @@ from transformers.testing_utils import ( ...@@ -26,7 +29,7 @@ from transformers.testing_utils import (
slow, slow,
torch_device, torch_device,
) )
from transformers.utils import is_accelerate_available, is_torch_available from transformers.utils import is_accelerate_available, is_aqlm_available, is_torch_available
if is_torch_available(): if is_torch_available():
...@@ -71,11 +74,12 @@ class AqlmConfigTest(unittest.TestCase): ...@@ -71,11 +74,12 @@ class AqlmConfigTest(unittest.TestCase):
@require_aqlm @require_aqlm
@require_accelerate @require_accelerate
class AqlmTest(unittest.TestCase): class AqlmTest(unittest.TestCase):
model_name = "BlackSamorez/Mixtral-8x7b-AQLM-2Bit-1x16-hf-test-dispatch" model_name = "BlackSamorez/Llama-2-7b-AQLM-2Bit-1x16-hf"
input_text = "Hello my name is" input_text = "Hello my name is"
max_new_tokens = 32
EXPECTED_OUTPUT = "Hello my name is Katie and I am a 20 year old student at the University of North Carolina at Chapel Hill. I am currently a sophomore and am majoring in Psychology. I am" EXPECTED_OUTPUT = "Hello my name is Katie. I am a 20 year old college student. I am a very outgoing person. I love to have fun and be active. I"
device_map = "cuda" device_map = "cuda"
...@@ -144,7 +148,7 @@ class AqlmTest(unittest.TestCase): ...@@ -144,7 +148,7 @@ class AqlmTest(unittest.TestCase):
""" """
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = self.quantized_model.generate(**input_ids, max_new_tokens=40) output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_raise_if_non_quantized(self): def test_raise_if_non_quantized(self):
...@@ -164,7 +168,7 @@ class AqlmTest(unittest.TestCase): ...@@ -164,7 +168,7 @@ class AqlmTest(unittest.TestCase):
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = model.generate(**input_ids, max_new_tokens=40) output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_torch_multi_gpu @require_torch_multi_gpu
...@@ -178,6 +182,56 @@ class AqlmTest(unittest.TestCase): ...@@ -178,6 +182,56 @@ class AqlmTest(unittest.TestCase):
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1}) self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
output = quantized_model.generate(**input_ids, max_new_tokens=40) output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@unittest.skipUnless(
is_aqlm_available() and version.parse(importlib.metadata.version("aqlm")) >= version.parse("1.0.3"),
"test requires `aqlm>=1.0.3`",
)
def test_quantized_model_compile(self):
"""
Simple test that checks if the quantized model is working properly
"""
# Sample tokens greedily
def decode_one_tokens(model, cur_token, input_pos, cache_position):
logits = model(
cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True
)[0]
new_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)
return new_token
# Tokenize the test input
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)["input_ids"]
seq_length = input_ids.shape[1]
# Setup static KV cache for generation
self.quantized_model._setup_cache(StaticCache, 1, max_cache_len=seq_length + self.max_new_tokens + 1)
# Allocate token ids to be generated and copy prefix ids
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(1, seq_length + self.max_new_tokens, dtype=torch.int, device=torch_device)
generated_ids[:, cache_position] = input_ids.to(torch_device).to(torch.int)
# Do a forward pass to fill the prefix cache and compile the kernels if necessary
logits = self.quantized_model(input_ids, cache_position=cache_position, return_dict=False, use_cache=True)[0]
next_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)
generated_ids[:, [seq_length]] = next_token
with torch.no_grad():
# Compile the CUDA graph
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
# Generate tokens one by one
cache_position = torch.tensor([seq_length + 1], device=torch_device)
for _ in range(1, self.max_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
next_token = decode_one_tokens(self.quantized_model, next_token.clone(), None, cache_position)
generated_ids.index_copy_(1, cache_position, next_token)
cache_position += 1
# Check generated text
self.assertEqual(self.tokenizer.decode(generated_ids[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment