Unverified Commit 72256bc7 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `PersimmonIntegrationTest` OOM (#26750)



* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent ab0ddc99
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
""" Testing suite for the PyTorch Persimmon model. """ """ Testing suite for the PyTorch Persimmon model. """
import gc
import unittest import unittest
from parameterized import parameterized from parameterized import parameterized
...@@ -395,19 +396,27 @@ class PersimmonIntegrationTest(unittest.TestCase): ...@@ -395,19 +396,27 @@ class PersimmonIntegrationTest(unittest.TestCase):
def test_model_8b_chat_logits(self): def test_model_8b_chat_logits(self):
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
model = PersimmonForCausalLM.from_pretrained( model = PersimmonForCausalLM.from_pretrained(
"adept/persimmon-8b-chat", device_map="auto", torch_dtype=torch.float16 "adept/persimmon-8b-chat", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
) )
out = model(torch.tensor([input_ids])).logits out = model(torch.tensor([input_ids], device=torch_device)).logits
EXPECTED_MEAN = torch.tensor( EXPECTED_MEAN = torch.tensor(
[[-11.2879, -11.2628, -11.2498, -11.2534, -11.2676, -11.2638, -11.2501, -11.2431]], dtype=torch.float16 [[-11.4726, -11.1495, -11.2694, -11.2223, -10.9452, -11.0663, -11.0031, -11.1028]]
) )
torch.testing.assert_close(out.cpu().mean(-1), EXPECTED_MEAN, atol=1e-4, rtol=1e-4) # change dtype to `torch.float32` before calling `mean` to avoid `nan` values
torch.testing.assert_close(out.cpu().to(torch.float32).mean(-1), EXPECTED_MEAN, atol=1e-4, rtol=1e-4)
# fmt: off # fmt: off
EXPECTED_SLICE = torch.tensor([-16.9670, -16.9647, -16.9649, -16.9630, -16.9577, -16.9623, -17.0164, -16.9673, -16.9648, -16.9668, -17.0160, -16.9651, -17.0156, -16.9668, -16.9655, -16.9653, -16.9665, -16.9682, -17.0112, -16.9667, -16.9717, -16.9654, -16.9650, -16.9701, -16.9657, -17.0160, -16.9676, -17.0138, -16.9610, -16.9695]) EXPECTED_SLICE = torch.tensor(
[-16.9062, -16.9062, -16.9062, -16.9062, -16.8906, -16.9062, -16.9531, -16.9062, -16.9062, -16.9062, -16.9531, -16.9062, -16.9531, -16.9062, -16.9062, -16.9062, -16.9062, -16.9062, -16.9531, -16.9062, -16.9062, -16.9062, -16.9062, -16.9062, -16.9062, -16.9531, -16.9062, -16.9531, -16.9062, -16.9062],
dtype=torch.float16
)
# fmt: on # fmt: on
torch.testing.assert_close(out.cpu()[0, 0, :30], EXPECTED_SLICE, atol=1e-5, rtol=1e-5) torch.testing.assert_close(out.cpu()[0, 0, :30], EXPECTED_SLICE, atol=1e-5, rtol=1e-5)
torch.cuda.empty_cache()
del model
gc.collect()
@slow @slow
@require_torch_gpu @require_torch_gpu
def test_model_8b_chat_greedy_generation(self): def test_model_8b_chat_greedy_generation(self):
...@@ -415,11 +424,15 @@ class PersimmonIntegrationTest(unittest.TestCase): ...@@ -415,11 +424,15 @@ class PersimmonIntegrationTest(unittest.TestCase):
prompt = "human: Simply put, the theory of relativity states that?\n\nadept:" prompt = "human: Simply put, the theory of relativity states that?\n\nadept:"
tokenizer = AutoTokenizer.from_pretrained("adept/persimmon-8b-chat", use_fast=False) tokenizer = AutoTokenizer.from_pretrained("adept/persimmon-8b-chat", use_fast=False)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(torch_device) input_ids = tokenizer.encode(prompt, return_tensors="pt").to(torch_device)
model = PersimmonForCausalLM.from_pretrained("adept/persimmon-8b-chat", torch_dtype=torch.float16).to( model = PersimmonForCausalLM.from_pretrained(
torch_device "adept/persimmon-8b-chat", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
) )
# greedy generation outputs # greedy generation outputs
generated_ids = model.generate(input_ids, max_new_tokens=64) generated_ids = model.generate(input_ids, max_new_tokens=64)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text) self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
torch.cuda.empty_cache()
del model
gc.collect()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment