Commit 11be742a authored by Tri Dao's avatar Tri Dao
Browse files

[Gen] Test generation with rotary embedding

parent 8d9674ed
...@@ -146,15 +146,17 @@ class GPTPreTrainedModel(nn.Module): ...@@ -146,15 +146,17 @@ class GPTPreTrainedModel(nn.Module):
self.config = config self.config = config
@classmethod @classmethod
def from_pretrained(cls, model_name, config, *inputs, **kwargs): def from_pretrained(cls, model_name, config, *args, strict=True, device=None, **kwargs):
""" """
Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict. Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed. Download and cache the pre-trained model file if needed.
""" """
# Instantiate model. # Instantiate model.
model = cls(config, *inputs, **kwargs) model = cls(config, *args, device=device, **kwargs)
load_return = model.load_state_dict( load_return = model.load_state_dict(
remap_state_dict_gpt2(state_dict_from_pretrained(model_name), config)) remap_state_dict_gpt2(state_dict_from_pretrained(model_name, device=device), config),
strict=strict
)
logger.info(load_return) logger.info(load_return)
return model return model
......
...@@ -341,7 +341,6 @@ class MHA(nn.Module): ...@@ -341,7 +341,6 @@ class MHA(nn.Module):
self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2, self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2,
groups=3 * embed_dim) groups=3 * embed_dim)
else: else:
inner_attn_cls = inner_cross_attn_cls
self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs) self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs)
if not self.return_residual: if not self.return_residual:
self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs) self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
...@@ -482,9 +481,9 @@ class MHA(nn.Module): ...@@ -482,9 +481,9 @@ class MHA(nn.Module):
'b d s -> b s d').contiguous() 'b d s -> b s d').contiguous()
if inference_params is None: if inference_params is None:
if not self.checkpointing: if not self.checkpointing:
context = self.inner_attn(q, kv, **kwargs) context = self.inner_cross_attn(q, kv, **kwargs)
else: else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv, **kwargs) context = torch.utils.checkpoint.checkpoint(self.inner_cross_attn, q, kv, **kwargs)
else: else:
kv = self._update_kv_cache(kv) kv = self._update_kv_cache(kv)
context = self.inner_cross_attn(q, kv, causal=False) context = self.inner_cross_attn(q, kv, causal=False)
......
...@@ -4,5 +4,5 @@ from transformers.utils import WEIGHTS_NAME ...@@ -4,5 +4,5 @@ from transformers.utils import WEIGHTS_NAME
from transformers.utils.hub import cached_file from transformers.utils.hub import cached_file
def state_dict_from_pretrained(model_name): def state_dict_from_pretrained(model_name, device=None):
return torch.load(cached_file(model_name, WEIGHTS_NAME)) return torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device)
...@@ -14,32 +14,40 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained ...@@ -14,32 +14,40 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import greedy_decode from flash_attn.utils.generation import greedy_decode
# TODO: test with rotary embedding
@pytest.mark.parametrize('fused_ft_kernel', [False, True]) @pytest.mark.parametrize('fused_ft_kernel', [False, True])
@pytest.mark.parametrize('optimized', [False, True]) @pytest.mark.parametrize('optimized', [False, True])
# @pytest.mark.parametrize('fused_ft_kernel', [False])
# @pytest.mark.parametrize('optimized', [True]) # @pytest.mark.parametrize('optimized', [True])
# @pytest.mark.parametrize('optimized', [True])
@pytest.mark.parametrize('rotary', [False, True])
@pytest.mark.parametrize('model_name', ["gpt2"]) @pytest.mark.parametrize('model_name', ["gpt2"])
def test_greedy_decode(model_name, optimized, fused_ft_kernel): def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
"""Check that our implementation of GPT2 generation matches the HF implementation: """Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32. the HF scores in fp32.
""" """
dtype = torch.float16 dtype = torch.float16
device = 'cuda'
rtol, atol = 3e-3, 3e-1 rtol, atol = 3e-3, 3e-1
config = GPT2Config.from_pretrained(model_name) config = GPT2Config.from_pretrained(model_name)
if rotary:
config.n_positions = 0
config.rotary_emb_dim = 64
if optimized: if optimized:
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_dense_gelu_dense = True config.fused_dense_gelu_dense = True
config.fused_dropout_add_ln = True config.fused_dropout_add_ln = True
model = GPTLMHeadModel.from_pretrained(model_name, config) # if not rotary, we load the weight from HF but ignore the position embeddings.
model = model.cuda().to(dtype=dtype) # The model would be nonsense but it doesn't matter for the test.
model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device)
model = model.to(dtype=dtype)
model.eval()
if not rotary:
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda() model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda()
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype) model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype)
model.eval()
model_ref.eval() model_ref.eval()
model_hf.eval() model_hf.eval()
...@@ -47,6 +55,8 @@ def test_greedy_decode(model_name, optimized, fused_ft_kernel): ...@@ -47,6 +55,8 @@ def test_greedy_decode(model_name, optimized, fused_ft_kernel):
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and ", return_tensors="pt").input_ids.cuda() input_ids = tokenizer("Hello, my dog is cute and ", return_tensors="pt").input_ids.cuda()
max_length = 30 max_length = 30
# input_ids = torch.randint(0, 100, (1, 512), dtype=torch.long, device='cuda')
# max_length = 512 + 50
# Slow generation for reference # Slow generation for reference
sequences = [] sequences = []
...@@ -66,6 +76,7 @@ def test_greedy_decode(model_name, optimized, fused_ft_kernel): ...@@ -66,6 +76,7 @@ def test_greedy_decode(model_name, optimized, fused_ft_kernel):
fused_ft_kernel=fused_ft_kernel, fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, output_scores=True) return_dict_in_generate=True, output_scores=True)
if not rotary:
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length, out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True) return_dict_in_generate=True, output_scores=True)
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length, out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length,
...@@ -79,6 +90,7 @@ def test_greedy_decode(model_name, optimized, fused_ft_kernel): ...@@ -79,6 +90,7 @@ def test_greedy_decode(model_name, optimized, fused_ft_kernel):
assert torch.all(out.sequences == sequences) assert torch.all(out.sequences == sequences)
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
rtol=rtol, atol=atol) rtol=rtol, atol=atol)
if not rotary:
assert torch.all(out.sequences == out_ref.sequences) assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences) assert torch.all(out.sequences == out_hf.sequences)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment