Commit 3f7d5786 authored by Tri Dao's avatar Tri Dao
Browse files

Pass alibi slopes to flash_attn_with_kvcache during generation

parent f8448524
# Copyright (c) 2023, GGGGGGXY.
# Copyright (c) 2023, GGGGGGXY, Tri Dao.
import math
import json
......@@ -14,7 +14,6 @@ from einops import rearrange
from transformers import GPT2Config, AutoConfig, PretrainedConfig
# only support Baichuan-7B now
def remap_state_dict_hf_baichuan(state_dict, config):
def key_mapping_layers(key):
return re.sub(r"^model.", "transformer.", key)
......
......@@ -501,6 +501,7 @@ class MHA(nn.Module):
if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset
)
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
context = flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
......@@ -513,6 +514,7 @@ class MHA(nn.Module):
softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal,
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
alibi_slopes=alibi_slopes,
)
return context
......@@ -534,6 +536,7 @@ class MHA(nn.Module):
if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset
)
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
return flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
......@@ -543,6 +546,7 @@ class MHA(nn.Module):
cache_seqlens=cache_seqlens,
softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal,
alibi_slopes=alibi_slopes,
)
def forward(
......@@ -847,6 +851,7 @@ class ParallelMHA(nn.Module):
if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset
)
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
context = flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
......@@ -859,6 +864,7 @@ class ParallelMHA(nn.Module):
softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal,
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
alibi_slopes=alibi_slopes,
)
return context
......@@ -876,6 +882,7 @@ class ParallelMHA(nn.Module):
if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset
)
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
context = flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
......@@ -885,6 +892,7 @@ class ParallelMHA(nn.Module):
cache_seqlens=cache_seqlens,
softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal,
alibi_slopes=alibi_slopes,
)
return context
......
# Copyright (c) 2023, Tri Dao.
import os
import time
from pathlib import Path
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment