"git@developer.sourcefind.cn:gaoqiong/pybind11.git" did not exist on "7ed08cb7aa26c0e905633e7f396d77c74b0ebc34"
Commit f5d0fbd4 authored by Tri Dao's avatar Tri Dao
Browse files

[FT] Fix FT's single query attention for bf16 hdim128 rotary

parent 4d87e4d8
...@@ -1669,22 +1669,6 @@ __device__ __inline__ void write_smem_transpose(const float& vec, float* smem, i ...@@ -1669,22 +1669,6 @@ __device__ __inline__ void write_smem_transpose(const float& vec, float* smem, i
return; return;
} }
#ifdef ENABLE_BF16
template<>
__device__ __inline__ void
write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
return;
}
template<>
__device__ __inline__ void
write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
return;
}
#endif
template<> template<>
__device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch) __device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
{ {
...@@ -1776,6 +1760,20 @@ write_smem_transpose(const __nv_bfloat162& vec, __nv_bfloat16* smem, int transpo ...@@ -1776,6 +1760,20 @@ write_smem_transpose(const __nv_bfloat162& vec, __nv_bfloat16* smem, int transpo
smem[transpose_idx] = vec.x; smem[transpose_idx] = vec.x;
smem[smem_pitch + transpose_idx] = vec.y; smem[smem_pitch + transpose_idx] = vec.y;
} }
template<>
__device__ __inline__ void
write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
write_smem_transpose(reinterpret_cast<const uint2&>(vec), reinterpret_cast<uint16_t*>(smem), transpose_idx, smem_pitch);
}
template<>
__device__ __inline__ void
write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
write_smem_transpose(reinterpret_cast<const uint4&>(vec), reinterpret_cast<uint16_t*>(smem), transpose_idx, smem_pitch);
}
#endif #endif
template<> template<>
......
...@@ -494,7 +494,8 @@ class MHA(nn.Module): ...@@ -494,7 +494,8 @@ class MHA(nn.Module):
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1), *rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
*inference_params.key_value_memory_dict[self.layer_idx], *inference_params.key_value_memory_dict[self.layer_idx],
inference_params.lengths_per_sample, inference_params.sequence_len_offset, inference_params.lengths_per_sample, inference_params.sequence_len_offset,
self.rotary_emb_dim self.rotary_emb_dim,
not self.rotary_emb.interleaved # neox_rotary_style
) )
context = rearrange(context, 'b h d -> b 1 h d') context = rearrange(context, 'b h d -> b 1 h d')
else: else:
...@@ -607,7 +608,8 @@ class ParallelMHA(nn.Module): ...@@ -607,7 +608,8 @@ class ParallelMHA(nn.Module):
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1), *rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
*inference_params.key_value_memory_dict[self.layer_idx], *inference_params.key_value_memory_dict[self.layer_idx],
inference_params.lengths_per_sample, inference_params.sequence_len_offset, inference_params.lengths_per_sample, inference_params.sequence_len_offset,
self.rotary_emb_dim self.rotary_emb_dim,
not self.rotary_emb.interleaved # neox_rotary_style
) )
context = rearrange(context, 'b h d -> b 1 h d') context = rearrange(context, 'b h d -> b 1 h d')
if seqlen is None: if seqlen is None:
......
...@@ -82,6 +82,8 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, ...@@ -82,6 +82,8 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
Arguments: Arguments:
input_ids: (batch, seq_len) input_ids: (batch, seq_len)
max_length: int max_length: int
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
logits, the next token is taken from the teacher_outputs. Useful for testing.
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
sequences: (batch, max_length) sequences: (batch, max_length)
scores: tuples of (batch, vocab_size) scores: tuples of (batch, vocab_size)
...@@ -111,7 +113,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, ...@@ -111,7 +113,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
start = time.time() start = time.time()
if vocab_size is not None: if vocab_size is not None:
logits = logits[..., :vocab_size] logits = logits[..., :vocab_size]
scores.append(logits) scores.append(logits if not cg else logits.clone())
if teacher_outputs is None or teacher_output_len <= seqlen_og: if teacher_outputs is None or teacher_output_len <= seqlen_og:
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
else: else:
...@@ -129,7 +131,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, ...@@ -129,7 +131,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
inference_params.sequence_len_offset) inference_params.sequence_len_offset)
if vocab_size is not None: if vocab_size is not None:
logits = logits[..., :vocab_size] logits = logits[..., :vocab_size]
scores.append(logits) scores.append(logits if not cg else logits.clone())
if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset + 1: if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset + 1:
next_token = sample(logits, top_k=top_k, temperature=temperature) next_token = sample(logits, top_k=top_k, temperature=temperature)
else: else:
......
...@@ -15,7 +15,6 @@ from flash_attn.models.gpt import GPTLMHeadModel ...@@ -15,7 +15,6 @@ from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_hf_gpt2 from flash_attn.models.gpt import remap_state_dict_hf_gpt2
from flash_attn.models.opt import remap_state_dict_hf_opt, opt_config_to_gpt2_config from flash_attn.models.opt import remap_state_dict_hf_opt, opt_config_to_gpt2_config
from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.generation import update_graph_cache from flash_attn.utils.generation import update_graph_cache
...@@ -61,7 +60,7 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel): ...@@ -61,7 +60,7 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
torch.manual_seed(0) torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and", input_ids = tokenizer("Hello, my dog is cute and",
return_tensors="pt").input_ids.to(device=device) return_tensors="pt").input_ids.to(device=device)
max_length = 30 max_length = 30
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda') # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40 # max_length = input_ids.shape[1] + 40
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment