Commit 8a733cbd authored by Tri Dao's avatar Tri Dao
Browse files

[Gen] Fix calling update_graph_cache in tests

parent 4c91621a
...@@ -659,8 +659,7 @@ class MHA(nn.Module): ...@@ -659,8 +659,7 @@ class MHA(nn.Module):
qkv = rearrange( qkv = rearrange(
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d" self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
).contiguous() ).contiguous()
# qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
qkv = qkv.reshape(batch, seqlen, 3, self.num_heads, self.head_dim)
if ( if (
inference_params is None inference_params is None
or inference_params.sequence_len_offset == 0 or inference_params.sequence_len_offset == 0
...@@ -700,10 +699,8 @@ class MHA(nn.Module): ...@@ -700,10 +699,8 @@ class MHA(nn.Module):
qkv, x = self.Wqkv(x) qkv, x = self.Wqkv(x)
q = qkv[..., : self.num_heads * self.head_dim] q = qkv[..., : self.num_heads * self.head_dim]
kv = qkv[..., self.num_heads * self.head_dim :] kv = qkv[..., self.num_heads * self.head_dim :]
# q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
q = q.reshape(batch, seqlen, -1, self.head_dim) kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
# kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
kv = kv.reshape(batch, seqlen, 2, -1, self.head_dim)
if self.dwconv: if self.dwconv:
q = rearrange( q = rearrange(
self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d" self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
...@@ -731,8 +728,7 @@ class MHA(nn.Module): ...@@ -731,8 +728,7 @@ class MHA(nn.Module):
context = self._update_kvcache_attention(q, kv, inference_params) context = self._update_kvcache_attention(q, kv, inference_params)
else: else:
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv) context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
# out = self.out_proj(rearrange(context, "... h d -> ... (h d)")) out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
out = self.out_proj(context.reshape(batch, seqlen, -1))
return out if not self.return_residual else (out, x) return out if not self.return_residual else (out, x)
......
# Copyright (c) 2023, Tri Dao.
from typing import Optional, Union from typing import Optional, Union
import torch import torch
......
...@@ -404,7 +404,7 @@ def test_baichuan_parallel_generation(model_name, world_size): ...@@ -404,7 +404,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
# Capture graph outside the timing loop # Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache( model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=False
) )
print("With CUDA graph") print("With CUDA graph")
out_cg = model.generate( out_cg = model.generate(
......
...@@ -253,7 +253,9 @@ def test_falcon_generation(model_name): ...@@ -253,7 +253,9 @@ def test_falcon_generation(model_name):
# Capture graph outside the timing loop # Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
)
print("With CUDA graph") print("With CUDA graph")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
...@@ -356,7 +358,9 @@ def test_falcon_parallel_generation(model_name, world_size): ...@@ -356,7 +358,9 @@ def test_falcon_parallel_generation(model_name, world_size):
# Capture graph outside the timing loop # Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
)
print("With CUDA graph") print("With CUDA graph")
out_cg = model.generate( out_cg = model.generate(
input_ids=input_ids, input_ids=input_ids,
......
...@@ -6,7 +6,6 @@ import pytest ...@@ -6,7 +6,6 @@ import pytest
import torch import torch
from flash_attn.models.gpt import GPTLMHeadModel from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt_neox import gpt_neox_config_to_gpt2_config, remap_state_dict_hf_gpt_neox from flash_attn.models.gpt_neox import gpt_neox_config_to_gpt2_config, remap_state_dict_hf_gpt_neox
from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import AutoTokenizer, GPTNeoXConfig from transformers import AutoTokenizer, GPTNeoXConfig
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
......
...@@ -83,8 +83,9 @@ def test_gptj_optimized(model_name): ...@@ -83,8 +83,9 @@ def test_gptj_optimized(model_name):
).abs().max().item() ).abs().max().item()
@pytest.mark.parametrize("fused_ft_kernel", [False, True])
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"]) @pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
def test_gptj_generation(model_name): def test_gptj_generation(model_name, fused_ft_kernel):
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the """Check that our implementation of GPT-J (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32. forward pass in fp16, when compared to the HF forward pass in fp32.
...@@ -140,8 +141,7 @@ def test_gptj_generation(model_name): ...@@ -140,8 +141,7 @@ def test_gptj_generation(model_name):
input_ids=input_ids, input_ids=input_ids,
max_length=max_length, max_length=max_length,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
fused_ft_kernel=True, fused_ft_kernel=fused_ft_kernel,
# eos_token_id=eos_token_id, fused_ft_kernel=False,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
enable_timing=True, enable_timing=True,
...@@ -152,14 +152,16 @@ def test_gptj_generation(model_name): ...@@ -152,14 +152,16 @@ def test_gptj_generation(model_name):
# Capture graph outside the timing loop # Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=fused_ft_kernel
)
print("With CUDA graph") print("With CUDA graph")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
out_cg = model.generate( out_cg = model.generate(
input_ids=input_ids, input_ids=input_ids,
max_length=max_length, max_length=max_length,
fused_ft_kernel=True, fused_ft_kernel=fused_ft_kernel,
cg=True, cg=True,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
......
...@@ -303,7 +303,9 @@ def test_llama_generation(model_name, checkpoint_format): ...@@ -303,7 +303,9 @@ def test_llama_generation(model_name, checkpoint_format):
# Capture graph outside the timing loop # Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
)
print("With CUDA graph") print("With CUDA graph")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
...@@ -408,7 +410,9 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): ...@@ -408,7 +410,9 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
# Capture graph outside the timing loop # Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
)
print("With CUDA graph") print("With CUDA graph")
out_cg = model.generate( out_cg = model.generate(
input_ids=input_ids, input_ids=input_ids,
......
...@@ -168,7 +168,9 @@ def test_opt_generation(model_name): ...@@ -168,7 +168,9 @@ def test_opt_generation(model_name):
if fused_ft_kernel: if fused_ft_kernel:
# Capture graph outside the timing loop # Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
)
print("With CUDA graph") print("With CUDA graph")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment