Commit 5cdabd7b authored by zhuwenwen's avatar zhuwenwen
Browse files

add 7b pad dim

parent 371b1251
...@@ -14,8 +14,6 @@ from vllm.logger import init_logger ...@@ -14,8 +14,6 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.logger import init_logger
import os import os
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -179,7 +179,7 @@ class BaiChuanAttention(nn.Module): ...@@ -179,7 +179,7 @@ class BaiChuanAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states) qkv, _ = self.W_pack(hidden_states)
if os.environ.get('FA_PAD') == '1': if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32] qkv = qkv[...,:-32]
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI": if self.postion_embedding != "ALIBI":
......
...@@ -102,7 +102,7 @@ class GLMAttention(nn.Module): ...@@ -102,7 +102,7 @@ class GLMAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
if os.environ.get('FA_PAD') == '1': if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32] qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
......
...@@ -157,7 +157,7 @@ class LlamaAttention(nn.Module): ...@@ -157,7 +157,7 @@ class LlamaAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
if os.environ.get('FA_PAD') == '1': if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32] qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
......
...@@ -149,7 +149,7 @@ class Qwen2Attention(nn.Module): ...@@ -149,7 +149,7 @@ class Qwen2Attention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
if os.environ.get('FA_PAD') == '1': if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32] qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment