Commit 82f1ffdf authored by zhuwenwen's avatar zhuwenwen
Browse files

add cutlass fa and support bloom nn layout

parent b5160479
...@@ -290,7 +290,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -290,7 +290,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.attn_func_triton = flash_attn_varlen_func self.attn_func_triton = flash_attn_varlen_func
from flash_attn import flash_attn_varlen_func # noqa: F401 from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func_ck = flash_attn_varlen_func self.attn_func_cu = flash_attn_varlen_func
logger.debug("When SEQ_LEN > 8000, Use Triton FA in ROCmBackend, otherwise Use CK FA") logger.debug("When SEQ_LEN > 8000, Use Triton FA in ROCmBackend, otherwise Use CK FA")
else: else:
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 # from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
...@@ -428,7 +428,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -428,7 +428,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
causal=True, causal=True,
) )
else: else:
out = self.attn_func_ck( if envs.VLLM_USE_CL_FLASH_ATTN:
out = self.attn_func_cu(
q=query, q=query,
k=key, k=key,
v=value, v=value,
...@@ -438,6 +439,20 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -438,6 +439,20 @@ class ROCmFlashAttentionImpl(AttentionImpl):
max_seqlen_k=prefill_meta.max_prefill_seq_len, max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
else:
out = self.attn_func_cu(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlens_q=prefill_meta.max_prefill_seq_len,
max_seqlens_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
) )
else: else:
# out = self.attn_func( # out = self.attn_func(
...@@ -490,6 +505,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -490,6 +505,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks, attn_masks,
) )
else: else:
if envs.VLLM_USE_CL_FLASH_ATTN:
out = self.attn_func( out = self.attn_func(
q=query, q=query,
k=key, k=key,
...@@ -500,8 +516,20 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -500,8 +516,20 @@ class ROCmFlashAttentionImpl(AttentionImpl):
max_seqlen_k=prefill_meta.max_prefill_seq_len, max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
# window_size=self.sliding_window, window_size=self.sliding_window,
# alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
)
else:
out = self.attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlens_q=prefill_meta.max_prefill_seq_len,
max_seqlens_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
) )
# common code for prefill # common code for prefill
......
...@@ -12,6 +12,7 @@ if TYPE_CHECKING: ...@@ -12,6 +12,7 @@ if TYPE_CHECKING:
VLLM_NCCL_SO_PATH: Optional[str] = None VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_USE_CL_FLASH_ATTN: bool = False
VLLM_USE_FLASH_ATTN_AUTO: bool = False VLLM_USE_FLASH_ATTN_AUTO: bool = False
VLLM_USE_OPT_OP: bool = False VLLM_USE_OPT_OP: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False VLLM_USE_PA_PRINT_PARAM: bool = False
...@@ -196,6 +197,11 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -196,6 +197,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
("true", "1")), ("true", "1")),
# flag to control if vllm should use cutlass flash attention
"VLLM_USE_CL_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_CL_FLASH_ATTN", "False").lower() in
("true", "1")),
# flag to control vllm to automatically switch between Triton FA and CK FA # flag to control vllm to automatically switch between Triton FA and CK FA
"VLLM_USE_FLASH_ATTN_AUTO": "VLLM_USE_FLASH_ATTN_AUTO":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_AUTO", "True").lower() in lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_AUTO", "True").lower() in
......
...@@ -22,7 +22,7 @@ def set_default_torch_dtype(dtype: torch.dtype): ...@@ -22,7 +22,7 @@ def set_default_torch_dtype(dtype: torch.dtype):
def get_model_architecture( def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'ChatGLMModel', 'BaichuanForCausalLM'] support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'ChatGLMModel', 'BaichuanForCausalLM', 'BloomForCausalLM']
use_triton_fa_architectures = ['DeepseekV2ForCausalLM'] use_triton_fa_architectures = ['DeepseekV2ForCausalLM']
if any(arch in architectures for arch in support_nn_architectures): if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0': if os.getenv('LLAMA_NN') != '0':
......
...@@ -22,6 +22,8 @@ from typing import Iterable, List, Optional, Tuple ...@@ -22,6 +22,8 @@ from typing import Iterable, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import BloomConfig from transformers import BloomConfig
import os
import re
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
...@@ -40,6 +42,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -40,6 +42,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
...@@ -113,6 +117,10 @@ class BloomAttention(nn.Module): ...@@ -113,6 +117,10 @@ class BloomAttention(nn.Module):
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
def forward( def forward(
self, self,
...@@ -285,6 +293,15 @@ class BloomForCausalLM(nn.Module): ...@@ -285,6 +293,15 @@ class BloomForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -342,3 +359,40 @@ class BloomForCausalLM(nn.Module): ...@@ -342,3 +359,40 @@ class BloomForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn and self.quant_method is None:
lay_key_words = [
"self_attention.query_key_value.weight",
"self_attention.dense.weight",
"mlp.dense_h_to_4h.weight",
"mlp.dense_4h_to_h.weight"
]
combined_words = "|".join(lay_key_words)
lay_qkv_words = ["self_attention.query_key_value.weight"]
qkv_words = "|".join(lay_qkv_words)
lay_qkv_bias_words = ["self_attention.query_key_value.bias"]
qkv_bias_words = "|".join(lay_qkv_bias_words)
for layername, weight in params_dict.items():
if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
weight.data = pad_weight(weight.data, 32)
matches = re.findall(combined_words, layername)
if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and (re.findall(qkv_words, layername)):
if not gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment