Commit 3f7d5786 authored by Tri Dao's avatar Tri Dao
Browse files

Pass alibi slopes to flash_attn_with_kvcache during generation

parent f8448524
# Copyright (c) 2023, GGGGGGXY. # Copyright (c) 2023, GGGGGGXY, Tri Dao.
import math import math
import json import json
...@@ -14,7 +14,6 @@ from einops import rearrange ...@@ -14,7 +14,6 @@ from einops import rearrange
from transformers import GPT2Config, AutoConfig, PretrainedConfig from transformers import GPT2Config, AutoConfig, PretrainedConfig
# only support Baichuan-7B now
def remap_state_dict_hf_baichuan(state_dict, config): def remap_state_dict_hf_baichuan(state_dict, config):
def key_mapping_layers(key): def key_mapping_layers(key):
return re.sub(r"^model.", "transformer.", key) return re.sub(r"^model.", "transformer.", key)
......
...@@ -501,6 +501,7 @@ class MHA(nn.Module): ...@@ -501,6 +501,7 @@ class MHA(nn.Module):
if inference_params.lengths_per_sample is not None if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset else inference_params.seqlen_offset
) )
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
context = flash_attn_with_kvcache( context = flash_attn_with_kvcache(
q, q,
kv_cache[:, :, 0], kv_cache[:, :, 0],
...@@ -513,6 +514,7 @@ class MHA(nn.Module): ...@@ -513,6 +514,7 @@ class MHA(nn.Module):
softmax_scale=self.inner_cross_attn.softmax_scale, softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal, causal=self.inner_cross_attn.causal,
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False, rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
alibi_slopes=alibi_slopes,
) )
return context return context
...@@ -534,6 +536,7 @@ class MHA(nn.Module): ...@@ -534,6 +536,7 @@ class MHA(nn.Module):
if inference_params.lengths_per_sample is not None if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset else inference_params.seqlen_offset
) )
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
return flash_attn_with_kvcache( return flash_attn_with_kvcache(
q, q,
kv_cache[:, :, 0], kv_cache[:, :, 0],
...@@ -543,6 +546,7 @@ class MHA(nn.Module): ...@@ -543,6 +546,7 @@ class MHA(nn.Module):
cache_seqlens=cache_seqlens, cache_seqlens=cache_seqlens,
softmax_scale=self.inner_cross_attn.softmax_scale, softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal, causal=self.inner_cross_attn.causal,
alibi_slopes=alibi_slopes,
) )
def forward( def forward(
...@@ -847,6 +851,7 @@ class ParallelMHA(nn.Module): ...@@ -847,6 +851,7 @@ class ParallelMHA(nn.Module):
if inference_params.lengths_per_sample is not None if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset else inference_params.seqlen_offset
) )
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
context = flash_attn_with_kvcache( context = flash_attn_with_kvcache(
q, q,
kv_cache[:, :, 0], kv_cache[:, :, 0],
...@@ -859,6 +864,7 @@ class ParallelMHA(nn.Module): ...@@ -859,6 +864,7 @@ class ParallelMHA(nn.Module):
softmax_scale=self.inner_cross_attn.softmax_scale, softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal, causal=self.inner_cross_attn.causal,
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False, rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
alibi_slopes=alibi_slopes,
) )
return context return context
...@@ -876,6 +882,7 @@ class ParallelMHA(nn.Module): ...@@ -876,6 +882,7 @@ class ParallelMHA(nn.Module):
if inference_params.lengths_per_sample is not None if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset else inference_params.seqlen_offset
) )
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
context = flash_attn_with_kvcache( context = flash_attn_with_kvcache(
q, q,
kv_cache[:, :, 0], kv_cache[:, :, 0],
...@@ -885,6 +892,7 @@ class ParallelMHA(nn.Module): ...@@ -885,6 +892,7 @@ class ParallelMHA(nn.Module):
cache_seqlens=cache_seqlens, cache_seqlens=cache_seqlens,
softmax_scale=self.inner_cross_attn.softmax_scale, softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal, causal=self.inner_cross_attn.causal,
alibi_slopes=alibi_slopes,
) )
return context return context
......
# Copyright (c) 2023, Tri Dao.
import os import os
import time import time
from pathlib import Path from pathlib import Path
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment