Commit 82f1ffdf authored by zhuwenwen's avatar zhuwenwen
Browse files

add cutlass fa and support bloom nn layout

parent b5160479
......@@ -290,7 +290,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.attn_func_triton = flash_attn_varlen_func
from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func_ck = flash_attn_varlen_func
self.attn_func_cu = flash_attn_varlen_func
logger.debug("When SEQ_LEN > 8000, Use Triton FA in ROCmBackend, otherwise Use CK FA")
else:
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
......@@ -428,7 +428,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
causal=True,
)
else:
out = self.attn_func_ck(
if envs.VLLM_USE_CL_FLASH_ATTN:
out = self.attn_func_cu(
q=query,
k=key,
v=value,
......@@ -438,6 +439,20 @@ class ROCmFlashAttentionImpl(AttentionImpl):
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
else:
out = self.attn_func_cu(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlens_q=prefill_meta.max_prefill_seq_len,
max_seqlens_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
else:
# out = self.attn_func(
......@@ -490,6 +505,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks,
)
else:
if envs.VLLM_USE_CL_FLASH_ATTN:
out = self.attn_func(
q=query,
k=key,
......@@ -500,8 +516,20 @@ class ROCmFlashAttentionImpl(AttentionImpl):
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
# window_size=self.sliding_window,
# alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
else:
out = self.attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlens_q=prefill_meta.max_prefill_seq_len,
max_seqlens_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
# common code for prefill
......
......@@ -12,6 +12,7 @@ if TYPE_CHECKING:
VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_USE_CL_FLASH_ATTN: bool = False
VLLM_USE_FLASH_ATTN_AUTO: bool = False
VLLM_USE_OPT_OP: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False
......@@ -196,6 +197,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
("true", "1")),
# flag to control if vllm should use cutlass flash attention
"VLLM_USE_CL_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_CL_FLASH_ATTN", "False").lower() in
("true", "1")),
# flag to control vllm to automatically switch between Triton FA and CK FA
"VLLM_USE_FLASH_ATTN_AUTO":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_AUTO", "True").lower() in
......
......@@ -22,7 +22,7 @@ def set_default_torch_dtype(dtype: torch.dtype):
def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'ChatGLMModel', 'BaichuanForCausalLM']
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'ChatGLMModel', 'BaichuanForCausalLM', 'BloomForCausalLM']
use_triton_fa_architectures = ['DeepseekV2ForCausalLM']
if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0':
......
......@@ -22,6 +22,8 @@ from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import BloomConfig
import os
import re
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
......@@ -40,6 +42,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
......@@ -113,6 +117,10 @@ class BloomAttention(nn.Module):
alibi_slopes=alibi_slopes,
cache_config=cache_config,
quant_config=quant_config)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
def forward(
self,
......@@ -285,6 +293,15 @@ class BloomForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def forward(
self,
input_ids: torch.Tensor,
......@@ -342,3 +359,40 @@ class BloomForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if self.use_llama_nn and self.quant_method is None:
lay_key_words = [
"self_attention.query_key_value.weight",
"self_attention.dense.weight",
"mlp.dense_h_to_4h.weight",
"mlp.dense_4h_to_h.weight"
]
combined_words = "|".join(lay_key_words)
lay_qkv_words = ["self_attention.query_key_value.weight"]
qkv_words = "|".join(lay_qkv_words)
lay_qkv_bias_words = ["self_attention.query_key_value.bias"]
qkv_bias_words = "|".join(lay_qkv_bias_words)
for layername, weight in params_dict.items():
if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
weight.data = pad_weight(weight.data, 32)
matches = re.findall(combined_words, layername)
if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and (re.findall(qkv_words, layername)):
if not gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment