"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "b972125ced4b60120e8bea606065059e3412d4ec"
Unverified Commit 903b97d8 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`gpt2-int8`] Add gpt2-xl int8 test (#24543)

add gpt2-xl test
parent b0651655
...@@ -762,8 +762,24 @@ class MixedInt8TestTraining(BaseMixedInt8Test): ...@@ -762,8 +762,24 @@ class MixedInt8TestTraining(BaseMixedInt8Test):
class MixedInt8GPT2Test(MixedInt8Test): class MixedInt8GPT2Test(MixedInt8Test):
model_name = "gpt2-xl" model_name = "gpt2-xl"
EXPECTED_RELATIVE_DIFFERENCE = 1.8720077507258357 EXPECTED_RELATIVE_DIFFERENCE = 1.8720077507258357
EXPECTED_OUTPUT = "Hello my name is John Doe, and I am a member of the" EXPECTED_OUTPUT = "Hello my name is John Doe, and I'm a big fan of"
def test_int8_from_pretrained(self): def test_int8_from_pretrained(self):
# TODO @younesbelkada: Test loading quantized gpt2 model from the hub. r"""
pass Test whether loading a 8bit model from the Hub works as expected
"""
from bitsandbytes.nn import Int8Params
model_id = "ybelkada/gpt2-xl-8bit"
model = AutoModelForCausalLM.from_pretrained(model_id)
linear = get_some_linear_layer(model)
self.assertTrue(linear.weight.__class__ == Int8Params)
self.assertTrue(hasattr(linear.weight, "SCB"))
# generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment