Commit ffbef65c authored by zhuwenwen's avatar zhuwenwen
Browse files

support fa pad

parent 1190e964
......@@ -54,7 +54,10 @@ pip install setuptools wheel
```shell
git clone http://developer.hpccube.com/codes/OpenDAS/vllm.git # 根据需要的分支进行切换
```
安装依赖:
```shell
pip install -r requirements-rocm.txt
```
- 提供2种源码编译方式(进入vllm目录):
```
1. 编译whl包并安装
......
......@@ -183,7 +183,7 @@ class BaiChuanAttention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
if os.environ.get('FA_PAD') == '1':
qkv = qkv[...,:-32]
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI":
......@@ -413,14 +413,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
]
combined_words = "|".join(lay_key_words)
lay_qkv_words = ["self_attn.W_pack.weight"]
qkv_words = "|".join(lay_qkv_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and weight.data.shape[0] == 12288:
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and (re.findall(qkv_words, layername)):
if not gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
......
......@@ -106,7 +106,7 @@ class GLMAttention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
if os.environ.get('FA_PAD') == '1':
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
......@@ -413,14 +413,24 @@ class ChatGLMForCausalLM(nn.Module):
]
combined_words = "|".join(lay_key_words)
lay_qkv_words = ["self_attention.query_key_value.weight"]
qkv_words = "|".join(lay_qkv_words)
lay_qkv_bias_words = ["self_attention.query_key_value.bias"]
qkv_bias_words = "|".join(lay_qkv_bias_words)
for layername, weight in params_dict.items():
if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
weight.data = pad_weight(weight.data, 32)
matches = re.findall(combined_words, layername)
if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and weight.data.shape[0] == 12288:
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and (re.findall(qkv_words, layername)):
if not gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
......
......@@ -161,7 +161,7 @@ class LlamaAttention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
if os.environ.get('FA_PAD') == '1':
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
......@@ -456,14 +456,18 @@ class LlamaForCausalLM(nn.Module):
]
combined_words = "|".join(lay_key_words)
lay_qkv_words = ["self_attn.qkv_proj.weight"]
qkv_words = "|".join(lay_qkv_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and weight.data.shape[0] == 12288:
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and (re.findall(qkv_words, layername)):
if not gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
......
......@@ -123,7 +123,7 @@ class QWenAttention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
if os.environ.get('FA_PAD') == '1':
qkv = qkv[...,:-32]
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k)
......@@ -312,14 +312,24 @@ class QWenLMHeadModel(nn.Module):
]
combined_words = "|".join(lay_key_words)
lay_qkv_words = ["attn.c_attn.weight"]
qkv_words = "|".join(lay_qkv_words)
lay_qkv_bias_words = ["attn.c_attn.bias"]
qkv_bias_words = "|".join(lay_qkv_bias_words)
for layername, weight in params_dict.items():
if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
weight.data = pad_weight(weight.data, 32)
matches = re.findall(combined_words, layername)
if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and weight.data.shape[0] == 12288:
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and (re.findall(qkv_words, layername)):
if not gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
......
......@@ -153,7 +153,7 @@ class Qwen2Attention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
if os.environ.get('FA_PAD') == '1':
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
......@@ -399,14 +399,24 @@ class Qwen2ForCausalLM(nn.Module):
]
combined_words = "|".join(lay_key_words)
lay_qkv_words = ["self_attn.qkv_proj.weight"]
qkv_words = "|".join(lay_qkv_words)
lay_qkv_bias_words = ["self_attn.qkv_proj.bias"]
qkv_bias_words = "|".join(lay_qkv_bias_words)
for layername, weight in params_dict.items():
if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
weight.data = pad_weight(weight.data, 32)
matches = re.findall(combined_words, layername)
if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and weight.data.shape[0] == 12288:
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and (re.findall(qkv_words, layername)):
if not gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment