Commit bd363067 authored by lizhigong's avatar lizhigong
Browse files

Merge branch 'v0.8.5.post1-dev' into v0.8.5-zero_overhead

parents 87ef4618 d36deb1a
...@@ -50,7 +50,6 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -50,7 +50,6 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import W8a8GetCacheJSON
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
...@@ -356,14 +355,12 @@ class LlamaModel(nn.Module): ...@@ -356,14 +355,12 @@ class LlamaModel(nn.Module):
if quant_config is not None: if quant_config is not None:
self.quant_method=quant_config.get_name() self.quant_method=quant_config.get_name()
self.quant_config=quant_config self.quant_config=quant_config
self.tritonsingleton= W8a8GetCacheJSON()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
# self.use_lm_nn = os.environ.get('LM_NN') == '1' # self.use_lm_nn = os.environ.get('LM_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -517,97 +514,7 @@ class LlamaModel(nn.Module): ...@@ -517,97 +514,7 @@ class LlamaModel(nn.Module):
else: else:
os.environ['LM_NN'] = '0' os.environ['LM_NN'] = '0'
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
# if self.quant_method == "awq" and not envs.VLLM_USE_TRITON_AWQ:
# lay_key_words = [
# "self_attn.qkv_proj.qweight",
# "self_attn.o_proj.qweight",
# "mlp.gate_up_proj.qweight",
# "mlp.down_proj.qweight"
# ]
# combined_words = "|".join(lay_key_words)
# for layername in loaded_params:
# weight = params_dict[layername]
# matches = re.findall(combined_words, layername)
# if matches:
# qweight =params_dict[layername]
# qzeros=params_dict[layername.replace("qweight", "qzeros")]
# scales=params_dict[layername.replace("qweight", "scales")]
# zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
# group_size= self.quant_config.group_size
# dim_n = scales.data.shape[1]
# dim_k = qweight.data.shape[0]
# pad_group=2
# _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size))
# sz = ops.sz_permute(_sz).reshape(-1,dim_n)
# zeros_and_scalse.data.copy_(sz)
# qweight.data.copy_(_qw)
# #reshape
# zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size]
# qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8]
# if dim_k % 4096==0 and self.use_awq_pad:
# zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
# zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
# qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
# qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
#当为triton支持推理的时候不能进行处理
if self.quant_method == "compressed_tensors":
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight",
]
combined_words = "|".join(lay_key_words)
weight_shapes=[]
all_json={}
matched_key_words=set()
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches and "scale" not in layername:
weight_data =params_dict[layername]
n=weight_data.shape[0]
# k=weight_data.shape[1]
# #判断当前size是否在优化的范围内,假如存在则走triton,假如不存在则走rocblas
# json_file=self.tritonsingleton.get_w8a8json_name(n,k)
#rocblas和cutlass目前都需要weight做处理,但是triton不用
if self.w8a8_strategy!=1:
_weight=weight_data.T.contiguous().reshape(n,-1)
weight_data.data.copy_(_weight)
#下面是针对模型记录模型出现k和n值
elif len(matched_key_words) < 4 and matches[0] not in matched_key_words:
matched_key_words.add(matches[0])
k=weight_data.shape[1]
weight_shapes.append({n,k})
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
all_json.update(configs_dict)
if self.w8a8_strategy==1:
self.tritonsingleton.triton_json_dict.append(all_json)
#找到的所有config都进行一次warmup
for key, value in all_json.items():
m=int(key.split('_')[0])
n=int(key.split('_')[1])
k=int(key.split('_')[2])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
return loaded_params return loaded_params
......
...@@ -39,7 +39,6 @@ from .utils import (is_pp_missing_parameter, ...@@ -39,7 +39,6 @@ from .utils import (is_pp_missing_parameter,
maybe_prefix) maybe_prefix)
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from vllm.utils import W8a8GetCacheJSON
class QWenMLP(nn.Module): class QWenMLP(nn.Module):
...@@ -291,13 +290,11 @@ class QWenBaseModel(nn.Module): ...@@ -291,13 +290,11 @@ class QWenBaseModel(nn.Module):
if quant_config is not None: if quant_config is not None:
self.quant_method=quant_config.get_name() self.quant_method=quant_config.get_name()
self.quant_config=quant_config self.quant_config=quant_config
self.tritonsingleton= W8a8GetCacheJSON()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def compute_logits( def compute_logits(
self, self,
...@@ -384,51 +381,7 @@ class QWenBaseModel(nn.Module): ...@@ -384,51 +381,7 @@ class QWenBaseModel(nn.Module):
weight.data.copy_(_weight) weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1) weight.data=weight.data.reshape(ori_shape[1],-1)
# if self.quant_method == "awq":
# os.environ['LM_NN'] = '0'
# lay_key_words = [
# "attn.c_attn.qweight",
# "attn.c_proj.qweight",
# "mlp.gate_up_proj.qweight",
# "mlp.c_proj.qweight"
# ]
# combined_words = "|".join(lay_key_words)
# for layername in loaded_params:
# weight = params_dict[layername]
# matches = re.findall(combined_words, layername)
# if matches:
# qweight =params_dict[layername]
# qzeros=params_dict[layername.replace("qweight", "qzeros")]
# scales=params_dict[layername.replace("qweight", "scales")]
# zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
# group_size= self.quant_config.group_size
# dim_n = scales.data.shape[1]
# dim_k = qweight.data.shape[0]
# pad_group=2
# _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size))
# sz = ops.sz_permute(_sz).reshape(-1,dim_n)
# zeros_and_scalse.data.copy_(sz)
# qweight.data.copy_(_qw)
# #reshape
# zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size]
# qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8]
# if dim_k % 4096==0 and self.use_awq_pad:
# zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
# zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
# qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
# qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
if self.quant_method == "compressed_tensors":
os.environ['LM_NN'] = '0' os.environ['LM_NN'] = '0'
lay_key_words = [ lay_key_words = [
"attn.c_attn.weight", "attn.c_attn.weight",
......
...@@ -63,7 +63,6 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, ...@@ -63,7 +63,6 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
maybe_prefix) maybe_prefix)
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from vllm.utils import W8a8GetCacheJSON
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -338,13 +337,11 @@ class Qwen2Model(nn.Module): ...@@ -338,13 +337,11 @@ class Qwen2Model(nn.Module):
if quant_config is not None: if quant_config is not None:
self.quant_method=quant_config.get_name() self.quant_method=quant_config.get_name()
self.quant_config=quant_config self.quant_config=quant_config
self.tritonsingleton= W8a8GetCacheJSON()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -485,93 +482,6 @@ class Qwen2Model(nn.Module): ...@@ -485,93 +482,6 @@ class Qwen2Model(nn.Module):
else: else:
os.environ['LM_NN'] = '0' os.environ['LM_NN'] = '0'
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
# if self.quant_method == "awq" and not envs.VLLM_USE_TRITON_AWQ:
# lay_key_words = [
# "self_attn.qkv_proj.qweight",
# "self_attn.o_proj.qweight",
# "mlp.gate_up_proj.qweight",
# "mlp.down_proj.qweight"
# ]
# combined_words = "|".join(lay_key_words)
# for layername in loaded_params:
# weight = params_dict[layername]
# matches = re.findall(combined_words, layername)
# if matches:
# qweight =params_dict[layername]
# qzeros=params_dict[layername.replace("qweight", "qzeros")]
# scales=params_dict[layername.replace("qweight", "scales")]
# zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
# group_size= self.quant_config.group_size
# dim_n = scales.data.shape[1]
# dim_k = qweight.data.shape[0]
# pad_group=2
# _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size))
# sz = ops.sz_permute(_sz).reshape(-1,dim_n)
# zeros_and_scalse.data.copy_(sz)
# qweight.data.copy_(_qw)
# #reshape
# zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size]
# qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8]
# if dim_k % 4096==0 and self.use_awq_pad:
# zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
# zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
# qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
# qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
if self.quant_method == "compressed_tensors":
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight",
]
combined_words = "|".join(lay_key_words)
weight_shapes=[]
all_json={}
matched_key_words=set()
for layername in loaded_params:
weight = params_dict[layername]
matches = re.findall(combined_words, layername)
if matches and "scale" not in layername:
weight_data =params_dict[layername]
n=weight_data.shape[0]
#rocblas和cutlass目前都需要weight做处理,但是triton不用
if self.w8a8_strategy!=1:
_weight=weight_data.T.contiguous().reshape(n,-1)
weight_data.data.copy_(_weight)
#下面是针对模型记录模型出现k和n值
elif len(matched_key_words) < 4 and matches[0] not in matched_key_words:
matched_key_words.add(matches[0])
k=weight_data.shape[1]
weight_shapes.append({n,k})
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
all_json.update(configs_dict)
if self.w8a8_strategy==1:
self.tritonsingleton.triton_json_dict.append(all_json)
#找到的所有config都进行一次warmup
for key, value in all_json.items():
m=int(key.split('_')[0])
n=int(key.split('_')[1])
k=int(key.split('_')[2])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
return loaded_params return loaded_params
......
...@@ -110,15 +110,16 @@ def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, ...@@ -110,15 +110,16 @@ def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
# rocm custom page attention not support on gfx1* # rocm custom page attention not support on gfx1*
# custom paged attn always supported on V0. On V1, requires sliding window # custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy. # disabled due to observed numerical discrepancy.
return (on_mi250_mi300() and (not envs.VLLM_USE_V1 or sliding_window == 0 return False
or sliding_window == (-1, -1)) # return (on_mi250_mi300() and (not envs.VLLM_USE_V1 or sliding_window == 0
and (qtype == torch.half or qtype == torch.bfloat16) # or sliding_window == (-1, -1))
and (head_size == 64 or head_size == 128) # and (qtype == torch.half or qtype == torch.bfloat16)
and (block_size == 16 or block_size == 32) # and (head_size == 64 or head_size == 128)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768 # and (block_size == 16 or block_size == 32)
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) # and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN # and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and envs.VLLM_ROCM_USE_AITER)) # and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
# and envs.VLLM_ROCM_USE_AITER))
class RocmPlatform(Platform): class RocmPlatform(Platform):
...@@ -205,20 +206,102 @@ class RocmPlatform(Platform): ...@@ -205,20 +206,102 @@ class RocmPlatform(Platform):
# f" The selected backend, {selected_backend.name}," # f" The selected backend, {selected_backend.name},"
# f"is not MLA type while requested for MLA backend.") # f"is not MLA type while requested for MLA backend.")
selected_backend = (_Backend.ROCM_FLASH if selected_backend if envs.VLLM_FLASH_ATTN_BACKEND:
== _Backend.FLASH_ATTN else selected_backend) if use_v1:
if envs.VLLM_USE_V1: if selected_backend == _Backend.FLASHINFER:
logger.info("Using Triton Attention backend on V1 engine.") raise ValueError("FlashInfer backend on V1 engine is not supported")
return ("vllm.v1.attention.backends." if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
"triton_attn.TritonAttentionBackend") logger.info_once("Using Triton backend on V1 engine.")
if selected_backend == _Backend.ROCM_FLASH: return ("vllm.v1.attention.backends."
if not cls.has_device_capability(90): "triton_attn.TritonAttentionBackend")
# not Instinct series GPUs. if cls.has_device_capability(80):
logger.info("flash_attn is not supported on NAVI GPUs.") logger.info_once("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"flash_attn.FlashAttentionBackend")
if selected_backend == _Backend.FLASHINFER:
raise ValueError("FlashInfer backend is not supported")
elif selected_backend == _Backend.XFORMERS:
raise ValueError("XFormers backend is not supported")
elif selected_backend == _Backend.FLASH_ATTN:
pass
elif selected_backend:
raise ValueError(
f"Invalid attention backend for {cls.device_name}, "
f"with use_v1: {use_v1} use_mla: {use_mla}")
target_backend = _Backend.FLASH_ATTN
if not cls.has_device_capability(80):
# Volta and Turing NVIDIA GPUs.
logger.info(
"Cannot use FlashAttention-2 backend for Volta and Turing "
"GPUs.")
raise ValueError("XFormers backend is not supported")
elif dtype not in (torch.float16, torch.bfloat16):
logger.info(
"Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16.")
# raise ValueError("XFormers backend is not supported")
pass
elif block_size % 16 != 0:
logger.info(
"Cannot use FlashAttention-2 backend for block size not "
"divisible by 16.")
raise ValueError("XFormers backend is not supported")
# FlashAttn is valid for the model, checking if the package is
# installed.
if target_backend == _Backend.FLASH_ATTN:
try:
import flash_attn # noqa: F401
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend, flash_attn_supports_fp8)
supported_sizes = \
FlashAttentionBackend.get_supported_head_sizes()
if head_size not in supported_sizes:
logger.info(
"Cannot use FlashAttention-2 backend for head size %d.",
head_size)
raise ValueError("XFormers backend is not supported")
fp8_kv_cache = (kv_cache_dtype is not None
and kv_cache_dtype.startswith("fp8"))
if (fp8_kv_cache and not flash_attn_supports_fp8()):
logger.info(
"Cannot use FlashAttention backend for FP8 KV cache.")
logger.warning(
"Please use FlashInfer backend with FP8 KV Cache for "
"better performance by setting environment variable "
"VLLM_ATTENTION_BACKEND=FLASHINFER")
raise ValueError("XFormers backend is not supported")
except ImportError:
logger.info(
"Cannot use FlashAttention-2 backend because the "
"flash_attn package is not found. "
"Make sure that flash_attn was built and installed "
"(on by default).")
raise ValueError("XFormers backend is not supported")
if target_backend == _Backend.XFORMERS:
raise ValueError("XFormers backend is not supported")
logger.info("Using Flash Attention backend.")
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"
else: else:
logger.info("%s is not supported in AMD GPUs.", selected_backend) selected_backend = (_Backend.ROCM_FLASH if selected_backend
logger.info("Using ROCmFlashAttention backend.") == _Backend.FLASH_ATTN else selected_backend)
return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501 if envs.VLLM_USE_V1:
logger.info("Using Triton Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
if selected_backend == _Backend.ROCM_FLASH:
if not cls.has_device_capability(90):
# not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.")
else:
logger.info("%s is not supported in AMD GPUs.", selected_backend)
logger.info("Using ROCmFlashAttention backend.")
return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501
@classmethod @classmethod
@lru_cache(maxsize=8) @lru_cache(maxsize=8)
......
...@@ -1725,7 +1725,9 @@ class W8a8GetCacheJSON: ...@@ -1725,7 +1725,9 @@ class W8a8GetCacheJSON:
json_folder_path=current_folder_path+'/../lmslim/configs/w8a8' json_folder_path=current_folder_path+'/../lmslim/configs/w8a8'
self.triton_json_dir=(os.getenv('TRITON_JSON_DIR', json_folder_path)) self.triton_json_dir=(os.getenv('TRITON_JSON_DIR', json_folder_path))
self.triton_json_dict=[] self.triton_json_dict={}
self.triton_json_list=[]
self.weight_shapes=[]
def getspec_config(self,configs_dict,M,N,K): def getspec_config(self,configs_dict,M,N,K):
if f"{M}_{N}_{K}" in configs_dict: if f"{M}_{N}_{K}" in configs_dict:
...@@ -1823,6 +1825,7 @@ class W8a8GetCacheJSON: ...@@ -1823,6 +1825,7 @@ class W8a8GetCacheJSON:
'kpack': int(sub_value["kpack"]), 'kpack': int(sub_value["kpack"]),
'num_stages':int(sub_value['num_stages']), 'num_stages':int(sub_value['num_stages']),
'num_warps':int(sub_value['num_warps']), 'num_warps':int(sub_value['num_warps']),
'enable_mmacfuse':int(sub_value['enable_mmacfuse']),
} }
configs_dict[configs_key]=configs_value configs_dict[configs_key]=configs_value
return configs_dict return configs_dict
......
...@@ -24,9 +24,11 @@ if TYPE_CHECKING: ...@@ -24,9 +24,11 @@ if TYPE_CHECKING:
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
if current_platform.is_cuda(): if not current_platform.is_rocm():
from vllm.vllm_flash_attn import (flash_attn_varlen_func, from vllm.vllm_flash_attn import (flash_attn_varlen_func,
get_scheduler_metadata) get_scheduler_metadata)
else:
from flash_attn import flash_attn_varlen_func, vllm_flash_attn_varlen_func
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -603,57 +605,79 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -603,57 +605,79 @@ class FlashAttentionImpl(AttentionImpl):
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
flash_attn_varlen_func( if not current_platform.is_rocm():
q=query[:num_actual_tokens], flash_attn_varlen_func(
k=key_cache, q=query[:num_actual_tokens],
v=value_cache, k=key_cache,
out=output[:num_actual_tokens], v=value_cache,
cu_seqlens_q=cu_seqlens_q, out=output[:num_actual_tokens],
max_seqlen_q=max_seqlen_q, cu_seqlens_q=cu_seqlens_q,
seqused_k=seqused_k, max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k, seqused_k=seqused_k,
softmax_scale=self.scale, max_seqlen_k=max_seqlen_k,
causal=True, softmax_scale=self.scale,
alibi_slopes=self.alibi_slopes, causal=True,
window_size=self.sliding_window, alibi_slopes=self.alibi_slopes,
block_table=block_table, window_size=self.sliding_window,
softcap=self.logits_soft_cap, block_table=block_table,
scheduler_metadata=scheduler_metadata, softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version, scheduler_metadata=scheduler_metadata,
q_descale=layer._q_scale.expand(descale_shape), fa_version=self.vllm_flash_attn_version,
k_descale=layer._k_scale.expand(descale_shape), q_descale=layer._q_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape),
) v_descale=layer._v_scale.expand(descale_shape),
)
else:
vllm_flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
# scheduler_metadata=scheduler_metadata,
)
return output return output
assert not use_local_attn, ( assert not use_local_attn, (
"Cascade attention does not support local attention.") "Cascade attention does not support local attention.")
# Cascade attention (rare case). # Cascade attention (rare case).
cascade_attention( if not current_platform.is_rocm():
output[:num_actual_tokens], cascade_attention(
query[:num_actual_tokens], output[:num_actual_tokens],
key_cache, query[:num_actual_tokens],
value_cache, key_cache,
cu_query_lens=attn_metadata.query_start_loc, value_cache,
max_query_len=attn_metadata.max_query_len, cu_query_lens=attn_metadata.query_start_loc,
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens, max_query_len=attn_metadata.max_query_len,
prefix_kv_lens=attn_metadata.prefix_kv_lens, cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
suffix_kv_lens=attn_metadata.suffix_kv_lens, prefix_kv_lens=attn_metadata.prefix_kv_lens,
max_kv_len=attn_metadata.max_seq_len, suffix_kv_lens=attn_metadata.suffix_kv_lens,
softmax_scale=self.scale, max_kv_len=attn_metadata.max_seq_len,
alibi_slopes=self.alibi_slopes, softmax_scale=self.scale,
sliding_window=self.sliding_window, alibi_slopes=self.alibi_slopes,
logits_soft_cap=self.logits_soft_cap, sliding_window=self.sliding_window,
block_table=attn_metadata.block_table, logits_soft_cap=self.logits_soft_cap,
common_prefix_len=attn_metadata.common_prefix_len, block_table=attn_metadata.block_table,
fa_version=self.vllm_flash_attn_version, common_prefix_len=attn_metadata.common_prefix_len,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata, fa_version=self.vllm_flash_attn_version,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata, prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
q_descale=layer._q_scale, suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
k_descale=layer._k_scale, q_descale=layer._q_scale,
v_descale=layer._v_scale, k_descale=layer._k_scale,
) v_descale=layer._v_scale,
return output )
return output
else:
raise ValueError("cascade attention is not supported on rocm")
def use_cascade_attention( def use_cascade_attention(
...@@ -761,56 +785,58 @@ def cascade_attention( ...@@ -761,56 +785,58 @@ def cascade_attention(
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2]) descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
# Process shared prefix. # Process shared prefix.
prefix_output, prefix_lse = flash_attn_varlen_func( if not current_platform.is_rocm():
q=query, prefix_output, prefix_lse = flash_attn_varlen_func(
k=key_cache, q=query,
v=value_cache, k=key_cache,
cu_seqlens_q=cu_prefix_query_lens, v=value_cache,
seqused_k=prefix_kv_lens, cu_seqlens_q=cu_prefix_query_lens,
max_seqlen_q=num_tokens, seqused_k=prefix_kv_lens,
max_seqlen_k=common_prefix_len, max_seqlen_q=num_tokens,
softmax_scale=softmax_scale, max_seqlen_k=common_prefix_len,
causal=False, softmax_scale=softmax_scale,
window_size=sliding_window, causal=False,
block_table=block_table[:1], window_size=sliding_window,
softcap=logits_soft_cap, block_table=block_table[:1],
return_softmax_lse=True, softcap=logits_soft_cap,
scheduler_metadata=prefix_scheduler_metadata, return_softmax_lse=True,
fa_version=fa_version, scheduler_metadata=prefix_scheduler_metadata,
q_descale=q_descale.expand(descale_shape) fa_version=fa_version,
if q_descale is not None else None, q_descale=q_descale.expand(descale_shape)
k_descale=k_descale.expand(descale_shape) if q_descale is not None else None,
if k_descale is not None else None, k_descale=k_descale.expand(descale_shape)
v_descale=v_descale.expand(descale_shape) if k_descale is not None else None,
if v_descale is not None else None, v_descale=v_descale.expand(descale_shape)
) if v_descale is not None else None,
)
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
# Process suffix per query. # Process suffix per query.
suffix_output, suffix_lse = flash_attn_varlen_func( if not current_platform.is_rocm():
q=query, suffix_output, suffix_lse = flash_attn_varlen_func(
k=key_cache, q=query,
v=value_cache, k=key_cache,
cu_seqlens_q=cu_query_lens, v=value_cache,
seqused_k=suffix_kv_lens, cu_seqlens_q=cu_query_lens,
max_seqlen_q=max_query_len, seqused_k=suffix_kv_lens,
max_seqlen_k=max_kv_len - common_prefix_len, max_seqlen_q=max_query_len,
softmax_scale=softmax_scale, max_seqlen_k=max_kv_len - common_prefix_len,
causal=True, softmax_scale=softmax_scale,
window_size=sliding_window, causal=True,
block_table=block_table[:, num_common_kv_blocks:], window_size=sliding_window,
softcap=logits_soft_cap, block_table=block_table[:, num_common_kv_blocks:],
return_softmax_lse=True, softcap=logits_soft_cap,
scheduler_metadata=suffix_scheduler_metadata, return_softmax_lse=True,
fa_version=fa_version, scheduler_metadata=suffix_scheduler_metadata,
q_descale=q_descale.expand(descale_shape) fa_version=fa_version,
if q_descale is not None else None, q_descale=q_descale.expand(descale_shape)
k_descale=k_descale.expand(descale_shape) if q_descale is not None else None,
if k_descale is not None else None, k_descale=k_descale.expand(descale_shape)
v_descale=v_descale.expand(descale_shape) if k_descale is not None else None,
if v_descale is not None else None, v_descale=v_descale.expand(descale_shape)
) if v_descale is not None else None,
)
# Merge prefix and suffix outputs, and store the result in output. # Merge prefix and suffix outputs, and store the result in output.
merge_attn_states(output, prefix_output, prefix_lse, suffix_output, merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment